diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index 806adceb5..e0e2c0967 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -83,7 +83,7 @@ class Speech2Text: # 1. Build ASR model scorers = {} asr_model, asr_train_args = build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, device, mode="asr" + asr_train_config, asr_model_file, cmvn_file, device ) frontend = None if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py index 51de5b010..5488c1029 100644 --- a/funasr/build_utils/build_model_from_file.py +++ b/funasr/build_utils/build_model_from_file.py @@ -17,6 +17,7 @@ def build_model_from_file( model_file: Union[Path, str] = None, cmvn_file: Union[Path, str] = None, device: str = "cpu", + task_name: str = "asr", mode: str = "paraformer", ): """Build model from the files. @@ -44,6 +45,7 @@ def build_model_from_file( if cmvn_file is not None: args["cmvn_file"] = cmvn_file args = argparse.Namespace(**args) + args.task_name = task_name model = build_model(args) if not isinstance(model, FunASRModel): raise RuntimeError(