diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index e1d5fdbbb..933a9271c 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -58,7 +58,7 @@ class ASRModelExportParaformer: if enc_size: dummy_input = model.get_dummy_inputs(enc_size) else: - dummy_input = model.get_dummy_inputs_txt() + dummy_input = model.get_dummy_inputs() # model_script = torch.jit.script(model) model_script = torch.jit.trace(model, dummy_input) @@ -111,7 +111,7 @@ class ASRModelExportParaformer: dummy_input, os.path.join(path, f'{model.model_name}.onnx'), verbose=verbose, - opset_version=12, + opset_version=14, input_names=model.get_input_names(), output_names=model.get_output_names(), dynamic_axes=model.get_dynamic_axes() diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py index 034e2334c..c8df7f381 100644 --- a/funasr/export/models/predictor/cif.py +++ b/funasr/export/models/predictor/cif.py @@ -76,108 +76,6 @@ class CifPredictorV2(nn.Module): return hidden, alphas, token_num_floor -# @torch.jit.script -# def cif(hidden, alphas, threshold: float): -# batch_size, len_time, hidden_size = hidden.size() -# threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) -# -# # loop varss -# integrate = torch.zeros([batch_size], device=hidden.device) -# frame = torch.zeros([batch_size, hidden_size], device=hidden.device) -# # intermediate vars along time -# list_fires = [] -# list_frames = [] -# -# for t in range(len_time): -# alpha = alphas[:, t] -# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate -# -# integrate += alpha -# list_fires.append(integrate) -# -# fire_place = integrate >= threshold -# integrate = torch.where(fire_place, -# integrate - torch.ones([batch_size], device=hidden.device), -# integrate) -# cur = torch.where(fire_place, -# distribution_completion, -# alpha) -# remainds = alpha - cur -# -# frame += cur[:, None] * hidden[:, t, :] -# list_frames.append(frame) -# frame = torch.where(fire_place[:, None].repeat(1, hidden_size), -# remainds[:, None] * hidden[:, t, :], -# frame) -# -# fires = torch.stack(list_fires, 1) -# frames = torch.stack(list_frames, 1) -# list_ls = [] -# len_labels = torch.round(alphas.sum(-1)).int() -# max_label_len = len_labels.max().item() -# # print("type: {}".format(type(max_label_len))) -# for b in range(batch_size): -# fire = fires[b, :] -# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) -# pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], dtype=l.dtype, device=hidden.device) -# list_ls.append(torch.cat([l, pad_l], 0)) -# return torch.stack(list_ls, 0), fires - -# @torch.jit.script -# def cif(hidden, alphas, threshold: float): -# batch_size, len_time, hidden_size = hidden.size() -# threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) -# -# # loop varss -# integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device) -# frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device) -# # intermediate vars along time -# list_fires = [] -# list_frames = [] -# -# for t in range(len_time): -# alpha = alphas[:, t] -# distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate -# -# integrate += alpha -# list_fires.append(integrate) -# -# fire_place = integrate >= threshold -# integrate = torch.where(fire_place, -# integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device), -# integrate) -# cur = torch.where(fire_place, -# distribution_completion, -# alpha) -# remainds = alpha - cur -# -# frame += cur[:, None] * hidden[:, t, :] -# list_frames.append(frame) -# frame = torch.where(fire_place[:, None].repeat(1, hidden_size), -# remainds[:, None] * hidden[:, t, :], -# frame) -# -# fires = torch.stack(list_fires, 1) -# frames = torch.stack(list_frames, 1) -# len_labels = torch.floor(torch.sum(alphas, dim=1)).int() -# max_label_len = torch.max(len_labels) -# pad_num = max_label_len - len_labels -# pad_num_max = torch.max(pad_num).item() -# frames_pad_tensor = torch.zeros([int(batch_size), int(pad_num_max), int(hidden_size)], dtype=frames.dtype, -# device=frames.device) -# fires_pad_tensor = torch.ones([int(batch_size), int(pad_num_max)], dtype=fires.dtype, device=fires.device) -# fires_pad_tensor_mask = sequence_mask_scripts(pad_num, maxlen=int(pad_num_max)) -# fires_pad_tensor *= fires_pad_tensor_mask -# frames_pad = torch.cat([frames, frames_pad_tensor], dim=1) -# fires_pad = torch.cat([fires, fires_pad_tensor], dim=1) -# index_bool = fires_pad >= threshold -# frames_fire = frames_pad[index_bool] -# frames_fire = torch.reshape(frames_fire, (int(batch_size), -1, int(hidden_size))) -# frames_fire_mask = sequence_mask_scripts(len_labels, maxlen=int(max_label_len)) -# frames_fire *= frames_fire_mask[:, :, None] -# -# return frames_fire, fires - @torch.jit.script def cif(hidden, alphas, threshold: float): diff --git a/scan.py b/scan.py new file mode 100644 index 000000000..e69de29bb