From e7351db81b3bfc4000633eca274c46893d68f64e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BB=B4=E7=9F=B3?= Date: Tue, 28 May 2024 19:07:22 +0800 Subject: [PATCH] update export --- .../paraformer/export.py | 16 ++++---- funasr/auto/auto_model.py | 8 +--- funasr/models/paraformer/export_meta.py | 1 + funasr/models/seaco_paraformer/export_meta.py | 7 ++-- funasr/utils/export_utils.py | 38 ++++++++++++++----- 5 files changed, 41 insertions(+), 29 deletions(-) diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py index 19512c1f3..fd5938aba 100644 --- a/examples/industrial_data_pretraining/paraformer/export.py +++ b/examples/industrial_data_pretraining/paraformer/export.py @@ -13,16 +13,16 @@ model = AutoModel( model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", ) -res = model.export(type="onnx", quantize=False) +res = model.export(type="torchscript", quantize=False) print(res) -# method2, inference from local path -from funasr import AutoModel +# # method2, inference from local path +# from funasr import AutoModel -model = AutoModel( - model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" -) +# model = AutoModel( +# model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +# ) -res = model.export(type="onnx", quantize=False) -print(res) +# res = model.export(type="onnx", quantize=False) +# print(res) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 97eb325da..faa5bd92b 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -580,12 +580,6 @@ class AutoModel: ) with torch.no_grad(): - - if type == "onnx": - export_dir = export_utils.export_onnx(model=model, data_in=data_list, **kwargs) - else: - export_dir = export_utils.export_torchscripts( - model=model, data_in=data_list, **kwargs - ) + export_dir = export_utils.export(model=model, data_in=data_list, **kwargs) return export_dir diff --git a/funasr/models/paraformer/export_meta.py b/funasr/models/paraformer/export_meta.py index 5c1b6c08d..db9385519 100644 --- a/funasr/models/paraformer/export_meta.py +++ b/funasr/models/paraformer/export_meta.py @@ -31,6 +31,7 @@ def export_rebuild_model(model, **kwargs): model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) model.export_name = types.MethodType(export_name, model) + model.export_name = 'model' return model diff --git a/funasr/models/seaco_paraformer/export_meta.py b/funasr/models/seaco_paraformer/export_meta.py index 6d8096f59..db27c914a 100644 --- a/funasr/models/seaco_paraformer/export_meta.py +++ b/funasr/models/seaco_paraformer/export_meta.py @@ -109,7 +109,9 @@ def export_rebuild_model(model, **kwargs): backbone_model.export_dynamic_axes = types.MethodType( export_backbone_dynamic_axes, backbone_model ) - backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model) + + embedder_model.export_name = "model_eb" + backbone_model.export_name = "model_bb" return backbone_model, embedder_model @@ -192,6 +194,3 @@ def export_backbone_dynamic_axes(self): "pre_acoustic_embeds": {1: "feats_length1"}, } - -def export_backbone_name(self): - return "model.onnx" diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py index bc7953917..7d6606b6f 100644 --- a/funasr/utils/export_utils.py +++ b/funasr/utils/export_utils.py @@ -2,7 +2,7 @@ import os import torch -def export_onnx(model, data_in=None, quantize: bool = False, opset_version: int = 14, **kwargs): +def export(model, data_in=None, quantize: bool = False, opset_version: int = 14, type='onnx', **kwargs): model_scripts = model.export(**kwargs) export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param"))) os.makedirs(export_dir, exist_ok=True) @@ -11,14 +11,20 @@ def export_onnx(model, data_in=None, quantize: bool = False, opset_version: int model_scripts = (model_scripts,) for m in model_scripts: m.eval() - _onnx( - m, - data_in=data_in, - quantize=quantize, - opset_version=opset_version, - export_dir=export_dir, - **kwargs - ) + if type == 'onnx': + _onnx( + m, + data_in=data_in, + quantize=quantize, + opset_version=opset_version, + export_dir=export_dir, + **kwargs + ) + elif type == 'torchscript': + _torchscripts( + m, + path=export_dir, + ) print("output dir: {}".format(export_dir)) return export_dir @@ -37,7 +43,7 @@ def _onnx( verbose = kwargs.get("verbose", False) - export_name = model.export_name() if hasattr(model, "export_name") else "model.onnx" + export_name = model.export_name + '.onnx' model_path = os.path.join(export_dir, export_name) torch.onnx.export( model, @@ -70,3 +76,15 @@ def _onnx( weight_type=QuantType.QUInt8, nodes_to_exclude=nodes_to_exclude, ) + + +def _torchscripts(model, path, device='cpu'): + dummy_input = model.export_dummy_inputs() + + if device == 'cuda': + model = model.cuda() + dummy_input = tuple([i.cuda() for i in dummy_input]) + + # model_script = torch.jit.script(model) + model_script = torch.jit.trace(model, dummy_input) + model_script.save(os.path.join(path, f'{model.export_name}.torchscripts'))