This commit is contained in:
嘉渊 2023-04-27 17:19:39 +08:00
parent 32feb7d2be
commit 6ed27c64c9
6 changed files with 42 additions and 6 deletions

View File

@ -17,6 +17,7 @@ from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.base_model import FunASRModel
@ -41,7 +42,7 @@ class ESPnetASRModel(FunASRModel):
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[torch.nn.Module],
frontend: Optional[AbsFrontend],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
encoder: AbsEncoder,

View File

@ -19,6 +19,7 @@ from funasr.modules.attention import (
RelPositionMultiHeadedAttention, # noqa: H301
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
@ -41,7 +42,8 @@ from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.modules.subsampling import Conv2dSubsamplingPad
class ConvolutionModule(nn.Module):
class ConvolutionModule(AbsEncoder):
"""ConvolutionModule in Conformer model.
Args:

View File

@ -13,6 +13,7 @@ from typeguard import check_argument_types
import logging
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.layer_norm import LayerNorm
@ -36,7 +37,7 @@ from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
class EncoderLayer(nn.Module):
class EncoderLayer(AbsEncoder):
"""Encoder layer module.
Args:

View File

@ -0,0 +1,17 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
class AbsFrontend(torch.nn.Module, ABC):
@abstractmethod
def output_size(self) -> int:
raise NotImplementedError
@abstractmethod
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@ -0,0 +1,16 @@
from typing import Optional
from typing import Tuple
import torch
class AbsSpecAug(torch.nn.Module):
"""Abstract class for the augmentation of spectrogram
The process-flow:
Frontend -> SpecAug -> Normalization -> Encoder -> Decoder
"""
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError

View File

@ -3,15 +3,14 @@ from typing import Optional
from typing import Sequence
from typing import Union
import torch.nn
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.mask_along_axis import MaskAlongAxis
from funasr.layers.mask_along_axis import MaskAlongAxisVariableMaxWidth
from funasr.layers.mask_along_axis import MaskAlongAxisLFR
from funasr.layers.time_warp import TimeWarp
class SpecAug(torch.nn.Module):
class SpecAug(AbsSpecAug):
"""Implementation of SpecAug.
Reference: