mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
5756ed9165
commit
d5a80d642a
@ -1,21 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsEncoder(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def output_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
@ -14,7 +14,6 @@ from torch import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.modules.attention import (
|
||||
MultiHeadedAttention, # noqa: H301
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
@ -277,7 +276,7 @@ class EncoderLayer(nn.Module):
|
||||
return x, mask
|
||||
|
||||
|
||||
class ConformerEncoder(AbsEncoder):
|
||||
class ConformerEncoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
Args:
|
||||
|
||||
@ -12,7 +12,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.modules.data2vec.data_utils import compute_mask_indices
|
||||
from funasr.modules.data2vec.ema_module import EMAModule
|
||||
from funasr.modules.data2vec.grad_multiply import GradMultiply
|
||||
@ -29,7 +28,7 @@ def get_annealed_rate(start, end, curr_step, total_steps):
|
||||
return end - r * pct_remaining
|
||||
|
||||
|
||||
class Data2VecEncoder(AbsEncoder):
|
||||
class Data2VecEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# for ConvFeatureExtractionModel
|
||||
|
||||
@ -34,8 +34,6 @@ from funasr.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
import pdb
|
||||
import math
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
@ -108,7 +106,7 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
|
||||
|
||||
class MFCCAEncoder(AbsEncoder):
|
||||
class MFCCAEncoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
Args:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from typing import Tuple, Optional
|
||||
from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
|
||||
from collections import OrderedDict
|
||||
@ -76,7 +75,7 @@ class BasicBlock(torch.nn.Module):
|
||||
return xs_pad, ilens
|
||||
|
||||
|
||||
class ResNet34(AbsEncoder):
|
||||
class ResNet34(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
|
||||
@ -9,10 +9,9 @@ from typeguard import check_argument_types
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.rnn.encoders import RNN
|
||||
from funasr.modules.rnn.encoders import RNNP
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
|
||||
class RNNEncoder(AbsEncoder):
|
||||
class RNNEncoder(torch.nn.Module):
|
||||
"""RNNEncoder class.
|
||||
|
||||
Args:
|
||||
|
||||
@ -26,7 +26,6 @@ from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.modules.mask import subsequent_mask, vad_mask
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
@ -115,7 +114,7 @@ class EncoderLayerSANM(nn.Module):
|
||||
|
||||
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
||||
|
||||
class SANMEncoder(AbsEncoder):
|
||||
class SANMEncoder(torch.nn.Module):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
San-m: Memory equipped self-attention for end-to-end speech recognition
|
||||
@ -547,7 +546,7 @@ class SANMEncoder(AbsEncoder):
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class SANMEncoderChunkOpt(AbsEncoder):
|
||||
class SANMEncoderChunkOpt(torch.nn.Module):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
|
||||
@ -960,7 +959,7 @@ class SANMEncoderChunkOpt(AbsEncoder):
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class SANMVadEncoder(AbsEncoder):
|
||||
class SANMVadEncoder(torch.nn.Module):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ from typeguard import check_argument_types
|
||||
import logging
|
||||
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.modules.attention import MultiHeadedAttention
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
@ -144,7 +143,7 @@ class EncoderLayer(nn.Module):
|
||||
return x, mask
|
||||
|
||||
|
||||
class TransformerEncoder(AbsEncoder):
|
||||
class TransformerEncoder(torch.nn.Module):
|
||||
"""Transformer encoder module.
|
||||
|
||||
Args:
|
||||
|
||||
@ -1,55 +0,0 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Dict
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsESPnetModel(torch.nn.Module, ABC):
|
||||
"""The common abstract class among each tasks
|
||||
|
||||
"ESPnetModel" is referred to a class which inherits torch.nn.Module,
|
||||
and makes the dnn-models forward as its member field,
|
||||
a.k.a delegate pattern,
|
||||
and defines "loss", "stats", and "weight" for the task.
|
||||
|
||||
If you intend to implement new task in ESPNet,
|
||||
the model must inherit this class.
|
||||
In other words, the "mediator" objects between
|
||||
our training system and the your task class are
|
||||
just only these three values, loss, stats, and weight.
|
||||
|
||||
Example:
|
||||
>>> from funasr.tasks.abs_task import AbsTask
|
||||
>>> class YourESPnetModel(AbsESPnetModel):
|
||||
... def forward(self, input, input_lengths):
|
||||
... ...
|
||||
... return loss, stats, weight
|
||||
>>> class YourTask(AbsTask):
|
||||
... @classmethod
|
||||
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.num_updates = 0
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, **batch: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
self.num_updates = num_updates
|
||||
|
||||
def get_num_updates(self):
|
||||
return self.num_updates
|
||||
Loading…
Reference in New Issue
Block a user