This commit is contained in:
speech_asr 2023-04-20 17:10:22 +08:00
parent 993fdd8ecf
commit aa07151996
2 changed files with 24 additions and 4 deletions

View File

@ -1,4 +1,7 @@
import logging
from funasr.lm.abs_model import AbsLM
from funasr.lm.abs_model import LanguageModel
from funasr.lm.seq_rnn_lm import SequentialRNNLM
from funasr.lm.transformer_lm import TransformerLM
from funasr.torch_utils.initialize import initialize
@ -13,10 +16,19 @@ lm_choices = ClassChoices(
type_check=AbsLM,
default="seq_rnn",
)
model_choices = ClassChoices(
"model",
classes=dict(
lm=LanguageModel,
),
default="lm",
)
class_choices_list = [
# --lm and --lm_conf
lm_choices
lm_choices,
# --model and --model_conf
model_choices
]
@ -31,4 +43,15 @@ def build_lm_model(args):
else:
vocab_size = None
# lm
lm_class = lm_choices.get_class(args.lm)
lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
model_class = model_choices.get_class(args.model)
model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
# initialize
if args.init is not None:
initialize(model, args.init)
return model

View File

@ -206,6 +206,3 @@ class LMTask(AbsTask):
# 3. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model