mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
32feb7d2be
commit
6ed27c64c9
@ -17,6 +17,7 @@ from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.base_model import FunASRModel
|
||||
@ -41,7 +42,7 @@ class ESPnetASRModel(FunASRModel):
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[torch.nn.Module],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[torch.nn.Module],
|
||||
normalize: Optional[torch.nn.Module],
|
||||
encoder: AbsEncoder,
|
||||
|
||||
@ -19,6 +19,7 @@ from funasr.modules.attention import (
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
LegacyRelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.modules.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
@ -41,7 +42,8 @@ from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.modules.subsampling import Conv2dSubsamplingPad
|
||||
class ConvolutionModule(nn.Module):
|
||||
|
||||
class ConvolutionModule(AbsEncoder):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
|
||||
Args:
|
||||
|
||||
@ -13,6 +13,7 @@ from typeguard import check_argument_types
|
||||
import logging
|
||||
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.modules.attention import MultiHeadedAttention
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
@ -36,7 +37,7 @@ from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
class EncoderLayer(AbsEncoder):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
|
||||
17
funasr/models/frontend/abs_frontend.py
Normal file
17
funasr/models/frontend/abs_frontend.py
Normal file
@ -0,0 +1,17 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsFrontend(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def output_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
16
funasr/models/specaug/abs_specaug.py
Normal file
16
funasr/models/specaug/abs_specaug.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsSpecAug(torch.nn.Module):
|
||||
"""Abstract class for the augmentation of spectrogram
|
||||
The process-flow:
|
||||
Frontend -> SpecAug -> Normalization -> Encoder -> Decoder
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lengths: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
@ -3,15 +3,14 @@ from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Union
|
||||
|
||||
import torch.nn
|
||||
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.layers.mask_along_axis import MaskAlongAxis
|
||||
from funasr.layers.mask_along_axis import MaskAlongAxisVariableMaxWidth
|
||||
from funasr.layers.mask_along_axis import MaskAlongAxisLFR
|
||||
from funasr.layers.time_warp import TimeWarp
|
||||
|
||||
|
||||
class SpecAug(torch.nn.Module):
|
||||
class SpecAug(AbsSpecAug):
|
||||
"""Implementation of SpecAug.
|
||||
|
||||
Reference:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user