diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index 3c73152be..933a9271c 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -58,7 +58,7 @@ class ASRModelExportParaformer: if enc_size: dummy_input = model.get_dummy_inputs(enc_size) else: - dummy_input = model.get_dummy_inputs_txt() + dummy_input = model.get_dummy_inputs() # model_script = torch.jit.script(model) model_script = torch.jit.trace(model, dummy_input) @@ -106,6 +106,110 @@ class ASRModelExportParaformer: # model_script = torch.jit.script(model) model_script = model #torch.jit.trace(model) + torch.onnx.export( + model_script, + dummy_input, + os.path.join(path, f'{model.model_name}.onnx'), + verbose=verbose, + opset_version=14, + input_names=model.get_input_names(), + output_names=model.get_output_names(), + dynamic_axes=model.get_dynamic_axes() + ) + + +class ASRModelExport: + def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True): + assert check_argument_types() + self.set_all_random_seed(0) + if cache_dir is None: + cache_dir = Path.home() / ".cache" / "export" + + self.cache_dir = Path(cache_dir) + self.export_config = dict( + feats_dim=560, + onnx=False, + ) + print("output dir: {}".format(self.cache_dir)) + self.onnx = onnx + + def _export( + self, + model: Speech2Text, + tag_name: str = None, + verbose: bool = False, + ): + + export_dir = self.cache_dir / tag_name.replace(' ', '-') + os.makedirs(export_dir, exist_ok=True) + + # export encoder1 + self.export_config["model_name"] = "model" + model = get_model( + model, + self.export_config, + ) + model.eval() + # self._export_onnx(model, verbose, export_dir) + if self.onnx: + self._export_onnx(model, verbose, export_dir) + else: + self._export_torchscripts(model, verbose, export_dir) + + print("output dir: {}".format(export_dir)) + + def _export_torchscripts(self, model, verbose, path, enc_size=None): + if enc_size: + dummy_input = model.get_dummy_inputs(enc_size) + else: + dummy_input = model.get_dummy_inputs_txt() + + # model_script = torch.jit.script(model) + model_script = torch.jit.trace(model, dummy_input) + model_script.save(os.path.join(path, f'{model.model_name}.torchscripts')) + + def set_all_random_seed(self, seed: int): + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def export(self, + tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', + mode: str = 'paraformer', + ): + + model_dir = tag_name + if model_dir.startswith('damo/'): + from modelscope.hub.snapshot_download import snapshot_download + model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir) + asr_train_config = os.path.join(model_dir, 'config.yaml') + asr_model_file = os.path.join(model_dir, 'model.pb') + cmvn_file = os.path.join(model_dir, 'am.mvn') + json_file = os.path.join(model_dir, 'configuration.json') + if mode is None: + import json + with open(json_file, 'r') as f: + config_data = json.load(f) + mode = config_data['model']['model_config']['mode'] + if mode.startswith('paraformer'): + from funasr.tasks.asr import ASRTaskParaformer as ASRTask + elif mode.startswith('uniasr'): + from funasr.tasks.asr import ASRTaskUniASR as ASRTask + + model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, 'cpu' + ) + self._export(model, tag_name) + + def _export_onnx(self, model, verbose, path, enc_size=None): + if enc_size: + dummy_input = model.get_dummy_inputs(enc_size) + else: + dummy_input = model.get_dummy_inputs() + + # model_script = torch.jit.script(model) + model_script = model # torch.jit.trace(model) + torch.onnx.export( model_script, dummy_input, @@ -117,6 +221,7 @@ class ASRModelExportParaformer: dynamic_axes=model.get_dynamic_axes() ) + if __name__ == '__main__': import sys diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py index ca2c8138a..27a65af9b 100644 --- a/funasr/export/models/__init__.py +++ b/funasr/export/models/__init__.py @@ -1,5 +1,6 @@ from funasr.models.e2e_asr_paraformer import Paraformer from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export +from funasr.models.e2e_uni_asr import UniASR def get_model(model, export_config=None): diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py index bf5ed1ea6..5424a0a94 100644 --- a/funasr/export/models/e2e_asr_paraformer.py +++ b/funasr/export/models/e2e_asr_paraformer.py @@ -59,7 +59,7 @@ class Paraformer(nn.Module): enc, enc_len = self.encoder(**batch) mask = self.make_pad_mask(enc_len)[:, None, :] pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask) - pre_token_length = pre_token_length.round().type(torch.int32) + pre_token_length = pre_token_length.floor().type(torch.int32) decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length) decoder_out = torch.log_softmax(decoder_out, dim=-1) diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py index 5518cb83d..c8df7f381 100644 --- a/funasr/export/models/predictor/cif.py +++ b/funasr/export/models/predictor/cif.py @@ -16,6 +16,11 @@ def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): return mask.type(dtype).to(device) if device is not None else mask.type(dtype) +def sequence_mask_scripts(lengths, maxlen:int): + row_vector = torch.arange(0, maxlen, 1).type(lengths.dtype).to(lengths.device) + matrix = torch.unsqueeze(lengths, dim=-1) + mask = row_vector < matrix + return mask.type(torch.float32).to(lengths.device) class CifPredictorV2(nn.Module): def __init__(self, model): @@ -71,28 +76,29 @@ class CifPredictorV2(nn.Module): return hidden, alphas, token_num_floor + @torch.jit.script def cif(hidden, alphas, threshold: float): batch_size, len_time, hidden_size = hidden.size() threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) # loop varss - integrate = torch.zeros([batch_size], device=hidden.device) - frame = torch.zeros([batch_size, hidden_size], device=hidden.device) + integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device) + frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device) # intermediate vars along time list_fires = [] list_frames = [] for t in range(len_time): alpha = alphas[:, t] - distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate + distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate integrate += alpha list_fires.append(integrate) fire_place = integrate >= threshold integrate = torch.where(fire_place, - integrate - torch.ones([batch_size], device=hidden.device), + integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device), integrate) cur = torch.where(fire_place, distribution_completion, @@ -107,12 +113,20 @@ def cif(hidden, alphas, threshold: float): fires = torch.stack(list_fires, 1) frames = torch.stack(list_frames, 1) - list_ls = [] - len_labels = torch.round(alphas.sum(-1)).int() - max_label_len = len_labels.max() + # list_ls = [] + len_labels = torch.round(alphas.sum(-1)).type(torch.int32) + # max_label_len = int(torch.max(len_labels).item()) + # print("type: {}".format(type(max_label_len))) + fire_idxs = fires >= threshold + frame_fires = torch.zeros_like(hidden) + max_label_len = frames[0, fire_idxs[0]].size(0) for b in range(batch_size): - fire = fires[b, :] - l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) - pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device) - list_ls.append(torch.cat([l, pad_l], 0)) - return torch.stack(list_ls, 0), fires + # fire = fires[b, :] + frame_fire = frames[b, fire_idxs[b]] + frame_len = frame_fire.size(0) + frame_fires[b, :frame_len, :] = frame_fire + + if frame_len >= max_label_len: + max_label_len = frame_len + frame_fires = frame_fires[:, :max_label_len, :] + return frame_fires, fires diff --git a/scan.py b/scan.py new file mode 100644 index 000000000..e69de29bb