diff --git a/funasr/build_utils/build_lm_model.py b/funasr/build_utils/build_lm_model.py index 2cd869daf..aaa4fb7a7 100644 --- a/funasr/build_utils/build_lm_model.py +++ b/funasr/build_utils/build_lm_model.py @@ -1,4 +1,7 @@ +import logging + from funasr.lm.abs_model import AbsLM +from funasr.lm.abs_model import LanguageModel from funasr.lm.seq_rnn_lm import SequentialRNNLM from funasr.lm.transformer_lm import TransformerLM from funasr.torch_utils.initialize import initialize @@ -13,10 +16,19 @@ lm_choices = ClassChoices( type_check=AbsLM, default="seq_rnn", ) +model_choices = ClassChoices( + "model", + classes=dict( + lm=LanguageModel, + ), + default="lm", +) class_choices_list = [ # --lm and --lm_conf - lm_choices + lm_choices, + # --model and --model_conf + model_choices ] @@ -31,4 +43,15 @@ def build_lm_model(args): else: vocab_size = None + # lm + lm_class = lm_choices.get_class(args.lm) + lm = lm_class(vocab_size=vocab_size, **args.lm_conf) + + model_class = model_choices.get_class(args.model) + model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf) + + # initialize + if args.init is not None: + initialize(model, args.init) + return model diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py index 80d66d52f..1e48655d9 100644 --- a/funasr/tasks/lm.py +++ b/funasr/tasks/lm.py @@ -206,6 +206,3 @@ class LMTask(AbsTask): # 3. Initialize if args.init is not None: initialize(model, args.init) - - assert check_return_type(model) - return model