diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py index 7525c906c..f3e0f3d2b 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py @@ -32,7 +32,7 @@ class Paraformer(): plot_timestamp_to: str = "", quantize: bool = False, intra_op_num_threads: int = 4, - cache_dir=None + cache_dir: str = None ): if not Path(model_dir).exists(): @@ -41,6 +41,12 @@ class Paraformer(): model_dir = snapshot_download(model_dir, cache_dir=cache_dir) except: raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir) + + model_file = os.path.join(model_dir, 'model.onnx') + if quantize: + model_file = os.path.join(model_dir, 'model_quant.onnx') + if not os.path.exists(model_file): + print(".onnx is not exist, begin to export onnx") from funasr.export.export_model import ModelExport export_model = ModelExport( cache_dir=cache_dir, @@ -50,11 +56,6 @@ class Paraformer(): ) export_model.export(model_dir) - - - model_file = os.path.join(model_dir, 'model.onnx') - if quantize: - model_file = os.path.join(model_dir, 'model_quant.onnx') config_file = os.path.join(model_dir, 'config.yaml') cmvn_file = os.path.join(model_dir, 'am.mvn') config = read_yaml(config_file) diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index 6fd01e404..8890714e6 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -24,15 +24,32 @@ class CT_Transformer(): batch_size: int = 1, device_id: Union[str, int] = "-1", quantize: bool = False, - intra_op_num_threads: int = 4 + intra_op_num_threads: int = 4, + cache_dir: str = None, ): - + if not Path(model_dir).exists(): - raise FileNotFoundError(f'{model_dir} does not exist.') - + from modelscope.hub.snapshot_download import snapshot_download + try: + model_dir = snapshot_download(model_dir, cache_dir=cache_dir) + except: + raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format( + model_dir) + model_file = os.path.join(model_dir, 'model.onnx') if quantize: model_file = os.path.join(model_dir, 'model_quant.onnx') + if not os.path.exists(model_file): + print(".onnx is not exist, begin to export onnx") + from funasr.export.export_model import ModelExport + export_model = ModelExport( + cache_dir=cache_dir, + onnx=True, + device="cpu", + quant=quantize, + ) + export_model.export(model_dir) + config_file = os.path.join(model_dir, 'punc.yaml') config = read_yaml(config_file) @@ -135,9 +152,10 @@ class CT_Transformer_VadRealtime(CT_Transformer): batch_size: int = 1, device_id: Union[str, int] = "-1", quantize: bool = False, - intra_op_num_threads: int = 4 + intra_op_num_threads: int = 4, + cache_dir: str = None ): - super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads) + super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads, cache_dir=cache_dir) def __call__(self, text: str, param_dict: map, split_size=20): cache_key = "cache" diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py index 78c3f0d98..dcee42500 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py @@ -271,4 +271,5 @@ def get_logger(name='funasr_onnx'): logger.addHandler(sh) logger_initialized[name] = True logger.propagate = False + logging.basicConfig(level=logging.ERROR) return logger diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py index 022f1e780..244dd757a 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py @@ -31,14 +31,30 @@ class Fsmn_vad(): quantize: bool = False, intra_op_num_threads: int = 4, max_end_sil: int = None, + cache_dir: str = None ): if not Path(model_dir).exists(): - raise FileNotFoundError(f'{model_dir} does not exist.') + from modelscope.hub.snapshot_download import snapshot_download + try: + model_dir = snapshot_download(model_dir, cache_dir=cache_dir) + except: + raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format( + model_dir) model_file = os.path.join(model_dir, 'model.onnx') if quantize: model_file = os.path.join(model_dir, 'model_quant.onnx') + if not os.path.exists(model_file): + print(".onnx is not exist, begin to export onnx") + from funasr.export.export_model import ModelExport + export_model = ModelExport( + cache_dir=cache_dir, + onnx=True, + device="cpu", + quant=quantize, + ) + export_model.export(model_dir) config_file = os.path.join(model_dir, 'vad.yaml') cmvn_file = os.path.join(model_dir, 'vad.mvn') config = read_yaml(config_file) @@ -172,14 +188,29 @@ class Fsmn_vad_online(): quantize: bool = False, intra_op_num_threads: int = 4, max_end_sil: int = None, + cache_dir: str = None ): - if not Path(model_dir).exists(): - raise FileNotFoundError(f'{model_dir} does not exist.') + from modelscope.hub.snapshot_download import snapshot_download + try: + model_dir = snapshot_download(model_dir, cache_dir=cache_dir) + except: + raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format( + model_dir) model_file = os.path.join(model_dir, 'model.onnx') if quantize: model_file = os.path.join(model_dir, 'model_quant.onnx') + if not os.path.exists(model_file): + print(".onnx is not exist, begin to export onnx") + from funasr.export.export_model import ModelExport + export_model = ModelExport( + cache_dir=cache_dir, + onnx=True, + device="cpu", + quant=quantize, + ) + export_model.export(model_dir) config_file = os.path.join(model_dir, 'vad.yaml') cmvn_file = os.path.join(model_dir, 'vad.mvn') config = read_yaml(config_file)