mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
80aeac6edc
commit
0149983c23
@ -17,10 +17,10 @@ from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
|
||||
|
||||
21
funasr/models/encoder/abs_encoder.py
Normal file
21
funasr/models/encoder/abs_encoder.py
Normal file
@ -0,0 +1,21 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsEncoder(torch.nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def output_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
Loading…
Reference in New Issue
Block a user