This commit is contained in:
speech_asr 2023-04-20 16:03:54 +08:00
parent 200d1ede05
commit a29166b9a0

View File

@ -268,7 +268,7 @@ def build_asr_model(args):
token_list=token_list,
**args.model_conf,
)
elif args.model == "paraformer":
elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
# predictor
predictor_class = predictor_choices.get_class(args.predictor)
predictor = predictor_class(**args.predictor_conf)
@ -336,7 +336,14 @@ def build_asr_model(args):
stride_conv=stride_conv,
**args.model_conf,
)
elif args.model == "timestamp_prediction":
model_class = model_choices.get_class(args.model)
model = model_class(
frontend=frontend,
encoder=encoder,
token_list=token_list,
**args.model_conf,
)
else:
raise NotImplementedError("Not supported model: {}".format(args.model))