From 14a1b5eb20c951b1fe23ca7ea389778a6899332a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Sat, 25 Feb 2023 17:45:41 +0800 Subject: [PATCH] onnx --- funasr/export/models/predictor/cif.py | 141 +++++++++++++++++++++++--- 1 file changed, 128 insertions(+), 13 deletions(-) diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py index fcfcd5f51..034e2334c 100644 --- a/funasr/export/models/predictor/cif.py +++ b/funasr/export/models/predictor/cif.py @@ -16,6 +16,11 @@ def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): return mask.type(dtype).to(device) if device is not None else mask.type(dtype) +def sequence_mask_scripts(lengths, maxlen:int): + row_vector = torch.arange(0, maxlen, 1).type(lengths.dtype).to(lengths.device) + matrix = torch.unsqueeze(lengths, dim=-1) + mask = row_vector < matrix + return mask.type(torch.float32).to(lengths.device) class CifPredictorV2(nn.Module): def __init__(self, model): @@ -71,28 +76,131 @@ class CifPredictorV2(nn.Module): return hidden, alphas, token_num_floor +# @torch.jit.script +# def cif(hidden, alphas, threshold: float): +# batch_size, len_time, hidden_size = hidden.size() +# threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) +# +# # loop varss +# integrate = torch.zeros([batch_size], device=hidden.device) +# frame = torch.zeros([batch_size, hidden_size], device=hidden.device) +# # intermediate vars along time +# list_fires = [] +# list_frames = [] +# +# for t in range(len_time): +# alpha = alphas[:, t] +# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate +# +# integrate += alpha +# list_fires.append(integrate) +# +# fire_place = integrate >= threshold +# integrate = torch.where(fire_place, +# integrate - torch.ones([batch_size], device=hidden.device), +# integrate) +# cur = torch.where(fire_place, +# distribution_completion, +# alpha) +# remainds = alpha - cur +# +# frame += cur[:, None] * hidden[:, t, :] +# list_frames.append(frame) +# frame = torch.where(fire_place[:, None].repeat(1, hidden_size), +# remainds[:, None] * hidden[:, t, :], +# frame) +# +# fires = torch.stack(list_fires, 1) +# frames = torch.stack(list_frames, 1) +# list_ls = [] +# len_labels = torch.round(alphas.sum(-1)).int() +# max_label_len = len_labels.max().item() +# # print("type: {}".format(type(max_label_len))) +# for b in range(batch_size): +# fire = fires[b, :] +# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) +# pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], dtype=l.dtype, device=hidden.device) +# list_ls.append(torch.cat([l, pad_l], 0)) +# return torch.stack(list_ls, 0), fires + +# @torch.jit.script +# def cif(hidden, alphas, threshold: float): +# batch_size, len_time, hidden_size = hidden.size() +# threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) +# +# # loop varss +# integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device) +# frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device) +# # intermediate vars along time +# list_fires = [] +# list_frames = [] +# +# for t in range(len_time): +# alpha = alphas[:, t] +# distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate +# +# integrate += alpha +# list_fires.append(integrate) +# +# fire_place = integrate >= threshold +# integrate = torch.where(fire_place, +# integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device), +# integrate) +# cur = torch.where(fire_place, +# distribution_completion, +# alpha) +# remainds = alpha - cur +# +# frame += cur[:, None] * hidden[:, t, :] +# list_frames.append(frame) +# frame = torch.where(fire_place[:, None].repeat(1, hidden_size), +# remainds[:, None] * hidden[:, t, :], +# frame) +# +# fires = torch.stack(list_fires, 1) +# frames = torch.stack(list_frames, 1) +# len_labels = torch.floor(torch.sum(alphas, dim=1)).int() +# max_label_len = torch.max(len_labels) +# pad_num = max_label_len - len_labels +# pad_num_max = torch.max(pad_num).item() +# frames_pad_tensor = torch.zeros([int(batch_size), int(pad_num_max), int(hidden_size)], dtype=frames.dtype, +# device=frames.device) +# fires_pad_tensor = torch.ones([int(batch_size), int(pad_num_max)], dtype=fires.dtype, device=fires.device) +# fires_pad_tensor_mask = sequence_mask_scripts(pad_num, maxlen=int(pad_num_max)) +# fires_pad_tensor *= fires_pad_tensor_mask +# frames_pad = torch.cat([frames, frames_pad_tensor], dim=1) +# fires_pad = torch.cat([fires, fires_pad_tensor], dim=1) +# index_bool = fires_pad >= threshold +# frames_fire = frames_pad[index_bool] +# frames_fire = torch.reshape(frames_fire, (int(batch_size), -1, int(hidden_size))) +# frames_fire_mask = sequence_mask_scripts(len_labels, maxlen=int(max_label_len)) +# frames_fire *= frames_fire_mask[:, :, None] +# +# return frames_fire, fires + + @torch.jit.script def cif(hidden, alphas, threshold: float): batch_size, len_time, hidden_size = hidden.size() threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) # loop varss - integrate = torch.zeros([batch_size], device=hidden.device) - frame = torch.zeros([batch_size, hidden_size], device=hidden.device) + integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device) + frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device) # intermediate vars along time list_fires = [] list_frames = [] for t in range(len_time): alpha = alphas[:, t] - distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate + distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate integrate += alpha list_fires.append(integrate) fire_place = integrate >= threshold integrate = torch.where(fire_place, - integrate - torch.ones([batch_size], device=hidden.device), + integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device), integrate) cur = torch.where(fire_place, distribution_completion, @@ -107,13 +215,20 @@ def cif(hidden, alphas, threshold: float): fires = torch.stack(list_fires, 1) frames = torch.stack(list_frames, 1) - list_ls = [] - len_labels = torch.round(alphas.sum(-1)).int() - max_label_len = len_labels.max().item() - print("type: {}".format(type(max_label_len))) + # list_ls = [] + len_labels = torch.round(alphas.sum(-1)).type(torch.int32) + # max_label_len = int(torch.max(len_labels).item()) + # print("type: {}".format(type(max_label_len))) + fire_idxs = fires >= threshold + frame_fires = torch.zeros_like(hidden) + max_label_len = frames[0, fire_idxs[0]].size(0) for b in range(batch_size): - fire = fires[b, :] - l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) - pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device) - list_ls.append(torch.cat([l, pad_l], 0)) - return torch.stack(list_ls, 0), fires + # fire = fires[b, :] + frame_fire = frames[b, fire_idxs[b]] + frame_len = frame_fire.size(0) + frame_fires[b, :frame_len, :] = frame_fire + + if frame_len >= max_label_len: + max_label_len = frame_len + frame_fires = frame_fires[:, :max_label_len, :] + return frame_fires, fires