diff --git a/funasr/models/encoder/abs_encoder.py b/funasr/models/encoder/abs_encoder.py deleted file mode 100644 index 1fb7c97c3..000000000 --- a/funasr/models/encoder/abs_encoder.py +++ /dev/null @@ -1,21 +0,0 @@ -from abc import ABC -from abc import abstractmethod -from typing import Optional -from typing import Tuple - -import torch - - -class AbsEncoder(torch.nn.Module, ABC): - @abstractmethod - def output_size(self) -> int: - raise NotImplementedError - - @abstractmethod - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - prev_states: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - raise NotImplementedError diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py index 7c7f66142..e649ecada 100644 --- a/funasr/models/encoder/conformer_encoder.py +++ b/funasr/models/encoder/conformer_encoder.py @@ -14,7 +14,6 @@ from torch import nn from typeguard import check_argument_types from funasr.models.ctc import CTC -from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.modules.attention import ( MultiHeadedAttention, # noqa: H301 RelPositionMultiHeadedAttention, # noqa: H301 @@ -277,7 +276,7 @@ class EncoderLayer(nn.Module): return x, mask -class ConformerEncoder(AbsEncoder): +class ConformerEncoder(torch.nn.Module): """Conformer encoder module. Args: diff --git a/funasr/models/encoder/data2vec_encoder.py b/funasr/models/encoder/data2vec_encoder.py index fd1796ca9..a30e91ec7 100644 --- a/funasr/models/encoder/data2vec_encoder.py +++ b/funasr/models/encoder/data2vec_encoder.py @@ -12,7 +12,6 @@ import torch.nn as nn import torch.nn.functional as F from typeguard import check_argument_types -from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.modules.data2vec.data_utils import compute_mask_indices from funasr.modules.data2vec.ema_module import EMAModule from funasr.modules.data2vec.grad_multiply import GradMultiply @@ -29,7 +28,7 @@ def get_annealed_rate(start, end, curr_step, total_steps): return end - r * pct_remaining -class Data2VecEncoder(AbsEncoder): +class Data2VecEncoder(torch.nn.Module): def __init__( self, # for ConvFeatureExtractionModel diff --git a/funasr/models/encoder/mfcca_encoder.py b/funasr/models/encoder/mfcca_encoder.py index 83d0b0e24..9ffd452f3 100644 --- a/funasr/models/encoder/mfcca_encoder.py +++ b/funasr/models/encoder/mfcca_encoder.py @@ -34,8 +34,6 @@ from funasr.modules.subsampling import Conv2dSubsampling6 from funasr.modules.subsampling import Conv2dSubsampling8 from funasr.modules.subsampling import TooShortUttError from funasr.modules.subsampling import check_short_utt -from funasr.models.encoder.abs_encoder import AbsEncoder -import pdb import math class ConvolutionModule(nn.Module): @@ -108,7 +106,7 @@ class ConvolutionModule(nn.Module): -class MFCCAEncoder(AbsEncoder): +class MFCCAEncoder(torch.nn.Module): """Conformer encoder module. Args: diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py index 7d7179a00..6f978eb34 100644 --- a/funasr/models/encoder/resnet34_encoder.py +++ b/funasr/models/encoder/resnet34_encoder.py @@ -1,6 +1,5 @@ import torch from torch.nn import functional as F -from funasr.models.encoder.abs_encoder import AbsEncoder from typing import Tuple, Optional from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling from collections import OrderedDict @@ -76,7 +75,7 @@ class BasicBlock(torch.nn.Module): return xs_pad, ilens -class ResNet34(AbsEncoder): +class ResNet34(torch.nn.Module): def __init__( self, input_size, diff --git a/funasr/models/encoder/rnn_encoder.py b/funasr/models/encoder/rnn_encoder.py index 7a3b05399..6b75574c1 100644 --- a/funasr/models/encoder/rnn_encoder.py +++ b/funasr/models/encoder/rnn_encoder.py @@ -9,10 +9,9 @@ from typeguard import check_argument_types from funasr.modules.nets_utils import make_pad_mask from funasr.modules.rnn.encoders import RNN from funasr.modules.rnn.encoders import RNNP -from funasr.models.encoder.abs_encoder import AbsEncoder -class RNNEncoder(AbsEncoder): +class RNNEncoder(torch.nn.Module): """RNNEncoder class. Args: diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index 2a3a35353..14624037f 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -26,7 +26,6 @@ from funasr.modules.subsampling import Conv2dSubsampling8 from funasr.modules.subsampling import TooShortUttError from funasr.modules.subsampling import check_short_utt from funasr.models.ctc import CTC -from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.modules.mask import subsequent_mask, vad_mask class EncoderLayerSANM(nn.Module): @@ -115,7 +114,7 @@ class EncoderLayerSANM(nn.Module): return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder -class SANMEncoder(AbsEncoder): +class SANMEncoder(torch.nn.Module): """ author: Speech Lab, Alibaba Group, China San-m: Memory equipped self-attention for end-to-end speech recognition @@ -547,7 +546,7 @@ class SANMEncoder(AbsEncoder): return var_dict_torch_update -class SANMEncoderChunkOpt(AbsEncoder): +class SANMEncoderChunkOpt(torch.nn.Module): """ author: Speech Lab, Alibaba Group, China SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition @@ -960,7 +959,7 @@ class SANMEncoderChunkOpt(AbsEncoder): return var_dict_torch_update -class SANMVadEncoder(AbsEncoder): +class SANMVadEncoder(torch.nn.Module): """ author: Speech Lab, Alibaba Group, China diff --git a/funasr/models/encoder/transformer_encoder.py b/funasr/models/encoder/transformer_encoder.py index ff9c3db51..55a65b30e 100644 --- a/funasr/models/encoder/transformer_encoder.py +++ b/funasr/models/encoder/transformer_encoder.py @@ -13,7 +13,6 @@ from typeguard import check_argument_types import logging from funasr.models.ctc import CTC -from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.modules.attention import MultiHeadedAttention from funasr.modules.embedding import PositionalEncoding from funasr.modules.layer_norm import LayerNorm @@ -144,7 +143,7 @@ class EncoderLayer(nn.Module): return x, mask -class TransformerEncoder(AbsEncoder): +class TransformerEncoder(torch.nn.Module): """Transformer encoder module. Args: diff --git a/funasr/train/abs_espnet_model.py b/funasr/train/abs_espnet_model.py deleted file mode 100644 index cc6a5a2a0..000000000 --- a/funasr/train/abs_espnet_model.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -from abc import ABC -from abc import abstractmethod -from typing import Dict -from typing import Tuple - -import torch - - -class AbsESPnetModel(torch.nn.Module, ABC): - """The common abstract class among each tasks - - "ESPnetModel" is referred to a class which inherits torch.nn.Module, - and makes the dnn-models forward as its member field, - a.k.a delegate pattern, - and defines "loss", "stats", and "weight" for the task. - - If you intend to implement new task in ESPNet, - the model must inherit this class. - In other words, the "mediator" objects between - our training system and the your task class are - just only these three values, loss, stats, and weight. - - Example: - >>> from funasr.tasks.abs_task import AbsTask - >>> class YourESPnetModel(AbsESPnetModel): - ... def forward(self, input, input_lengths): - ... ... - ... return loss, stats, weight - >>> class YourTask(AbsTask): - ... @classmethod - ... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel: - """ - - def __init__(self): - super().__init__() - self.num_updates = 0 - - @abstractmethod - def forward( - self, **batch: torch.Tensor - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: - raise NotImplementedError - - @abstractmethod - def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: - raise NotImplementedError - - def set_num_updates(self, num_updates): - self.num_updates = num_updates - - def get_num_updates(self): - return self.num_updates