mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
from abc import ABC
|
|
from abc import abstractmethod
|
|
from typing import Dict
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
|
|
class AbsESPnetModel(torch.nn.Module, ABC):
|
|
"""The common abstract class among each tasks
|
|
|
|
"ESPnetModel" is referred to a class which inherits torch.nn.Module,
|
|
and makes the dnn-models forward as its member field,
|
|
a.k.a delegate pattern,
|
|
and defines "loss", "stats", and "weight" for the task.
|
|
|
|
If you intend to implement new task in ESPNet,
|
|
the model must inherit this class.
|
|
In other words, the "mediator" objects between
|
|
our training system and the your task class are
|
|
just only these three values, loss, stats, and weight.
|
|
|
|
Example:
|
|
>>> from funasr.tasks.abs_task import AbsTask
|
|
>>> class YourESPnetModel(AbsESPnetModel):
|
|
... def forward(self, input, input_lengths):
|
|
... ...
|
|
... return loss, stats, weight
|
|
>>> class YourTask(AbsTask):
|
|
... @classmethod
|
|
... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.num_updates = 0
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self, **batch: torch.Tensor
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
raise NotImplementedError
|
|
|
|
def set_num_updates(self, num_updates):
|
|
self.num_updates = num_updates
|
|
|
|
def get_num_updates(self):
|
|
return self.num_updates
|