diff --git a/funasr/models/e2e_asr_contextual_paraformer.py b/funasr/models/e2e_asr_contextual_paraformer.py index cafb65314..e1dfe6cf0 100644 --- a/funasr/models/e2e_asr_contextual_paraformer.py +++ b/funasr/models/e2e_asr_contextual_paraformer.py @@ -1,4 +1,3 @@ -from json import decoder import logging from contextlib import contextmanager from distutils.version import LooseVersion @@ -7,35 +6,24 @@ from typing import List from typing import Optional from typing import Tuple from typing import Union -import random -from unicodedata import bidirectional import numpy as np import torch from typeguard import check_argument_types from funasr.layers.abs_normalize import AbsNormalize -from funasr.losses.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder -from funasr.models.e2e_asr_common import ErrorCalculator from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder -from funasr.models.predictor.cif import mae_loss from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.nets_utils import make_pad_mask, pad_list from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel -from funasr.models.predictor.cif import CifPredictorV3 -from funasr.modules.streaming_utils import utils as myutils from funasr.models.e2e_asr_paraformer import Paraformer -from funasr.modules.layer_norm import LayerNorm if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): @@ -47,7 +35,7 @@ else: yield -class AdvancedContextualParaformer(Paraformer): +class NeatContextualParaformer(Paraformer): def __init__( self, vocab_size: int, diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index d8d524605..4d1009220 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -42,7 +42,7 @@ from funasr.models.decoder.rnnt_decoder import RNNTDecoder from funasr.models.joint_net.joint_network import JointNetwork from funasr.models.e2e_asr import ESPnetASRModel from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer -from funasr.models.e2e_asr_contextual_paraformer import AdvancedContextualParaformer +from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer from funasr.models.e2e_tp import TimestampPredictor from funasr.models.e2e_asr_mfcca import MFCCA from funasr.models.e2e_uni_asr import UniASR @@ -129,7 +129,7 @@ model_choices = ClassChoices( paraformer_bert=ParaformerBert, bicif_paraformer=BiCifParaformer, contextual_paraformer=ContextualParaformer, - acontextual_paraformer=AdvancedContextualParaformer, + neatcontextual_paraformer=NeatContextualParaformer, mfcca=MFCCA, timestamp_prediction=TimestampPredictor, ),