diff --git a/examples/industrial_data_pretraining/paraformer-large/infer.sh b/examples/industrial_data_pretraining/paraformer-large/infer.sh index 48ad3bf8d..87260acfd 100644 --- a/examples/industrial_data_pretraining/paraformer-large/infer.sh +++ b/examples/industrial_data_pretraining/paraformer-large/infer.sh @@ -2,13 +2,13 @@ cmd="funasr/bin/inference.py" python $cmd \ -+model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ ++model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ +input="/Users/zhifu/Downloads/asr_example.wav" \ +output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \ +device="cpu" \ python $cmd \ -+model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \ ++model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \ +input="/Users/zhifu/Downloads/asr_example.wav" \ +output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \ +device="cpu" \ diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py index 50ea4d4b4..d7b33e379 100644 --- a/funasr/bin/inference.py +++ b/funasr/bin/inference.py @@ -15,7 +15,7 @@ from funasr.train_utils.load_pretrained_model import load_pretrained_model import time import random import string -from funasr.utils.register import registry_tables +from funasr.register import tables def build_iter_for_infer(data_in, input_len=None, data_type="sound"): @@ -81,7 +81,7 @@ def main_hydra(kwargs: DictConfig): class AutoModel: def __init__(self, **kwargs): - registry_tables.print() + tables.print() assert "model" in kwargs if "model_conf" not in kwargs: logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms"))) @@ -98,7 +98,7 @@ class AutoModel: # build tokenizer tokenizer = kwargs.get("tokenizer", None) if tokenizer is not None: - tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower()) + tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower()) tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) kwargs["tokenizer"] = tokenizer kwargs["token_list"] = tokenizer.token_list @@ -106,13 +106,13 @@ class AutoModel: # build frontend frontend = kwargs.get("frontend", None) if frontend is not None: - frontend_class = registry_tables.frontend_classes.get(frontend.lower()) + frontend_class = tables.frontend_classes.get(frontend.lower()) frontend = frontend_class(**kwargs["frontend_conf"]) kwargs["frontend"] = frontend kwargs["input_size"] = frontend.output_size() # build model - model_class = registry_tables.model_classes.get(kwargs["model"].lower()) + model_class = tables.model_classes.get(kwargs["model"].lower()) model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1) model.eval() model.to(device) diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 1e06c5037..b1f0d06b2 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -21,7 +21,7 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from funasr.download.download_from_hub import download_model -from funasr.utils.register import registry_tables +from funasr.register import tables @hydra.main(config_name=None, version_base=None) def main_hydra(kwargs: DictConfig): @@ -39,7 +39,7 @@ def main(**kwargs): # preprocess_config(kwargs) # import pdb; pdb.set_trace() # set random seed - registry_tables.print() + tables.print() set_all_random_seed(kwargs.get("seed", 0)) torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) @@ -62,14 +62,14 @@ def main(**kwargs): tokenizer = kwargs.get("tokenizer", None) if tokenizer is not None: - tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower()) + tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower()) tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) kwargs["tokenizer"] = tokenizer # build frontend if frontend is none None frontend = kwargs.get("frontend", None) if frontend is not None: - frontend_class = registry_tables.frontend_classes.get(frontend.lower()) + frontend_class = tables.frontend_classes.get(frontend.lower()) frontend = frontend_class(**kwargs["frontend_conf"]) kwargs["frontend"] = frontend kwargs["input_size"] = frontend.output_size() @@ -77,7 +77,7 @@ def main(**kwargs): # import pdb; # pdb.set_trace() # build model - model_class = registry_tables.model_classes.get(kwargs["model"].lower()) + model_class = tables.model_classes.get(kwargs["model"].lower()) model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) @@ -139,12 +139,12 @@ def main(**kwargs): # import pdb; # pdb.set_trace() # dataset - dataset_class = registry_tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower()) + dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower()) dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) # dataloader batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") - batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower()) + batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler.lower()) if batch_sampler is not None: batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) dataloader_tr = torch.utils.data.DataLoader(dataset_tr, diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py index d69d0b573..bfdf86a4f 100644 --- a/funasr/datasets/audio_datasets/datasets.py +++ b/funasr/datasets/audio_datasets/datasets.py @@ -9,9 +9,9 @@ import time import logging from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("dataset_classes", "AudioDataset") +@tables.register("dataset_classes", "AudioDataset") class AudioDataset(torch.utils.data.Dataset): def __init__(self, path, @@ -22,16 +22,16 @@ class AudioDataset(torch.utils.data.Dataset): float_pad_value: float = 0.0, **kwargs): super().__init__() - index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower()) + index_ds_class = tables.index_ds_classes.get(index_ds.lower()) self.index_ds = index_ds_class(path) preprocessor_speech = kwargs.get("preprocessor_speech", None) if preprocessor_speech: - preprocessor_speech_class = registry_tables.preprocessor_speech_classes.get(preprocessor_speech.lower()) + preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech.lower()) preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) self.preprocessor_speech = preprocessor_speech preprocessor_text = kwargs.get("preprocessor_text", None) if preprocessor_text: - preprocessor_text_class = registry_tables.preprocessor_text_classes.get(preprocessor_text.lower()) + preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text.lower()) preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) self.preprocessor_text = preprocessor_text diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py index 33b309adb..79bb26eb0 100644 --- a/funasr/datasets/audio_datasets/index_ds.py +++ b/funasr/datasets/audio_datasets/index_ds.py @@ -4,9 +4,9 @@ import torch.distributed as dist import time import logging -from funasr.utils.register import register_class +from funasr.register import tables -@register_class("index_ds_classes", "IndexDSJsonl") +@tables.register("index_ds_classes", "IndexDSJsonl") class IndexDSJsonl(torch.utils.data.Dataset): def __init__(self, path): diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index 7d3a94197..d34fdeae7 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -2,9 +2,9 @@ import torch import numpy as np -from funasr.utils.register import register_class +from funasr.register import tables -@register_class("batch_sampler_classes", "DynamicBatchLocalShuffleSampler") +@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler") class BatchSampler(torch.utils.data.BatchSampler): def __init__(self, dataset, diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py index 4866fa188..746bf82ce 100644 --- a/funasr/frontends/wav_frontend.py +++ b/funasr/frontends/wav_frontend.py @@ -9,7 +9,7 @@ import torchaudio.compliance.kaldi as kaldi from torch.nn.utils.rnn import pad_sequence import funasr.frontends.eend_ola_feature as eend_ola_feature -from funasr.utils.register import register_class +from funasr.register import tables @@ -75,7 +75,7 @@ def apply_lfr(inputs, lfr_m, lfr_n): LFR_outputs = torch.vstack(LFR_inputs) return LFR_outputs.type(torch.float32) -@register_class("frontend_classes", "WavFrontend") +@tables.register("frontend_classes", "WavFrontend") class WavFrontend(nn.Module): """Conventional frontend structure for ASR. """ @@ -211,7 +211,7 @@ class WavFrontend(nn.Module): return feats_pad, feats_lens -@register_class("frontend_classes", "WavFrontendOnline") +@tables.register("frontend_classes", "WavFrontendOnline") class WavFrontendOnline(nn.Module): """Conventional frontend structure for streaming ASR/VAD. """ diff --git a/funasr/models/bat/cif_predictor.py b/funasr/models/bat/cif_predictor.py index 9aa3e337d..d8915c226 100644 --- a/funasr/models/bat/cif_predictor.py +++ b/funasr/models/bat/cif_predictor.py @@ -8,7 +8,7 @@ # from funasr.models.scama.utils import sequence_mask # from typing import Optional, Tuple # -# from funasr.utils.register import register_class +# from funasr.register import tables # # class mae_loss(nn.Module): # @@ -93,7 +93,7 @@ # fires = torch.stack(list_fires, 1) # return fires # -# @register_class("predictor_classes", "BATPredictor") +# @tables.register("predictor_classes", "BATPredictor") # class BATPredictor(nn.Module): # def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False): # super(BATPredictor, self).__init__() diff --git a/funasr/models/bat/conformer_chunk_encoder.py b/funasr/models/bat/conformer_chunk_encoder.py index 2dc03c3b4..7635c0289 100644 --- a/funasr/models/bat/conformer_chunk_encoder.py +++ b/funasr/models/bat/conformer_chunk_encoder.py @@ -29,7 +29,7 @@ from funasr.models.transformer.utils.repeat import repeat, MultiBlocks from funasr.models.transformer.utils.subsampling import TooShortUttError from funasr.models.transformer.utils.subsampling import check_short_utt from funasr.models.transformer.utils.subsampling import StreamingConvInput -from funasr.utils.register import register_class +from funasr.register import tables @@ -312,7 +312,7 @@ class CausalConvolution(nn.Module): return x, cache -@register_class("encoder_classes", "ConformerChunkEncoder") +@tables.register("encoder_classes", "ConformerChunkEncoder") class ConformerChunkEncoder(nn.Module): """Encoder module definition. Args: diff --git a/funasr/models/bici_paraformer/cif_predictor.py b/funasr/models/bici_paraformer/cif_predictor.py index 67d801c07..5a1488e9e 100644 --- a/funasr/models/bici_paraformer/cif_predictor.py +++ b/funasr/models/bici_paraformer/cif_predictor.py @@ -8,7 +8,7 @@ from funasr.models.transformer.utils.nets_utils import make_pad_mask from funasr.models.scama.utils import sequence_mask from typing import Optional, Tuple -from funasr.utils.register import register_class +from funasr.register import tables class mae_loss(nn.Module): @@ -94,7 +94,7 @@ def cif_wo_hidden(alphas, threshold): fires = torch.stack(list_fires, 1) return fires -@register_class("predictor_classes", "CifPredictorV3") +@tables.register("predictor_classes", "CifPredictorV3") class CifPredictorV3(nn.Module): def __init__(self, idim, diff --git a/funasr/models/bici_paraformer/model.py b/funasr/models/bici_paraformer/model.py index 23a698598..52eac875d 100644 --- a/funasr/models/bici_paraformer/model.py +++ b/funasr/models/bici_paraformer/model.py @@ -27,12 +27,12 @@ from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, from funasr.utils import postprocess_utils from funasr.utils.datadir_writer import DatadirWriter from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables from funasr.models.ctc.ctc import CTC from funasr.models.paraformer.model import Paraformer -@register_class("model_classes", "BiCifParaformer") +@tables.register("model_classes", "BiCifParaformer") class BiCifParaformer(Paraformer): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/branchformer/encoder.py b/funasr/models/branchformer/encoder.py index 11b64290d..4b5b2377d 100644 --- a/funasr/models/branchformer/encoder.py +++ b/funasr/models/branchformer/encoder.py @@ -43,7 +43,7 @@ from funasr.models.transformer.utils.subsampling import ( check_short_utt, ) -from funasr.utils.register import register_class +from funasr.register import tables class BranchformerEncoderLayer(torch.nn.Module): """Branchformer encoder layer module. @@ -291,7 +291,7 @@ class BranchformerEncoderLayer(torch.nn.Module): return x, mask -@register_class("encoder_classes", "BranchformerEncoder") +@tables.register("encoder_classes", "BranchformerEncoder") class BranchformerEncoder(nn.Module): """Branchformer encoder module.""" diff --git a/funasr/models/branchformer/model.py b/funasr/models/branchformer/model.py index a14b40798..53f254df7 100644 --- a/funasr/models/branchformer/model.py +++ b/funasr/models/branchformer/model.py @@ -1,9 +1,9 @@ import logging from funasr.models.transformer.model import Transformer -from funasr.utils.register import register_class +from funasr.register import tables -@register_class("model_classes", "Branchformer") +@tables.register("model_classes", "Branchformer") class Branchformer(Transformer): """CTC-attention hybrid Encoder-Decoder model""" diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py index 709e10e41..1ca437da4 100644 --- a/funasr/models/conformer/encoder.py +++ b/funasr/models/conformer/encoder.py @@ -45,7 +45,7 @@ from funasr.models.transformer.utils.subsampling import TooShortUttError from funasr.models.transformer.utils.subsampling import check_short_utt from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad from funasr.models.transformer.utils.subsampling import StreamingConvInput -from funasr.utils.register import register_class +from funasr.register import tables class ConvolutionModule(nn.Module): @@ -283,7 +283,7 @@ class EncoderLayer(nn.Module): return x, mask -@register_class("encoder_classes", "ConformerEncoder") +@tables.register("encoder_classes", "ConformerEncoder") class ConformerEncoder(nn.Module): """Conformer encoder module. diff --git a/funasr/models/conformer/model.py b/funasr/models/conformer/model.py index 5319a73fd..2c267532e 100644 --- a/funasr/models/conformer/model.py +++ b/funasr/models/conformer/model.py @@ -3,9 +3,9 @@ import logging import torch from funasr.models.transformer.model import Transformer -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("model_classes", "Conformer") +@tables.register("model_classes", "Conformer") class Conformer(Transformer): """CTC-attention hybrid Encoder-Decoder model""" diff --git a/funasr/models/conformer/template.yaml b/funasr/models/conformer/template.yaml index 609431391..4cbeca46f 100644 --- a/funasr/models/conformer/template.yaml +++ b/funasr/models/conformer/template.yaml @@ -2,8 +2,8 @@ # You can modify the configuration according to your own requirements. # to print the register_table: -# from funasr.utils.register import registry_tables -# registry_tables.print() +# from funasr.register import tables +# tables.print() # network architecture #model: funasr.models.paraformer.model:Paraformer diff --git a/funasr/models/ct_transformer/encoder.py b/funasr/models/ct_transformer/encoder.py index 1bdf5d586..784baf37d 100644 --- a/funasr/models/ct_transformer/encoder.py +++ b/funasr/models/ct_transformer/encoder.py @@ -31,7 +31,7 @@ from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask from funasr.models.ctc.ctc import CTC -from funasr.utils.register import register_class +from funasr.register import tables class EncoderLayerSANM(nn.Module): def __init__( @@ -155,7 +155,7 @@ class EncoderLayerSANM(nn.Module): return x, cache -@register_class("encoder_classes", "SANMVadEncoder") +@tables.register("encoder_classes", "SANMVadEncoder") class SANMVadEncoder(nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py index 31b2af2aa..d8c7fc3bf 100644 --- a/funasr/models/ct_transformer/model.py +++ b/funasr/models/ct_transformer/model.py @@ -5,9 +5,9 @@ from typing import Tuple import torch import torch.nn as nn -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("model_classes", "CTTransformer") +@tables.register("model_classes", "CTTransformer") class CTTransformer(nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group @@ -37,7 +37,7 @@ class CTTransformer(nn.Module): self.embed = nn.Embedding(vocab_size, embed_unit) - encoder_class = registry_tables.encoder_classes.get(encoder.lower()) + encoder_class = tables.encoder_classes.get(encoder.lower()) encoder = encoder_class(**encoder_conf) self.decoder = nn.Linear(att_unit, punc_size) diff --git a/funasr/models/e_branchformer/encoder.py b/funasr/models/e_branchformer/encoder.py index 5604c9fed..4084e2116 100644 --- a/funasr/models/e_branchformer/encoder.py +++ b/funasr/models/e_branchformer/encoder.py @@ -42,7 +42,7 @@ from funasr.models.transformer.utils.subsampling import ( TooShortUttError, check_short_utt, ) -from funasr.utils.register import register_class +from funasr.register import tables class EBranchformerEncoderLayer(torch.nn.Module): """E-Branchformer encoder layer module. @@ -174,7 +174,7 @@ class EBranchformerEncoderLayer(torch.nn.Module): return x, mask -@register_class("encoder_classes", "EBranchformerEncoder") +@tables.register("encoder_classes", "EBranchformerEncoder") class EBranchformerEncoder(nn.Module): """E-Branchformer encoder module.""" diff --git a/funasr/models/e_branchformer/model.py b/funasr/models/e_branchformer/model.py index ccf132063..4ffeb3e43 100644 --- a/funasr/models/e_branchformer/model.py +++ b/funasr/models/e_branchformer/model.py @@ -1,9 +1,9 @@ import logging from funasr.models.transformer.model import Transformer -from funasr.utils.register import register_class +from funasr.register import tables -@register_class("model_classes", "EBranchformer") +@tables.register("model_classes", "EBranchformer") class EBranchformer(Transformer): """CTC-attention hybrid Encoder-Decoder model""" diff --git a/funasr/models/fsmn_vad/encoder.py b/funasr/models/fsmn_vad/encoder.py index 50e31fc32..54410acb3 100755 --- a/funasr/models/fsmn_vad/encoder.py +++ b/funasr/models/fsmn_vad/encoder.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables class LinearTransform(nn.Module): @@ -158,7 +158,7 @@ num_syn: output dimension fsmn_layers: no. of sequential fsmn layers ''' -@register_class("encoder_classes", "FSMN") +@tables.register("encoder_classes", "FSMN") class FSMN(nn.Module): def __init__( self, @@ -229,7 +229,7 @@ lstride: left stride rstride: right stride ''' -@register_class("encoder_classes", "DFSMN") +@tables.register("encoder_classes", "DFSMN") class DFSMN(nn.Module): def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1): diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py index 16f21dca9..488e05e9a 100644 --- a/funasr/models/fsmn_vad/model.py +++ b/funasr/models/fsmn_vad/model.py @@ -8,7 +8,7 @@ from torch import nn import math from typing import Optional import time -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank from funasr.utils.datadir_writer import DatadirWriter from torch.nn.utils.rnn import pad_sequence @@ -218,7 +218,7 @@ class WindowDetector(object): return int(self.frame_size_ms) -@register_class("model_classes", "FsmnVAD") +@tables.register("model_classes", "FsmnVAD") class FsmnVAD(nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group @@ -238,7 +238,7 @@ class FsmnVAD(nn.Module): self.vad_opts.speech_to_sil_time_thres, self.vad_opts.frame_in_ms) - encoder_class = registry_tables.encoder_classes.get(encoder.lower()) + encoder_class = tables.encoder_classes.get(encoder.lower()) encoder = encoder_class(**encoder_conf) self.encoder = encoder # init variables diff --git a/funasr/models/fsmn_vad/template.yaml b/funasr/models/fsmn_vad/template.yaml new file mode 100644 index 000000000..90032eb83 --- /dev/null +++ b/funasr/models/fsmn_vad/template.yaml @@ -0,0 +1,62 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +# network architecture +model: FsmnVAD +model_conf: + sample_rate: 16000 + detect_mode: 1 + snr_mode: 0 + max_end_silence_time: 800 + max_start_silence_time: 3000 + do_start_point_detection: True + do_end_point_detection: True + window_size_ms: 200 + sil_to_speech_time_thres: 150 + speech_to_sil_time_thres: 150 + speech_2_noise_ratio: 1.0 + do_extend: 1 + lookback_time_start_point: 200 + lookahead_time_end_point: 100 + max_single_segment_time: 60000 + snr_thres: -100.0 + noise_frame_num_used_for_snr: 100 + decibel_thres: -100.0 + speech_noise_thres: 0.6 + fe_prior_thres: 0.0001 + silence_pdf_num: 1 + sil_pdf_ids: [0] + speech_noise_thresh_low: -0.1 + speech_noise_thresh_high: 0.3 + output_frame_probs: False + frame_in_ms: 10 + frame_length_ms: 25 + +encoder: FSMN +encoder_conf: + input_dim: 400 + input_affine_dim: 140 + fsmn_layers: 4 + linear_dim: 250 + proj_dim: 128 + lorder: 20 + rorder: 0 + lstride: 1 + rstride: 0 + output_affine_dim: 140 + output_dim: 248 + +frontend: WavFrontend +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + dither: 0.0 + lfr_m: 5 + lfr_n: 1 diff --git a/funasr/models/neat_contextual_paraformer/decoder.py b/funasr/models/neat_contextual_paraformer/decoder.py index ca689d37a..5ec27560e 100644 --- a/funasr/models/neat_contextual_paraformer/decoder.py +++ b/funasr/models/neat_contextual_paraformer/decoder.py @@ -14,7 +14,7 @@ from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForward from funasr.models.transformer.utils.repeat import repeat from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables class ContextualDecoderLayer(nn.Module): def __init__( @@ -98,7 +98,7 @@ class ContextualBiasDecoder(nn.Module): x = self.dropout(self.src_attn(x, memory, memory_mask)) return x, tgt_mask, memory, memory_mask, cache -@register_class("decoder_classes", "ContextualParaformerDecoder") +@tables.register("decoder_classes", "ContextualParaformerDecoder") class ContextualParaformerDecoder(ParaformerSANMDecoder): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/neat_contextual_paraformer/model.py b/funasr/models/neat_contextual_paraformer/model.py index 78913072b..d056ab980 100644 --- a/funasr/models/neat_contextual_paraformer/model.py +++ b/funasr/models/neat_contextual_paraformer/model.py @@ -53,9 +53,9 @@ from funasr.utils.datadir_writer import DatadirWriter from funasr.models.paraformer.model import Paraformer -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("model_classes", "NeatContextualParaformer") +@tables.register("model_classes", "NeatContextualParaformer") class NeatContextualParaformer(Paraformer): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/neat_contextual_paraformer/template.yaml b/funasr/models/neat_contextual_paraformer/template.yaml index 012ecf7bc..6efc62c8c 100644 --- a/funasr/models/neat_contextual_paraformer/template.yaml +++ b/funasr/models/neat_contextual_paraformer/template.yaml @@ -2,8 +2,8 @@ # You can modify the configuration according to your own requirements. # to print the register_table: -# from funasr.utils.register import registry_tables -# registry_tables.print() +# from funasr.register import tables +# tables.print() # network architecture model: NeatContextualParaformer diff --git a/funasr/models/normalize/global_mvn.py b/funasr/models/normalize/global_mvn.py index eea84dc83..8df7c3372 100644 --- a/funasr/models/normalize/global_mvn.py +++ b/funasr/models/normalize/global_mvn.py @@ -6,9 +6,9 @@ import numpy as np import torch from funasr.models.transformer.utils.nets_utils import make_pad_mask -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("normalize_classes", "GlobalMVN") +@tables.register("normalize_classes", "GlobalMVN") class GlobalMVN(torch.nn.Module): """Apply global mean and variance normalization TODO(kamo): Make this class portable somehow diff --git a/funasr/models/normalize/utterance_mvn.py b/funasr/models/normalize/utterance_mvn.py index 60703fbb3..9558402ee 100644 --- a/funasr/models/normalize/utterance_mvn.py +++ b/funasr/models/normalize/utterance_mvn.py @@ -3,9 +3,9 @@ from typing import Tuple import torch from funasr.models.transformer.utils.nets_utils import make_pad_mask -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("normalize_classes", "UtteranceMVN") +@tables.register("normalize_classes", "UtteranceMVN") class UtteranceMVN(torch.nn.Module): def __init__( self, diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index c1b7d7a24..383d9ca7c 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -8,9 +8,9 @@ from funasr.models.transformer.utils.nets_utils import make_pad_mask from funasr.models.scama.utils import sequence_mask from typing import Optional, Tuple -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("predictor_classes", "CifPredictor") +@tables.register("predictor_classes", "CifPredictor") class CifPredictor(nn.Module): def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45): super().__init__() @@ -136,7 +136,7 @@ class CifPredictor(nn.Module): predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype) return predictor_alignments.detach(), predictor_alignments_length.detach() -@register_class("predictor_classes", "CifPredictorV2") +@tables.register("predictor_classes", "CifPredictorV2") class CifPredictorV2(nn.Module): def __init__(self, idim, diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py index f59ce4db8..b4de6cd13 100644 --- a/funasr/models/paraformer/decoder.py +++ b/funasr/models/paraformer/decoder.py @@ -17,7 +17,7 @@ from funasr.models.transformer.attention import MultiHeadedAttention from funasr.models.transformer.embedding import PositionalEncoding from funasr.models.transformer.utils.nets_utils import make_pad_mask from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables class DecoderLayerSANM(nn.Module): """Single decoder layer module. @@ -200,7 +200,7 @@ class DecoderLayerSANM(nn.Module): return x, memory, fsmn_cache, opt_cache -@register_class("decoder_classes", "ParaformerSANMDecoder") +@tables.register("decoder_classes", "ParaformerSANMDecoder") class ParaformerSANMDecoder(BaseTransformerDecoder): """ Author: Speech Lab of DAMO Academy, Alibaba Group @@ -525,7 +525,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder): return y, new_cache -@register_class("decoder_classes", "ParaformerDecoderSAN") +@tables.register("decoder_classes", "ParaformerDecoderSAN") class ParaformerDecoderSAN(BaseTransformerDecoder): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index 03a0bd25a..d92d08d5c 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -25,10 +25,10 @@ from torch.cuda.amp import autocast from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank from funasr.utils import postprocess_utils from funasr.utils.datadir_writer import DatadirWriter -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables from funasr.models.ctc.ctc import CTC -@register_class("model_classes", "Paraformer") +@tables.register("model_classes", "Paraformer") class Paraformer(nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group @@ -79,17 +79,17 @@ class Paraformer(nn.Module): super().__init__() if specaug is not None: - specaug_class = registry_tables.specaug_classes.get(specaug.lower()) + specaug_class = tables.specaug_classes.get(specaug.lower()) specaug = specaug_class(**specaug_conf) if normalize is not None: - normalize_class = registry_tables.normalize_classes.get(normalize.lower()) + normalize_class = tables.normalize_classes.get(normalize.lower()) normalize = normalize_class(**normalize_conf) - encoder_class = registry_tables.encoder_classes.get(encoder.lower()) + encoder_class = tables.encoder_classes.get(encoder.lower()) encoder = encoder_class(input_size=input_size, **encoder_conf) encoder_output_size = encoder.output_size() if decoder is not None: - decoder_class = registry_tables.decoder_classes.get(decoder.lower()) + decoder_class = tables.decoder_classes.get(decoder.lower()) decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, @@ -104,7 +104,7 @@ class Paraformer(nn.Module): odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf ) if predictor is not None: - predictor_class = registry_tables.predictor_classes.get(predictor.lower()) + predictor_class = tables.predictor_classes.get(predictor.lower()) predictor = predictor_class(**predictor_conf) # note that eos is the same as sos (equivalent ID) diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml index 000f641c2..94eebf7bd 100644 --- a/funasr/models/paraformer/template.yaml +++ b/funasr/models/paraformer/template.yaml @@ -2,8 +2,8 @@ # You can modify the configuration according to your own requirements. # to print the register_table: -# from funasr.utils.register import registry_tables -# registry_tables.print() +# from funasr.register import tables +# tables.print() # network architecture #model: funasr.models.paraformer.model:Paraformer diff --git a/funasr/models/paraformer_online/model.py b/funasr/models/paraformer_online/model.py index 5cbed26ee..27871bca5 100644 --- a/funasr/models/paraformer_online/model.py +++ b/funasr/models/paraformer_online/model.py @@ -44,7 +44,7 @@ from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, from funasr.utils import postprocess_utils from funasr.utils.datadir_writer import DatadirWriter from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard -from funasr.utils.register import registry_tables +from funasr.register import tables from funasr.models.ctc.ctc import CTC class Paraformer(nn.Module): @@ -102,19 +102,19 @@ class Paraformer(nn.Module): # pdb.set_trace() if frontend is not None: - frontend_class = registry_tables.frontend_classes.get_class(frontend.lower()) + frontend_class = tables.frontend_classes.get_class(frontend.lower()) frontend = frontend_class(**frontend_conf) if specaug is not None: - specaug_class = registry_tables.specaug_classes.get_class(specaug.lower()) + specaug_class = tables.specaug_classes.get_class(specaug.lower()) specaug = specaug_class(**specaug_conf) if normalize is not None: - normalize_class = registry_tables.normalize_classes.get_class(normalize.lower()) + normalize_class = tables.normalize_classes.get_class(normalize.lower()) normalize = normalize_class(**normalize_conf) - encoder_class = registry_tables.encoder_classes.get_class(encoder.lower()) + encoder_class = tables.encoder_classes.get_class(encoder.lower()) encoder = encoder_class(input_size=input_size, **encoder_conf) encoder_output_size = encoder.output_size() if decoder is not None: - decoder_class = registry_tables.decoder_classes.get_class(decoder.lower()) + decoder_class = tables.decoder_classes.get_class(decoder.lower()) decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, @@ -129,7 +129,7 @@ class Paraformer(nn.Module): odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf ) if predictor is not None: - predictor_class = registry_tables.predictor_classes.get_class(predictor.lower()) + predictor_class = tables.predictor_classes.get_class(predictor.lower()) predictor = predictor_class(**predictor_conf) # note that eos is the same as sos (equivalent ID) diff --git a/funasr/models/paraformer_online/sanm_decoder.py b/funasr/models/paraformer_online/sanm_decoder.py index b1e94d755..2ae433586 100644 --- a/funasr/models/paraformer_online/sanm_decoder.py +++ b/funasr/models/paraformer_online/sanm_decoder.py @@ -14,7 +14,7 @@ from funasr.models.transformer.layer_norm import LayerNorm from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM from funasr.models.transformer.utils.repeat import repeat -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables class DecoderLayerSANM(nn.Module): """Single decoder layer module. @@ -190,7 +190,7 @@ class DecoderLayerSANM(nn.Module): return x, memory, fsmn_cache, opt_cache -@register_class("decoder_classes", "ParaformerSANMDecoder") +@tables.register("decoder_classes", "ParaformerSANMDecoder") class ParaformerSANMDecoder(BaseTransformerDecoder): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/sa_asr/transformer_decoder.py b/funasr/models/sa_asr/transformer_decoder.py index 3319212b0..b34a3aacf 100644 --- a/funasr/models/sa_asr/transformer_decoder.py +++ b/funasr/models/sa_asr/transformer_decoder.py @@ -27,7 +27,7 @@ from funasr.models.transformer.positionwise_feed_forward import ( from funasr.models.transformer.utils.repeat import repeat from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables class DecoderLayer(nn.Module): """Single decoder layer module. @@ -353,7 +353,7 @@ class BaseTransformerDecoder(nn.Module, BatchScorerInterface): state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] return logp, state_list -@register_class("decoder_classes", "TransformerDecoder") +@tables.register("decoder_classes", "TransformerDecoder") class TransformerDecoder(BaseTransformerDecoder): def __init__( self, @@ -402,7 +402,7 @@ class TransformerDecoder(BaseTransformerDecoder): ) -@register_class("decoder_classes", "ParaformerDecoderSAN") +@tables.register("decoder_classes", "ParaformerDecoderSAN") class ParaformerDecoderSAN(BaseTransformerDecoder): """ Author: Speech Lab of DAMO Academy, Alibaba Group @@ -516,7 +516,7 @@ class ParaformerDecoderSAN(BaseTransformerDecoder): else: return x, olens -@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder") +@tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder") class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): def __init__( self, @@ -577,7 +577,7 @@ class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): ), ) -@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder") +@tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder") class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): def __init__( self, @@ -639,7 +639,7 @@ class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): ) -@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder") +@tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder") class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): def __init__( self, @@ -700,7 +700,7 @@ class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): ), ) -@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder") +@tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder") class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): def __init__( self, diff --git a/funasr/models/sanm/decoder.py b/funasr/models/sanm/decoder.py index 64033add3..190ada0f3 100644 --- a/funasr/models/sanm/decoder.py +++ b/funasr/models/sanm/decoder.py @@ -14,7 +14,7 @@ from funasr.models.transformer.layer_norm import LayerNorm from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM from funasr.models.transformer.utils.repeat import repeat -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables class DecoderLayerSANM(nn.Module): """Single decoder layer module. @@ -190,7 +190,7 @@ class DecoderLayerSANM(nn.Module): return x, memory, fsmn_cache, opt_cache -@register_class("decoder_classes", "FsmnDecoder") +@tables.register("decoder_classes", "FsmnDecoder") class FsmnDecoder(BaseTransformerDecoder): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py index 8e159e2f6..cb4e21af4 100644 --- a/funasr/models/sanm/encoder.py +++ b/funasr/models/sanm/encoder.py @@ -30,7 +30,7 @@ from funasr.models.transformer.utils.subsampling import check_short_utt from funasr.models.ctc.ctc import CTC -from funasr.utils.register import register_class +from funasr.register import tables class EncoderLayerSANM(nn.Module): def __init__( @@ -153,7 +153,7 @@ class EncoderLayerSANM(nn.Module): return x, cache -@register_class("encoder_classes", "SANMEncoder") +@tables.register("encoder_classes", "SANMEncoder") class SANMEncoder(nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/sanm/model.py b/funasr/models/sanm/model.py index e01394c89..d51478f73 100644 --- a/funasr/models/sanm/model.py +++ b/funasr/models/sanm/model.py @@ -3,9 +3,9 @@ import logging import torch from funasr.models.transformer.model import Transformer -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("model_classes", "SANM") +@tables.register("model_classes", "SANM") class SANM(Transformer): """CTC-attention hybrid Encoder-Decoder model""" diff --git a/funasr/models/scama/sanm_decoder.py b/funasr/models/scama/sanm_decoder.py index 53423d039..4222e5f85 100644 --- a/funasr/models/scama/sanm_decoder.py +++ b/funasr/models/scama/sanm_decoder.py @@ -14,7 +14,7 @@ from funasr.models.transformer.layer_norm import LayerNorm from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM from funasr.models.transformer.utils.repeat import repeat -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables class DecoderLayerSANM(nn.Module): """Single decoder layer module. @@ -189,7 +189,7 @@ class DecoderLayerSANM(nn.Module): return x, memory, fsmn_cache, opt_cache -@register_class("decoder_classes", "FsmnDecoderSCAMAOpt") +@tables.register("decoder_classes", "FsmnDecoderSCAMAOpt") class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/scama/sanm_encoder.py b/funasr/models/scama/sanm_encoder.py index c89bfb374..4bf6ef0ed 100644 --- a/funasr/models/scama/sanm_encoder.py +++ b/funasr/models/scama/sanm_encoder.py @@ -30,7 +30,7 @@ from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask from funasr.models.ctc.ctc import CTC -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables class EncoderLayerSANM(nn.Module): def __init__( @@ -154,7 +154,7 @@ class EncoderLayerSANM(nn.Module): return x, cache -@register_class("encoder_classes", "SANMEncoderChunkOpt") +@tables.register("encoder_classes", "SANMEncoderChunkOpt") class SANMEncoderChunkOpt(nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index 86aa7602a..d25babe6f 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -51,10 +51,10 @@ from funasr.utils import postprocess_utils from funasr.utils.datadir_writer import DatadirWriter from funasr.models.paraformer.model import Paraformer -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("model_classes", "SeacoParaformer") +@tables.register("model_classes", "SeacoParaformer") class SeacoParaformer(Paraformer): """ Author: Speech Lab of DAMO Academy, Alibaba Group @@ -100,7 +100,7 @@ class SeacoParaformer(Paraformer): seaco_decoder = kwargs.get("seaco_decoder", None) if seaco_decoder is not None: seaco_decoder_conf = kwargs.get("seaco_decoder_conf") - seaco_decoder_class = registry_tables.decoder_classes.get(seaco_decoder.lower()) + seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower()) self.seaco_decoder = seaco_decoder_class( vocab_size=self.vocab_size, encoder_output_size=self.inner_dim, diff --git a/funasr/models/seaco_paraformer/template.yaml b/funasr/models/seaco_paraformer/template.yaml index 266386ffe..52654ac0f 100644 --- a/funasr/models/seaco_paraformer/template.yaml +++ b/funasr/models/seaco_paraformer/template.yaml @@ -2,8 +2,8 @@ # You can modify the configuration according to your own requirements. # to print the register_table: -# from funasr.utils.register import registry_tables -# registry_tables.print() +# from funasr.register import tables +# tables.print() # network architecture model: SeacoParaformer diff --git a/funasr/models/specaug/specaug.py b/funasr/models/specaug/specaug.py index 17f265770..49f83e2e4 100644 --- a/funasr/models/specaug/specaug.py +++ b/funasr/models/specaug/specaug.py @@ -7,11 +7,11 @@ from funasr.models.specaug.mask_along_axis import MaskAlongAxis from funasr.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth from funasr.models.specaug.mask_along_axis import MaskAlongAxisLFR from funasr.models.specaug.time_warp import TimeWarp -from funasr.utils.register import register_class +from funasr.register import tables import torch.nn as nn -@register_class("specaug_classes", "SpecAug") +@tables.register("specaug_classes", "SpecAug") class SpecAug(nn.Module): """Implementation of SpecAug. @@ -101,7 +101,7 @@ class SpecAug(nn.Module): x, x_lengths = self.time_mask(x, x_lengths) return x, x_lengths -@register_class("specaug_classes", "SpecAugLFR") +@tables.register("specaug_classes", "SpecAugLFR") class SpecAugLFR(nn.Module): """Implementation of SpecAug. lfr_rate:low frame rate diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py index 3e8d224f8..820de4a1f 100644 --- a/funasr/models/transformer/decoder.py +++ b/funasr/models/transformer/decoder.py @@ -26,7 +26,7 @@ from funasr.models.transformer.positionwise_feed_forward import ( from funasr.models.transformer.utils.repeat import repeat from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables class DecoderLayer(nn.Module): """Single decoder layer module. @@ -352,7 +352,7 @@ class BaseTransformerDecoder(nn.Module, BatchScorerInterface): state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] return logp, state_list -@register_class("decoder_classes", "TransformerDecoder") +@tables.register("decoder_classes", "TransformerDecoder") class TransformerDecoder(BaseTransformerDecoder): def __init__( self, @@ -401,7 +401,7 @@ class TransformerDecoder(BaseTransformerDecoder): ) -@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder") +@tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder") class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): def __init__( self, @@ -462,7 +462,7 @@ class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): ), ) -@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder") +@tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder") class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): def __init__( self, @@ -524,7 +524,7 @@ class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): ) -@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder") +@tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder") class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): def __init__( self, @@ -585,7 +585,7 @@ class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): ), ) -@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder") +@tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder") class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): def __init__( self, diff --git a/funasr/models/transformer/encoder.py b/funasr/models/transformer/encoder.py index a3d524961..1f1486704 100644 --- a/funasr/models/transformer/encoder.py +++ b/funasr/models/transformer/encoder.py @@ -28,7 +28,7 @@ from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8 from funasr.models.transformer.utils.subsampling import TooShortUttError from funasr.models.transformer.utils.subsampling import check_short_utt -from funasr.utils.register import register_class +from funasr.register import tables class EncoderLayer(nn.Module): """Encoder layer module. @@ -136,7 +136,7 @@ class EncoderLayer(nn.Module): return x, mask -@register_class("encoder_classes", "TransformerEncoder") +@tables.register("encoder_classes", "TransformerEncoder") class TransformerEncoder(nn.Module): """Transformer encoder module. diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py index e4eae1059..b710c9732 100644 --- a/funasr/models/transformer/model.py +++ b/funasr/models/transformer/model.py @@ -15,9 +15,9 @@ from funasr.train_utils.device_funcs import force_gatherable from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank from funasr.utils import postprocess_utils from funasr.utils.datadir_writer import DatadirWriter -from funasr.utils.register import register_class, registry_tables +from funasr.register import tables -@register_class("model_classes", "Transformer") +@tables.register("model_classes", "Transformer") class Transformer(nn.Module): """CTC-attention hybrid Encoder-Decoder model""" @@ -60,19 +60,19 @@ class Transformer(nn.Module): super().__init__() if frontend is not None: - frontend_class = registry_tables.frontend_classes.get_class(frontend.lower()) + frontend_class = tables.frontend_classes.get_class(frontend.lower()) frontend = frontend_class(**frontend_conf) if specaug is not None: - specaug_class = registry_tables.specaug_classes.get_class(specaug.lower()) + specaug_class = tables.specaug_classes.get_class(specaug.lower()) specaug = specaug_class(**specaug_conf) if normalize is not None: - normalize_class = registry_tables.normalize_classes.get_class(normalize.lower()) + normalize_class = tables.normalize_classes.get_class(normalize.lower()) normalize = normalize_class(**normalize_conf) - encoder_class = registry_tables.encoder_classes.get_class(encoder.lower()) + encoder_class = tables.encoder_classes.get_class(encoder.lower()) encoder = encoder_class(input_size=input_size, **encoder_conf) encoder_output_size = encoder.output_size() if decoder is not None: - decoder_class = registry_tables.decoder_classes.get_class(decoder.lower()) + decoder_class = tables.decoder_classes.get_class(decoder.lower()) decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, diff --git a/funasr/models/transformer/template.yaml b/funasr/models/transformer/template.yaml index 798e37462..c9228f433 100644 --- a/funasr/models/transformer/template.yaml +++ b/funasr/models/transformer/template.yaml @@ -2,8 +2,8 @@ # You can modify the configuration according to your own requirements. # to print the register_table: -# from funasr.utils.register import registry_tables -# registry_tables.print() +# from funasr.register import tables +# tables.print() # network architecture #model: funasr.models.paraformer.model:Paraformer diff --git a/funasr/register.py b/funasr/register.py new file mode 100644 index 000000000..145a6986e --- /dev/null +++ b/funasr/register.py @@ -0,0 +1,77 @@ +import logging +import inspect +from dataclasses import dataclass + + +@dataclass +class RegisterTables: + model_classes = {} + frontend_classes = {} + specaug_classes = {} + normalize_classes = {} + encoder_classes = {} + decoder_classes = {} + joint_network_classes = {} + predictor_classes = {} + stride_conv_classes = {} + tokenizer_classes = {} + batch_sampler_classes = {} + dataset_classes = {} + index_ds_classes = {} + + def print(self,): + print("\ntables: \n") + fields = vars(self) + for classes_key, classes_dict in fields.items(): + print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------") + + if classes_key.endswith("_meta"): + headers = ["class name", "register name", "class location"] + metas = [] + for register_key, meta in classes_dict.items(): + metas.append(meta) + metas.sort(key=lambda x: x[0]) + data = [headers] + metas + col_widths = [max(len(str(item)) for item in col) for col in zip(*data)] + + for row in data: + print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |") + print("\n") + + + def register(self, register_tables_key: str, key=None): + def decorator(target_class): + + if not hasattr(self, register_tables_key): + setattr(self, register_tables_key, {}) + logging.info("new registry table has been added: {}".format(register_tables_key)) + + registry = getattr(self, register_tables_key) + registry_key = key if key is not None else target_class.__name__ + registry_key = registry_key.lower() + # import pdb; pdb.set_trace() + assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format( + registry_key, target_class, register_tables_key) + + registry[registry_key] = target_class + + # meta, headers = ["class name", "register name", "class location"] + register_tables_key_meta = register_tables_key + "_meta" + if not hasattr(self, register_tables_key_meta): + setattr(self, register_tables_key_meta, {}) + registry_meta = getattr(self, register_tables_key_meta) + class_file = inspect.getfile(target_class) + class_line = inspect.getsourcelines(target_class)[1] + meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"] + registry_meta[registry_key] = meata_data + # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}") + return target_class + + return decorator + + +tables = RegisterTables() + + +import funasr + diff --git a/funasr/tokenizer/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py index 23ff74350..8c6c214f6 100644 --- a/funasr/tokenizer/char_tokenizer.py +++ b/funasr/tokenizer/char_tokenizer.py @@ -5,9 +5,9 @@ from typing import Union import warnings from funasr.tokenizer.abs_tokenizer import BaseTokenizer -from funasr.utils.register import register_class +from funasr.register import tables -@register_class("tokenizer_classes", "CharTokenizer") +@tables.register("tokenizer_classes", "CharTokenizer") class CharTokenizer(BaseTokenizer): def __init__( self, diff --git a/funasr/utils/register.py b/funasr/utils/register.py deleted file mode 100644 index 6fe04f709..000000000 --- a/funasr/utils/register.py +++ /dev/null @@ -1,72 +0,0 @@ -import logging -import inspect -from dataclasses import dataclass - - -@dataclass -class ClassRegistryTables: - model_classes = {} - frontend_classes = {} - specaug_classes = {} - normalize_classes = {} - encoder_classes = {} - decoder_classes = {} - joint_network_classes = {} - predictor_classes = {} - stride_conv_classes = {} - tokenizer_classes = {} - batch_sampler_classes = {} - dataset_classes = {} - index_ds_classes = {} - - def print(self,): - print("\nregister_tables: \n") - fields = vars(self) - for classes_key, classes_dict in fields.items(): - print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------") - - if classes_key.endswith("_meta"): - headers = ["class name", "register name", "class location"] - metas = [] - for register_key, meta in classes_dict.items(): - metas.append(meta) - metas.sort(key=lambda x: x[0]) - data = [headers] + metas - col_widths = [max(len(str(item)) for item in col) for col in zip(*data)] - - for row in data: - print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |") - print("\n") - -registry_tables = ClassRegistryTables() - -def register_class(registry_tables_key:str, key=None): - def decorator(target_class): - - if not hasattr(registry_tables, registry_tables_key): - setattr(registry_tables, registry_tables_key, {}) - logging.info("new registry table has been added: {}".format(registry_tables_key)) - - registry = getattr(registry_tables, registry_tables_key) - registry_key = key if key is not None else target_class.__name__ - registry_key = registry_key.lower() - # import pdb; pdb.set_trace() - assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(registry_key, target_class, registry_tables_key) - - registry[registry_key] = target_class - - # meta, headers = ["class name", "register name", "class location"] - registry_tables_key_meta = registry_tables_key + "_meta" - if not hasattr(registry_tables, registry_tables_key_meta): - setattr(registry_tables, registry_tables_key_meta, {}) - registry_meta = getattr(registry_tables, registry_tables_key_meta) - class_file = inspect.getfile(target_class) - class_line = inspect.getsourcelines(target_class)[1] - meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"] - registry_meta[registry_key] = meata_data - # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}") - return target_class - return decorator - -import funasr -