diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py index 2d4711f89..380c137b6 100644 --- a/funasr/models/data2vec.py +++ b/funasr/models/data2vec.py @@ -12,7 +12,6 @@ from typing import Tuple import torch from typeguard import check_argument_types -from funasr.layers.abs_normalize import AbsNormalize from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.torch_utils.device_funcs import force_gatherable from funasr.models.base_model import FunASRModel @@ -33,7 +32,7 @@ class Data2VecPretrainModel(FunASRModel): self, frontend: Optional[torch.nn.Module], specaug: Optional[torch.nn.Module], - normalize: Optional[AbsNormalize], + normalize: Optional[torch.nn.Module], preencoder: Optional[AbsPreEncoder], encoder: torch.nn.Module, ): diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py index 1d2b20d8d..950d69960 100644 --- a/funasr/models/e2e_asr.py +++ b/funasr/models/e2e_asr.py @@ -13,17 +13,13 @@ from typing import Union import torch from typeguard import check_argument_types -from funasr.layers.abs_normalize import AbsNormalize from funasr.losses.label_smoothing_loss import ( LabelSmoothingLoss, # noqa: H301 ) from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder -from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.e2e_asr_common import ErrorCalculator @@ -46,11 +42,11 @@ class ESPnetASRModel(FunASRModel): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], + frontend: Optional[torch.nn.Module], + specaug: Optional[torch.nn.Module], + normalize: Optional[torch.nn.Module], preencoder: Optional[AbsPreEncoder], - encoder: AbsEncoder, + encoder: torch.nn.Module, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py index d04255251..efdd90dc7 100644 --- a/funasr/models/e2e_asr_mfcca.py +++ b/funasr/models/e2e_asr_mfcca.py @@ -17,12 +17,8 @@ from funasr.losses.label_smoothing_loss import ( ) from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder -from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel -from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable @@ -43,11 +39,11 @@ class MFCCA(FunASRModel): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], + frontend: Optional[torch.nn.Module], + specaug: Optional[torch.nn.Module], + normalize: Optional[torch.nn.Module], preencoder: Optional[AbsPreEncoder], - encoder: AbsEncoder, + encoder: torch.nn.Module, decoder: AbsDecoder, ctc: CTC, rnnt_decoder: None, diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index cf5c16d15..288f46995 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -12,19 +12,15 @@ import random import numpy as np from typeguard import check_argument_types -from funasr.layers.abs_normalize import AbsNormalize from funasr.losses.label_smoothing_loss import ( LabelSmoothingLoss, # noqa: H301 ) from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder from funasr.models.e2e_asr_common import ErrorCalculator -from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.predictor.cif import mae_loss from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.nets_utils import make_pad_mask, pad_list @@ -53,11 +49,11 @@ class Paraformer(FunASRModel): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], + frontend: Optional[torch.nn.Module], + specaug: Optional[torch.nn.Module], + normalize: Optional[torch.nn.Module], preencoder: Optional[AbsPreEncoder], - encoder: AbsEncoder, + encoder: torch.nn.Module, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, @@ -620,11 +616,11 @@ class ParaformerBert(Paraformer): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], + frontend: Optional[torch.nn.Module], + specaug: Optional[torch.nn.Module], + normalize: Optional[torch.nn.Module], preencoder: Optional[AbsPreEncoder], - encoder: AbsEncoder, + encoder: torch.nn.Module, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, @@ -898,11 +894,11 @@ class BiCifParaformer(Paraformer): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], + frontend: Optional[torch.nn.Module], + specaug: Optional[torch.nn.Module], + normalize: Optional[torch.nn.Module], preencoder: Optional[AbsPreEncoder], - encoder: AbsEncoder, + encoder: torch.nn.Module, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, @@ -1111,11 +1107,11 @@ class ContextualParaformer(Paraformer): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], + frontend: Optional[torch.nn.Module], + specaug: Optional[torch.nn.Module], + normalize: Optional[torch.nn.Module], preencoder: Optional[AbsPreEncoder], - encoder: AbsEncoder, + encoder: torch.nn.Module, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC,