From 6c381c270094b107ab7a7f087f809fc38a1c69a8 Mon Sep 17 00:00:00 2001 From: aky15 Date: Thu, 18 May 2023 11:44:48 +0800 Subject: [PATCH] modify rnnt infer --- funasr/bin/asr_infer.py | 2 +- funasr/build_utils/build_asr_model.py | 43 +++++++++--------------- funasr/tasks/asr.py | 48 +++++++++++---------------- 3 files changed, 36 insertions(+), 57 deletions(-) diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index 03145f859..d9d413b3a 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -1581,7 +1581,7 @@ class Speech2TextTransducer: d = ModelDownloader() kwargs.update(**d.download_and_unpack(model_tag)) - return Speech2Text(**kwargs) + return Speech2TextTransducer(**kwargs) class Speech2TextSAASR: diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py index 718736b9c..ddc827fb3 100644 --- a/funasr/build_utils/build_asr_model.py +++ b/funasr/build_utils/build_asr_model.py @@ -87,6 +87,8 @@ model_choices = ClassChoices( contextual_paraformer=ContextualParaformer, mfcca=MFCCA, timestamp_prediction=TimestampPredictor, + rnnt=TransducerModel, + rnnt_unified=UnifiedTransducerModel, ), default="asr", ) @@ -367,7 +369,7 @@ def build_asr_model(args): token_list=token_list, **args.model_conf, ) - elif args.model == "rnnt": + elif args.model == "rnnt" or args.model == "rnnt_unified": # 5. Decoder encoder_output_size = encoder.output_size() @@ -396,34 +398,21 @@ def build_asr_model(args): **args.joint_network_conf, ) + model_class = model_choices.get_class(args.model) # 7. Build model - if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training: - model = UnifiedTransducerModel( - vocab_size=vocab_size, - token_list=token_list, - frontend=frontend, - specaug=specaug, - normalize=normalize, - encoder=encoder, - decoder=decoder, - att_decoder=att_decoder, - joint_network=joint_network, - **args.model_conf, - ) + model = model_class( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + **args.model_conf, + ) - else: - model = TransducerModel( - vocab_size=vocab_size, - token_list=token_list, - frontend=frontend, - specaug=specaug, - normalize=normalize, - encoder=encoder, - decoder=decoder, - att_decoder=att_decoder, - joint_network=joint_network, - **args.model_conf, - ) else: raise NotImplementedError("Not supported model: {}".format(args.model)) diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index 5de475f19..8e4f9ccbf 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -132,6 +132,8 @@ model_choices = ClassChoices( neatcontextual_paraformer=NeatContextualParaformer, mfcca=MFCCA, timestamp_prediction=TimestampPredictor, + rnnt=TransducerModel, + rnnt_unified=UnifiedTransducerModel, ), type_check=FunASRModel, default="asr", @@ -1453,7 +1455,7 @@ class ASRTransducerTask(ASRTask): decoder_output_size = decoder.output_size if getattr(args, "decoder", None) is not None: - att_decoder_class = decoder_choices.get_class(args.att_decoder) + att_decoder_class = decoder_choices.get_class(args.decoder) att_decoder = att_decoder_class( vocab_size=vocab_size, @@ -1471,35 +1473,23 @@ class ASRTransducerTask(ASRTask): ) # 7. Build model + try: + model_class = model_choices.get_class(args.model) + except AttributeError: + model_class = model_choices.get_class("asr") - if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training: - model = UnifiedTransducerModel( - vocab_size=vocab_size, - token_list=token_list, - frontend=frontend, - specaug=specaug, - normalize=normalize, - encoder=encoder, - decoder=decoder, - att_decoder=att_decoder, - joint_network=joint_network, - **args.model_conf, - ) - - else: - model = TransducerModel( - vocab_size=vocab_size, - token_list=token_list, - frontend=frontend, - specaug=specaug, - normalize=normalize, - encoder=encoder, - decoder=decoder, - att_decoder=att_decoder, - joint_network=joint_network, - **args.model_conf, - ) - + model = model_class( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + **args.model_conf, + ) # 8. Initialize model if args.init is not None: raise NotImplementedError(