diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index f6ba61616..b1161cbf8 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -14,7 +14,7 @@ from funasr.utils.types import str2bool # torch_version = float(".".join(torch.__version__.split(".")[:2])) # assert torch_version > 1.9 -class ASRModelExportParaformer: +class ModelExport: def __init__( self, cache_dir: Union[Path, str] = None, @@ -240,7 +240,7 @@ if __name__ == '__main__': parser.add_argument('--calib_num', type=int, default=200, help='calib max num') args = parser.parse_args() - export_model = ASRModelExportParaformer( + export_model = ModelExport( cache_dir=args.export_dir, onnx=args.type == 'onnx', quant=args.quantize,