This commit is contained in:
speech_asr 2023-04-11 00:09:29 +08:00
parent 5756ed9165
commit d5a80d642a
9 changed files with 9 additions and 93 deletions

View File

@ -1,21 +0,0 @@
from abc import ABC
from abc import abstractmethod
from typing import Optional
from typing import Tuple
import torch
class AbsEncoder(torch.nn.Module, ABC):
@abstractmethod
def output_size(self) -> int:
raise NotImplementedError
@abstractmethod
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError

View File

@ -14,7 +14,6 @@ from torch import nn
from typeguard import check_argument_types
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
@ -277,7 +276,7 @@ class EncoderLayer(nn.Module):
return x, mask
class ConformerEncoder(AbsEncoder):
class ConformerEncoder(torch.nn.Module):
"""Conformer encoder module.
Args:

View File

@ -12,7 +12,6 @@ import torch.nn as nn
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.data2vec.data_utils import compute_mask_indices
from funasr.modules.data2vec.ema_module import EMAModule
from funasr.modules.data2vec.grad_multiply import GradMultiply
@ -29,7 +28,7 @@ def get_annealed_rate(start, end, curr_step, total_steps):
return end - r * pct_remaining
class Data2VecEncoder(AbsEncoder):
class Data2VecEncoder(torch.nn.Module):
def __init__(
self,
# for ConvFeatureExtractionModel

View File

@ -34,8 +34,6 @@ from funasr.modules.subsampling import Conv2dSubsampling6
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.encoder.abs_encoder import AbsEncoder
import pdb
import math
class ConvolutionModule(nn.Module):
@ -108,7 +106,7 @@ class ConvolutionModule(nn.Module):
class MFCCAEncoder(AbsEncoder):
class MFCCAEncoder(torch.nn.Module):
"""Conformer encoder module.
Args:

View File

@ -1,6 +1,5 @@
import torch
from torch.nn import functional as F
from funasr.models.encoder.abs_encoder import AbsEncoder
from typing import Tuple, Optional
from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
from collections import OrderedDict
@ -76,7 +75,7 @@ class BasicBlock(torch.nn.Module):
return xs_pad, ilens
class ResNet34(AbsEncoder):
class ResNet34(torch.nn.Module):
def __init__(
self,
input_size,

View File

@ -9,10 +9,9 @@ from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.rnn.encoders import RNN
from funasr.modules.rnn.encoders import RNNP
from funasr.models.encoder.abs_encoder import AbsEncoder
class RNNEncoder(AbsEncoder):
class RNNEncoder(torch.nn.Module):
"""RNNEncoder class.
Args:

View File

@ -26,7 +26,6 @@ from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
@ -115,7 +114,7 @@ class EncoderLayerSANM(nn.Module):
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
class SANMEncoder(AbsEncoder):
class SANMEncoder(torch.nn.Module):
"""
author: Speech Lab, Alibaba Group, China
San-m: Memory equipped self-attention for end-to-end speech recognition
@ -547,7 +546,7 @@ class SANMEncoder(AbsEncoder):
return var_dict_torch_update
class SANMEncoderChunkOpt(AbsEncoder):
class SANMEncoderChunkOpt(torch.nn.Module):
"""
author: Speech Lab, Alibaba Group, China
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
@ -960,7 +959,7 @@ class SANMEncoderChunkOpt(AbsEncoder):
return var_dict_torch_update
class SANMVadEncoder(AbsEncoder):
class SANMVadEncoder(torch.nn.Module):
"""
author: Speech Lab, Alibaba Group, China

View File

@ -13,7 +13,6 @@ from typeguard import check_argument_types
import logging
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.layer_norm import LayerNorm
@ -144,7 +143,7 @@ class EncoderLayer(nn.Module):
return x, mask
class TransformerEncoder(AbsEncoder):
class TransformerEncoder(torch.nn.Module):
"""Transformer encoder module.
Args:

View File

@ -1,55 +0,0 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from abc import ABC
from abc import abstractmethod
from typing import Dict
from typing import Tuple
import torch
class AbsESPnetModel(torch.nn.Module, ABC):
"""The common abstract class among each tasks
"ESPnetModel" is referred to a class which inherits torch.nn.Module,
and makes the dnn-models forward as its member field,
a.k.a delegate pattern,
and defines "loss", "stats", and "weight" for the task.
If you intend to implement new task in ESPNet,
the model must inherit this class.
In other words, the "mediator" objects between
our training system and the your task class are
just only these three values, loss, stats, and weight.
Example:
>>> from funasr.tasks.abs_task import AbsTask
>>> class YourESPnetModel(AbsESPnetModel):
... def forward(self, input, input_lengths):
... ...
... return loss, stats, weight
>>> class YourTask(AbsTask):
... @classmethod
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
"""
def __init__(self):
super().__init__()
self.num_updates = 0
@abstractmethod
def forward(
self, **batch: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
raise NotImplementedError
@abstractmethod
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def set_num_updates(self, num_updates):
self.num_updates = num_updates
def get_num_updates(self):
return self.num_updates