diff --git a/funasr/utils/build_asr_model.py b/funasr/utils/build_asr_model.py index 2da050cb4..4908cd49e 100644 --- a/funasr/utils/build_asr_model.py +++ b/funasr/utils/build_asr_model.py @@ -268,7 +268,7 @@ def build_asr_model(args): token_list=token_list, **args.model_conf, ) - elif args.model == "paraformer": + elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]: # predictor predictor_class = predictor_choices.get_class(args.predictor) predictor = predictor_class(**args.predictor_conf) @@ -336,7 +336,14 @@ def build_asr_model(args): stride_conv=stride_conv, **args.model_conf, ) - + elif args.model == "timestamp_prediction": + model_class = model_choices.get_class(args.model) + model = model_class( + frontend=frontend, + encoder=encoder, + token_list=token_list, + **args.model_conf, + ) else: raise NotImplementedError("Not supported model: {}".format(args.model))