general punc model conversion onnx, fix bug

This commit is contained in:
九耳 2023-03-29 17:01:26 +08:00
parent a62b457439
commit 9232f06604

View File

@ -3,6 +3,7 @@ from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_exp
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.export.models.target_delay_transformer import TargetDelayTransformer as TargetDelayTransformer_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
@ -11,5 +12,8 @@ def get_model(model, export_config=None):
return Paraformer_export(model, **export_config)
elif isinstance(model, E2EVadModel):
return E2EVadModel_export(model, **export_config)
elif isinstance(model, ESPnetPunctuationModel):
if isinstance(model.punc_model, TargetDelayTransformer):
return TargetDelayTransformer_export(model.punc_model, **export_config)
else:
raise "Funasr does not support the given model type currently."
raise "Funasr does not support the given model type currently."