mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_infer' of https://github.com/alibaba/FunASR into dev_infer
This commit is contained in:
commit
17eaf419c0
@ -1581,7 +1581,7 @@ class Speech2TextTransducer:
|
||||
d = ModelDownloader()
|
||||
kwargs.update(**d.download_and_unpack(model_tag))
|
||||
|
||||
return Speech2Text(**kwargs)
|
||||
return Speech2TextTransducer(**kwargs)
|
||||
|
||||
|
||||
class Speech2TextSAASR:
|
||||
|
||||
@ -87,6 +87,8 @@ model_choices = ClassChoices(
|
||||
contextual_paraformer=ContextualParaformer,
|
||||
mfcca=MFCCA,
|
||||
timestamp_prediction=TimestampPredictor,
|
||||
rnnt=TransducerModel,
|
||||
rnnt_unified=UnifiedTransducerModel,
|
||||
),
|
||||
default="asr",
|
||||
)
|
||||
@ -367,7 +369,7 @@ def build_asr_model(args):
|
||||
token_list=token_list,
|
||||
**args.model_conf,
|
||||
)
|
||||
elif args.model == "rnnt":
|
||||
elif args.model == "rnnt" or args.model == "rnnt_unified":
|
||||
# 5. Decoder
|
||||
encoder_output_size = encoder.output_size()
|
||||
|
||||
@ -396,34 +398,21 @@ def build_asr_model(args):
|
||||
**args.joint_network_conf,
|
||||
)
|
||||
|
||||
model_class = model_choices.get_class(args.model)
|
||||
# 7. Build model
|
||||
if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
|
||||
model = UnifiedTransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
else:
|
||||
model = TransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not supported model: {}".format(args.model))
|
||||
|
||||
|
||||
@ -132,6 +132,8 @@ model_choices = ClassChoices(
|
||||
neatcontextual_paraformer=NeatContextualParaformer,
|
||||
mfcca=MFCCA,
|
||||
timestamp_prediction=TimestampPredictor,
|
||||
rnnt=TransducerModel,
|
||||
rnnt_unified=UnifiedTransducerModel,
|
||||
),
|
||||
type_check=FunASRModel,
|
||||
default="asr",
|
||||
@ -1453,7 +1455,7 @@ class ASRTransducerTask(ASRTask):
|
||||
decoder_output_size = decoder.output_size
|
||||
|
||||
if getattr(args, "decoder", None) is not None:
|
||||
att_decoder_class = decoder_choices.get_class(args.att_decoder)
|
||||
att_decoder_class = decoder_choices.get_class(args.decoder)
|
||||
|
||||
att_decoder = att_decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
@ -1471,35 +1473,23 @@ class ASRTransducerTask(ASRTask):
|
||||
)
|
||||
|
||||
# 7. Build model
|
||||
try:
|
||||
model_class = model_choices.get_class(args.model)
|
||||
except AttributeError:
|
||||
model_class = model_choices.get_class("asr")
|
||||
|
||||
if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
|
||||
model = UnifiedTransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
else:
|
||||
model = TransducerModel(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
att_decoder=att_decoder,
|
||||
joint_network=joint_network,
|
||||
**args.model_conf,
|
||||
)
|
||||
# 8. Initialize model
|
||||
if args.init is not None:
|
||||
raise NotImplementedError(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user