diff --git a/funasr/models/base_model.py b/funasr/models/base_model.py new file mode 100644 index 000000000..80b3bbd71 --- /dev/null +++ b/funasr/models/base_model.py @@ -0,0 +1,17 @@ +import torch + + +class FunASRModel(torch.nn.Module): + """The common model class + + """ + + def __init__(self): + super().__init__() + self.num_updates = 0 + + def set_num_updates(self, num_updates): + self.num_updates = num_updates + + def get_num_updates(self): + return self.num_updates diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py index f64ea3dbe..1d2b20d8d 100644 --- a/funasr/models/e2e_asr.py +++ b/funasr/models/e2e_asr.py @@ -24,11 +24,11 @@ from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.e2e_asr_common import ErrorCalculator from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -39,7 +39,7 @@ else: yield -class ESPnetASRModel(AbsESPnetModel): +class ESPnetASRModel(FunASRModel): """CTC-attention hybrid Encoder-Decoder model""" def __init__( diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py index 033613382..d04255251 100644 --- a/funasr/models/e2e_asr_mfcca.py +++ b/funasr/models/e2e_asr_mfcca.py @@ -21,9 +21,10 @@ from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.base_model import FunASRModel from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -35,7 +36,7 @@ else: import pdb import random import math -class MFCCA(AbsESPnetModel): +class MFCCA(FunASRModel): """CTC-attention hybrid Encoder-Decoder model""" def __init__( diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index f1bb2bfc1..cf5c16d15 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -25,11 +25,11 @@ from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.predictor.cif import mae_loss from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.nets_utils import make_pad_mask, pad_list from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.models.predictor.cif import CifPredictorV3 @@ -42,7 +42,7 @@ else: yield -class Paraformer(AbsESPnetModel): +class Paraformer(FunASRModel): """ Author: Speech Lab, Alibaba Group, China Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py index 097b23a57..b4a3fa290 100644 --- a/funasr/models/e2e_diar_eend_ola.py +++ b/funasr/models/e2e_diar_eend_ola.py @@ -15,8 +15,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23 from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor from funasr.modules.eend_ola.utils.power import generate_mapping_dict +from funasr.models.base_model import FunASRModel from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): pass @@ -34,7 +34,7 @@ def pad_attractor(att, max_n_speakers): return att -class DiarEENDOLAModel(AbsESPnetModel): +class DiarEENDOLAModel(FunASRModel): """EEND-OLA diarization model""" def __init__( diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py index de669f2ee..a6d780719 100644 --- a/funasr/models/e2e_diar_sond.py +++ b/funasr/models/e2e_diar_sond.py @@ -20,9 +20,9 @@ from funasr.models.decoder.abs_decoder import AbsDecoder from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.base_model import FunASRModel from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy from funasr.utils.misc import int2vec @@ -35,7 +35,7 @@ else: yield -class DiarSondModel(AbsESPnetModel): +class DiarSondModel(FunASRModel): """Speaker overlap-aware neural diarization model reference: https://arxiv.org/abs/2211.10243 """ diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py index eff63d967..dd2ea8438 100644 --- a/funasr/models/e2e_sv.py +++ b/funasr/models/e2e_sv.py @@ -21,11 +21,11 @@ from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.e2e_asr_common import ErrorCalculator from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -36,7 +36,7 @@ else: yield -class ESPnetSVModel(AbsESPnetModel): +class ESPnetSVModel(FunASRModel): """CTC-attention hybrid Encoder-Decoder model""" def __init__( diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py index 887439c5e..ac8c3b40c 100644 --- a/funasr/models/e2e_tp.py +++ b/funasr/models/e2e_tp.py @@ -14,10 +14,10 @@ from typeguard import check_argument_types from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.predictor.cif import mae_loss +from funasr.models.base_model import FunASRModel from funasr.modules.add_sos_eos import add_sos_eos from funasr.modules.nets_utils import make_pad_mask, pad_list from funasr.torch_utils.device_funcs import force_gatherable -from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.models.predictor.cif import CifPredictorV3 @@ -30,7 +30,7 @@ else: yield -class TimestampPredictor(AbsESPnetModel): +class TimestampPredictor(FunASRModel): """ Author: Speech Lab, Alibaba Group, China """ diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py index ac4db329b..0c26b8ea4 100644 --- a/funasr/models/e2e_uni_asr.py +++ b/funasr/models/e2e_uni_asr.py @@ -23,6 +23,7 @@ from funasr.models.frontend.abs_frontend import AbsFrontend from funasr.models.postencoder.abs_postencoder import AbsPostEncoder from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.base_model import FunASRModel from funasr.layers.abs_normalize import AbsNormalize from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel @@ -38,7 +39,7 @@ else: yield -class UniASR(AbsESPnetModel): +class UniASR(FunASRModel): """ Author: Speech Lab, Alibaba Group, China """