diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py index e5bed1d64..621c4d960 100644 --- a/funasr/build_utils/build_asr_model.py +++ b/funasr/build_utils/build_asr_model.py @@ -6,6 +6,7 @@ from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder from funasr.models.decoder.rnn_decoder import RNNDecoder +from funasr.models.decoder.rnnt_decoder import RNNTDecoder from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt from funasr.models.decoder.transformer_decoder import ( DynamicConvolution2DTransformerDecoder, # noqa: H301 @@ -19,14 +20,13 @@ from funasr.models.decoder.transformer_decoder import ( ) from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN from funasr.models.decoder.transformer_decoder import TransformerDecoder -from funasr.models.decoder.rnnt_decoder import RNNTDecoder -from funasr.models.joint_net.joint_network import JointNetwork from funasr.models.e2e_asr import ASRModel from funasr.models.e2e_asr_mfcca import MFCCA -from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer +from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, \ + ContextualParaformer +from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel from funasr.models.e2e_tp import TimestampPredictor from funasr.models.e2e_uni_asr import UniASR -from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder from funasr.models.encoder.data2vec_encoder import Data2VecEncoder from funasr.models.encoder.mfcca_encoder import MFCCAEncoder @@ -39,6 +39,7 @@ from funasr.models.frontend.fused import FusedFrontends from funasr.models.frontend.s3prl import S3prlFrontend from funasr.models.frontend.wav_frontend import WavFrontend from funasr.models.frontend.windowing import SlidingWindow +from funasr.models.joint_net.joint_network import JointNetwork from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3 from funasr.models.specaug.specaug import SpecAug from funasr.models.specaug.specaug import SpecAugLFR @@ -252,7 +253,7 @@ def build_asr_model(args): args.frontend = None args.frontend_conf = {} frontend = None - input_size = args.input_size + input_size = args.input_size if hasattr(args, "input_size") else None # data augmentation for spectrogram if args.specaug is not None: @@ -298,7 +299,8 @@ def build_asr_model(args): token_list=token_list, **args.model_conf, ) - elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]: + elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", + "contextual_paraformer"]: # predictor predictor_class = predictor_choices.get_class(args.predictor) predictor = predictor_class(**args.predictor_conf)