From dfa356a10c698e4e0548ab2d05ae31ab142bd4aa Mon Sep 17 00:00:00 2001 From: speech_asr Date: Tue, 11 Apr 2023 00:27:54 +0800 Subject: [PATCH] update --- funasr/models/e2e_diar_sond.py | 14 ++++---------- funasr/models/e2e_sv.py | 19 ++++--------------- funasr/models/e2e_tp.py | 9 ++------- funasr/models/e2e_uni_asr.py | 14 +++++--------- 4 files changed, 15 insertions(+), 41 deletions(-) diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py index a6d780719..dc7135f4e 100644 --- a/funasr/models/e2e_diar_sond.py +++ b/funasr/models/e2e_diar_sond.py @@ -14,14 +14,8 @@ import torch from torch.nn import functional as F from typeguard import check_argument_types -from funasr.modules.nets_utils import to_device from funasr.modules.nets_utils import make_pad_mask -from funasr.models.decoder.abs_decoder import AbsDecoder -from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.frontend.abs_frontend import AbsFrontend -from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel -from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy from funasr.utils.misc import int2vec @@ -43,9 +37,9 @@ class DiarSondModel(FunASRModel): def __init__( self, vocab_size: int, - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], + frontend: Optional[torch.nn.Module], + specaug: Optional[torch.nn.Module], + normalize: Optional[torch.nn.Module], encoder: torch.nn.Module, speaker_encoder: Optional[torch.nn.Module], ci_scorer: torch.nn.Module, @@ -348,7 +342,7 @@ class DiarSondModel(FunASRModel): cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1]) cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1]) - if isinstance(self.ci_scorer, AbsEncoder): + if isinstance(self.ci_scorer, torch.nn.Module): ci_simi = self.ci_scorer(ge_in, ge_len)[0] ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1]) else: diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py index dd2ea8438..582c25dfb 100644 --- a/funasr/models/e2e_sv.py +++ b/funasr/models/e2e_sv.py @@ -10,21 +10,10 @@ from typing import Union import torch from typeguard import check_argument_types -from funasr.layers.abs_normalize import AbsNormalize -from funasr.losses.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) -from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder -from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel -from funasr.modules.add_sos_eos import add_sos_eos -from funasr.modules.e2e_asr_common import ErrorCalculator -from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): @@ -43,11 +32,11 @@ class ESPnetSVModel(FunASRModel): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], + frontend: Optional[torch.nn.Module], + specaug: Optional[torch.nn.Module], + normalize: Optional[torch.nn.Module], preencoder: Optional[AbsPreEncoder], - encoder: AbsEncoder, + encoder: torch.nn.Module, postencoder: Optional[AbsPostEncoder], pooling_layer: torch.nn.Module, decoder: AbsDecoder, diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py index ac8c3b40c..c5dc63c59 100644 --- a/funasr/models/e2e_tp.py +++ b/funasr/models/e2e_tp.py @@ -2,17 +2,12 @@ import logging from contextlib import contextmanager from distutils.version import LooseVersion from typing import Dict -from typing import List from typing import Optional from typing import Tuple -from typing import Union import torch -import numpy as np from typeguard import check_argument_types -from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.predictor.cif import mae_loss from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos @@ -37,8 +32,8 @@ class TimestampPredictor(FunASRModel): def __init__( self, - frontend: Optional[AbsFrontend], - encoder: AbsEncoder, + frontend: Optional[torch.nn.Module], + encoder: torch.nn.Module, predictor: CifPredictorV3, predictor_bias: int = 0, token_list=None, diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py index 0c26b8ea4..ee5e2baba 100644 --- a/funasr/models/e2e_uni_asr.py +++ b/funasr/models/e2e_uni_asr.py @@ -18,15 +18,11 @@ from funasr.losses.label_smoothing_loss import ( ) from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder -from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.models.base_model import FunASRModel -from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel +from funasr.models.base_model import FunASRModel from funasr.modules.streaming_utils.chunk_utilis import sequence_mask from funasr.models.predictor.cif import mae_loss @@ -48,11 +44,11 @@ class UniASR(FunASRModel): self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], + frontend: Optional[torch.nn.Module], + specaug: Optional[torch.nn.Module], + normalize: Optional[torch.nn.Module], preencoder: Optional[AbsPreEncoder], - encoder: AbsEncoder, + encoder: torch.nn.Module, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC,