mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
993fdd8ecf
commit
aa07151996
@ -1,4 +1,7 @@
|
||||
import logging
|
||||
|
||||
from funasr.lm.abs_model import AbsLM
|
||||
from funasr.lm.abs_model import LanguageModel
|
||||
from funasr.lm.seq_rnn_lm import SequentialRNNLM
|
||||
from funasr.lm.transformer_lm import TransformerLM
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
@ -13,10 +16,19 @@ lm_choices = ClassChoices(
|
||||
type_check=AbsLM,
|
||||
default="seq_rnn",
|
||||
)
|
||||
model_choices = ClassChoices(
|
||||
"model",
|
||||
classes=dict(
|
||||
lm=LanguageModel,
|
||||
),
|
||||
default="lm",
|
||||
)
|
||||
|
||||
class_choices_list = [
|
||||
# --lm and --lm_conf
|
||||
lm_choices
|
||||
lm_choices,
|
||||
# --model and --model_conf
|
||||
model_choices
|
||||
]
|
||||
|
||||
|
||||
@ -31,4 +43,15 @@ def build_lm_model(args):
|
||||
else:
|
||||
vocab_size = None
|
||||
|
||||
# lm
|
||||
lm_class = lm_choices.get_class(args.lm)
|
||||
lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
|
||||
|
||||
model_class = model_choices.get_class(args.model)
|
||||
model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
|
||||
|
||||
# initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
return model
|
||||
|
||||
@ -206,6 +206,3 @@ class LMTask(AbsTask):
|
||||
# 3. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user