onnx export funasr_onnx

This commit is contained in:
游雁 2023-05-12 11:03:22 +08:00
parent a1cbcc09f4
commit d25d0942f9
4 changed files with 66 additions and 15 deletions

View File

@ -32,7 +32,7 @@ class Paraformer():
plot_timestamp_to: str = "",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir=None
cache_dir: str = None
):
if not Path(model_dir).exists():
@ -41,6 +41,12 @@ class Paraformer():
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
model_file = os.path.join(model_dir, 'model.onnx')
if quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
from funasr.export.export_model import ModelExport
export_model = ModelExport(
cache_dir=cache_dir,
@ -50,11 +56,6 @@ class Paraformer():
)
export_model.export(model_dir)
model_file = os.path.join(model_dir, 'model.onnx')
if quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
config_file = os.path.join(model_dir, 'config.yaml')
cmvn_file = os.path.join(model_dir, 'am.mvn')
config = read_yaml(config_file)

View File

@ -24,15 +24,32 @@ class CT_Transformer():
batch_size: int = 1,
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4
intra_op_num_threads: int = 4,
cache_dir: str = None,
):
if not Path(model_dir).exists():
raise FileNotFoundError(f'{model_dir} does not exist.')
from modelscope.hub.snapshot_download import snapshot_download
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir)
model_file = os.path.join(model_dir, 'model.onnx')
if quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
from funasr.export.export_model import ModelExport
export_model = ModelExport(
cache_dir=cache_dir,
onnx=True,
device="cpu",
quant=quantize,
)
export_model.export(model_dir)
config_file = os.path.join(model_dir, 'punc.yaml')
config = read_yaml(config_file)
@ -135,9 +152,10 @@ class CT_Transformer_VadRealtime(CT_Transformer):
batch_size: int = 1,
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4
intra_op_num_threads: int = 4,
cache_dir: str = None
):
super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads, cache_dir=cache_dir)
def __call__(self, text: str, param_dict: map, split_size=20):
cache_key = "cache"

View File

@ -271,4 +271,5 @@ def get_logger(name='funasr_onnx'):
logger.addHandler(sh)
logger_initialized[name] = True
logger.propagate = False
logging.basicConfig(level=logging.ERROR)
return logger

View File

@ -31,14 +31,30 @@ class Fsmn_vad():
quantize: bool = False,
intra_op_num_threads: int = 4,
max_end_sil: int = None,
cache_dir: str = None
):
if not Path(model_dir).exists():
raise FileNotFoundError(f'{model_dir} does not exist.')
from modelscope.hub.snapshot_download import snapshot_download
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir)
model_file = os.path.join(model_dir, 'model.onnx')
if quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
from funasr.export.export_model import ModelExport
export_model = ModelExport(
cache_dir=cache_dir,
onnx=True,
device="cpu",
quant=quantize,
)
export_model.export(model_dir)
config_file = os.path.join(model_dir, 'vad.yaml')
cmvn_file = os.path.join(model_dir, 'vad.mvn')
config = read_yaml(config_file)
@ -172,14 +188,29 @@ class Fsmn_vad_online():
quantize: bool = False,
intra_op_num_threads: int = 4,
max_end_sil: int = None,
cache_dir: str = None
):
if not Path(model_dir).exists():
raise FileNotFoundError(f'{model_dir} does not exist.')
from modelscope.hub.snapshot_download import snapshot_download
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir)
model_file = os.path.join(model_dir, 'model.onnx')
if quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
from funasr.export.export_model import ModelExport
export_model = ModelExport(
cache_dir=cache_dir,
onnx=True,
device="cpu",
quant=quantize,
)
export_model.export(model_dir)
config_file = os.path.join(model_dir, 'vad.yaml')
cmvn_file = os.path.join(model_dir, 'vad.mvn')
config = read_yaml(config_file)