mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
onnx export funasr_onnx
This commit is contained in:
parent
a1cbcc09f4
commit
d25d0942f9
@ -32,7 +32,7 @@ class Paraformer():
|
||||
plot_timestamp_to: str = "",
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4,
|
||||
cache_dir=None
|
||||
cache_dir: str = None
|
||||
):
|
||||
|
||||
if not Path(model_dir).exists():
|
||||
@ -41,6 +41,12 @@ class Paraformer():
|
||||
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
|
||||
except:
|
||||
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
|
||||
|
||||
model_file = os.path.join(model_dir, 'model.onnx')
|
||||
if quantize:
|
||||
model_file = os.path.join(model_dir, 'model_quant.onnx')
|
||||
if not os.path.exists(model_file):
|
||||
print(".onnx is not exist, begin to export onnx")
|
||||
from funasr.export.export_model import ModelExport
|
||||
export_model = ModelExport(
|
||||
cache_dir=cache_dir,
|
||||
@ -50,11 +56,6 @@ class Paraformer():
|
||||
)
|
||||
export_model.export(model_dir)
|
||||
|
||||
|
||||
|
||||
model_file = os.path.join(model_dir, 'model.onnx')
|
||||
if quantize:
|
||||
model_file = os.path.join(model_dir, 'model_quant.onnx')
|
||||
config_file = os.path.join(model_dir, 'config.yaml')
|
||||
cmvn_file = os.path.join(model_dir, 'am.mvn')
|
||||
config = read_yaml(config_file)
|
||||
|
||||
@ -24,15 +24,32 @@ class CT_Transformer():
|
||||
batch_size: int = 1,
|
||||
device_id: Union[str, int] = "-1",
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4
|
||||
intra_op_num_threads: int = 4,
|
||||
cache_dir: str = None,
|
||||
):
|
||||
|
||||
|
||||
if not Path(model_dir).exists():
|
||||
raise FileNotFoundError(f'{model_dir} does not exist.')
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
try:
|
||||
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
|
||||
except:
|
||||
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
|
||||
model_dir)
|
||||
|
||||
model_file = os.path.join(model_dir, 'model.onnx')
|
||||
if quantize:
|
||||
model_file = os.path.join(model_dir, 'model_quant.onnx')
|
||||
if not os.path.exists(model_file):
|
||||
print(".onnx is not exist, begin to export onnx")
|
||||
from funasr.export.export_model import ModelExport
|
||||
export_model = ModelExport(
|
||||
cache_dir=cache_dir,
|
||||
onnx=True,
|
||||
device="cpu",
|
||||
quant=quantize,
|
||||
)
|
||||
export_model.export(model_dir)
|
||||
|
||||
config_file = os.path.join(model_dir, 'punc.yaml')
|
||||
config = read_yaml(config_file)
|
||||
|
||||
@ -135,9 +152,10 @@ class CT_Transformer_VadRealtime(CT_Transformer):
|
||||
batch_size: int = 1,
|
||||
device_id: Union[str, int] = "-1",
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4
|
||||
intra_op_num_threads: int = 4,
|
||||
cache_dir: str = None
|
||||
):
|
||||
super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
|
||||
super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads, cache_dir=cache_dir)
|
||||
|
||||
def __call__(self, text: str, param_dict: map, split_size=20):
|
||||
cache_key = "cache"
|
||||
|
||||
@ -271,4 +271,5 @@ def get_logger(name='funasr_onnx'):
|
||||
logger.addHandler(sh)
|
||||
logger_initialized[name] = True
|
||||
logger.propagate = False
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
return logger
|
||||
|
||||
@ -31,14 +31,30 @@ class Fsmn_vad():
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4,
|
||||
max_end_sil: int = None,
|
||||
cache_dir: str = None
|
||||
):
|
||||
|
||||
if not Path(model_dir).exists():
|
||||
raise FileNotFoundError(f'{model_dir} does not exist.')
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
try:
|
||||
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
|
||||
except:
|
||||
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
|
||||
model_dir)
|
||||
|
||||
model_file = os.path.join(model_dir, 'model.onnx')
|
||||
if quantize:
|
||||
model_file = os.path.join(model_dir, 'model_quant.onnx')
|
||||
if not os.path.exists(model_file):
|
||||
print(".onnx is not exist, begin to export onnx")
|
||||
from funasr.export.export_model import ModelExport
|
||||
export_model = ModelExport(
|
||||
cache_dir=cache_dir,
|
||||
onnx=True,
|
||||
device="cpu",
|
||||
quant=quantize,
|
||||
)
|
||||
export_model.export(model_dir)
|
||||
config_file = os.path.join(model_dir, 'vad.yaml')
|
||||
cmvn_file = os.path.join(model_dir, 'vad.mvn')
|
||||
config = read_yaml(config_file)
|
||||
@ -172,14 +188,29 @@ class Fsmn_vad_online():
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4,
|
||||
max_end_sil: int = None,
|
||||
cache_dir: str = None
|
||||
):
|
||||
|
||||
if not Path(model_dir).exists():
|
||||
raise FileNotFoundError(f'{model_dir} does not exist.')
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
try:
|
||||
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
|
||||
except:
|
||||
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
|
||||
model_dir)
|
||||
|
||||
model_file = os.path.join(model_dir, 'model.onnx')
|
||||
if quantize:
|
||||
model_file = os.path.join(model_dir, 'model_quant.onnx')
|
||||
if not os.path.exists(model_file):
|
||||
print(".onnx is not exist, begin to export onnx")
|
||||
from funasr.export.export_model import ModelExport
|
||||
export_model = ModelExport(
|
||||
cache_dir=cache_dir,
|
||||
onnx=True,
|
||||
device="cpu",
|
||||
quant=quantize,
|
||||
)
|
||||
export_model.export(model_dir)
|
||||
config_file = os.path.join(model_dir, 'vad.yaml')
|
||||
cmvn_file = os.path.join(model_dir, 'vad.mvn')
|
||||
config = read_yaml(config_file)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user