mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
rename register tables
This commit is contained in:
parent
f920ca6298
commit
a1b0cd33d5
@ -2,13 +2,13 @@
|
|||||||
cmd="funasr/bin/inference.py"
|
cmd="funasr/bin/inference.py"
|
||||||
|
|
||||||
python $cmd \
|
python $cmd \
|
||||||
+model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
+model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
||||||
+input="/Users/zhifu/Downloads/asr_example.wav" \
|
+input="/Users/zhifu/Downloads/asr_example.wav" \
|
||||||
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
|
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
|
||||||
+device="cpu" \
|
+device="cpu" \
|
||||||
|
|
||||||
python $cmd \
|
python $cmd \
|
||||||
+model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
|
+model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
|
||||||
+input="/Users/zhifu/Downloads/asr_example.wav" \
|
+input="/Users/zhifu/Downloads/asr_example.wav" \
|
||||||
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
|
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
|
||||||
+device="cpu" \
|
+device="cpu" \
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
|||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from funasr.utils.register import registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
|
|
||||||
def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
|
def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
|
||||||
@ -81,7 +81,7 @@ def main_hydra(kwargs: DictConfig):
|
|||||||
|
|
||||||
class AutoModel:
|
class AutoModel:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
registry_tables.print()
|
tables.print()
|
||||||
assert "model" in kwargs
|
assert "model" in kwargs
|
||||||
if "model_conf" not in kwargs:
|
if "model_conf" not in kwargs:
|
||||||
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
|
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
|
||||||
@ -98,7 +98,7 @@ class AutoModel:
|
|||||||
# build tokenizer
|
# build tokenizer
|
||||||
tokenizer = kwargs.get("tokenizer", None)
|
tokenizer = kwargs.get("tokenizer", None)
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
|
tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
|
||||||
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
||||||
kwargs["tokenizer"] = tokenizer
|
kwargs["tokenizer"] = tokenizer
|
||||||
kwargs["token_list"] = tokenizer.token_list
|
kwargs["token_list"] = tokenizer.token_list
|
||||||
@ -106,13 +106,13 @@ class AutoModel:
|
|||||||
# build frontend
|
# build frontend
|
||||||
frontend = kwargs.get("frontend", None)
|
frontend = kwargs.get("frontend", None)
|
||||||
if frontend is not None:
|
if frontend is not None:
|
||||||
frontend_class = registry_tables.frontend_classes.get(frontend.lower())
|
frontend_class = tables.frontend_classes.get(frontend.lower())
|
||||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||||
kwargs["frontend"] = frontend
|
kwargs["frontend"] = frontend
|
||||||
kwargs["input_size"] = frontend.output_size()
|
kwargs["input_size"] = frontend.output_size()
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model_class = registry_tables.model_classes.get(kwargs["model"].lower())
|
model_class = tables.model_classes.get(kwargs["model"].lower())
|
||||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
|
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|||||||
@ -21,7 +21,7 @@ import torch.distributed as dist
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from funasr.download.download_from_hub import download_model
|
from funasr.download.download_from_hub import download_model
|
||||||
from funasr.utils.register import registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@hydra.main(config_name=None, version_base=None)
|
@hydra.main(config_name=None, version_base=None)
|
||||||
def main_hydra(kwargs: DictConfig):
|
def main_hydra(kwargs: DictConfig):
|
||||||
@ -39,7 +39,7 @@ def main(**kwargs):
|
|||||||
# preprocess_config(kwargs)
|
# preprocess_config(kwargs)
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
# set random seed
|
# set random seed
|
||||||
registry_tables.print()
|
tables.print()
|
||||||
set_all_random_seed(kwargs.get("seed", 0))
|
set_all_random_seed(kwargs.get("seed", 0))
|
||||||
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
|
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
|
||||||
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
|
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
|
||||||
@ -62,14 +62,14 @@ def main(**kwargs):
|
|||||||
|
|
||||||
tokenizer = kwargs.get("tokenizer", None)
|
tokenizer = kwargs.get("tokenizer", None)
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
|
tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
|
||||||
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
||||||
kwargs["tokenizer"] = tokenizer
|
kwargs["tokenizer"] = tokenizer
|
||||||
|
|
||||||
# build frontend if frontend is none None
|
# build frontend if frontend is none None
|
||||||
frontend = kwargs.get("frontend", None)
|
frontend = kwargs.get("frontend", None)
|
||||||
if frontend is not None:
|
if frontend is not None:
|
||||||
frontend_class = registry_tables.frontend_classes.get(frontend.lower())
|
frontend_class = tables.frontend_classes.get(frontend.lower())
|
||||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||||
kwargs["frontend"] = frontend
|
kwargs["frontend"] = frontend
|
||||||
kwargs["input_size"] = frontend.output_size()
|
kwargs["input_size"] = frontend.output_size()
|
||||||
@ -77,7 +77,7 @@ def main(**kwargs):
|
|||||||
# import pdb;
|
# import pdb;
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
# build model
|
# build model
|
||||||
model_class = registry_tables.model_classes.get(kwargs["model"].lower())
|
model_class = tables.model_classes.get(kwargs["model"].lower())
|
||||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
||||||
|
|
||||||
|
|
||||||
@ -139,12 +139,12 @@ def main(**kwargs):
|
|||||||
# import pdb;
|
# import pdb;
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
# dataset
|
# dataset
|
||||||
dataset_class = registry_tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower())
|
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower())
|
||||||
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
||||||
|
|
||||||
# dataloader
|
# dataloader
|
||||||
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
|
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
|
||||||
batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
|
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler.lower())
|
||||||
if batch_sampler is not None:
|
if batch_sampler is not None:
|
||||||
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
||||||
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
|
||||||
|
|||||||
@ -9,9 +9,9 @@ import time
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
|
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("dataset_classes", "AudioDataset")
|
@tables.register("dataset_classes", "AudioDataset")
|
||||||
class AudioDataset(torch.utils.data.Dataset):
|
class AudioDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
path,
|
path,
|
||||||
@ -22,16 +22,16 @@ class AudioDataset(torch.utils.data.Dataset):
|
|||||||
float_pad_value: float = 0.0,
|
float_pad_value: float = 0.0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower())
|
index_ds_class = tables.index_ds_classes.get(index_ds.lower())
|
||||||
self.index_ds = index_ds_class(path)
|
self.index_ds = index_ds_class(path)
|
||||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||||
if preprocessor_speech:
|
if preprocessor_speech:
|
||||||
preprocessor_speech_class = registry_tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
|
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
|
||||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
||||||
self.preprocessor_speech = preprocessor_speech
|
self.preprocessor_speech = preprocessor_speech
|
||||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||||
if preprocessor_text:
|
if preprocessor_text:
|
||||||
preprocessor_text_class = registry_tables.preprocessor_text_classes.get(preprocessor_text.lower())
|
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text.lower())
|
||||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||||
self.preprocessor_text = preprocessor_text
|
self.preprocessor_text = preprocessor_text
|
||||||
|
|
||||||
|
|||||||
@ -4,9 +4,9 @@ import torch.distributed as dist
|
|||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("index_ds_classes", "IndexDSJsonl")
|
@tables.register("index_ds_classes", "IndexDSJsonl")
|
||||||
class IndexDSJsonl(torch.utils.data.Dataset):
|
class IndexDSJsonl(torch.utils.data.Dataset):
|
||||||
|
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
|
|||||||
@ -2,9 +2,9 @@ import torch
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
|
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
|
||||||
class BatchSampler(torch.utils.data.BatchSampler):
|
class BatchSampler(torch.utils.data.BatchSampler):
|
||||||
|
|
||||||
def __init__(self, dataset,
|
def __init__(self, dataset,
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import torchaudio.compliance.kaldi as kaldi
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
import funasr.frontends.eend_ola_feature as eend_ola_feature
|
import funasr.frontends.eend_ola_feature as eend_ola_feature
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ def apply_lfr(inputs, lfr_m, lfr_n):
|
|||||||
LFR_outputs = torch.vstack(LFR_inputs)
|
LFR_outputs = torch.vstack(LFR_inputs)
|
||||||
return LFR_outputs.type(torch.float32)
|
return LFR_outputs.type(torch.float32)
|
||||||
|
|
||||||
@register_class("frontend_classes", "WavFrontend")
|
@tables.register("frontend_classes", "WavFrontend")
|
||||||
class WavFrontend(nn.Module):
|
class WavFrontend(nn.Module):
|
||||||
"""Conventional frontend structure for ASR.
|
"""Conventional frontend structure for ASR.
|
||||||
"""
|
"""
|
||||||
@ -211,7 +211,7 @@ class WavFrontend(nn.Module):
|
|||||||
return feats_pad, feats_lens
|
return feats_pad, feats_lens
|
||||||
|
|
||||||
|
|
||||||
@register_class("frontend_classes", "WavFrontendOnline")
|
@tables.register("frontend_classes", "WavFrontendOnline")
|
||||||
class WavFrontendOnline(nn.Module):
|
class WavFrontendOnline(nn.Module):
|
||||||
"""Conventional frontend structure for streaming ASR/VAD.
|
"""Conventional frontend structure for streaming ASR/VAD.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
# from funasr.models.scama.utils import sequence_mask
|
# from funasr.models.scama.utils import sequence_mask
|
||||||
# from typing import Optional, Tuple
|
# from typing import Optional, Tuple
|
||||||
#
|
#
|
||||||
# from funasr.utils.register import register_class
|
# from funasr.register import tables
|
||||||
#
|
#
|
||||||
# class mae_loss(nn.Module):
|
# class mae_loss(nn.Module):
|
||||||
#
|
#
|
||||||
@ -93,7 +93,7 @@
|
|||||||
# fires = torch.stack(list_fires, 1)
|
# fires = torch.stack(list_fires, 1)
|
||||||
# return fires
|
# return fires
|
||||||
#
|
#
|
||||||
# @register_class("predictor_classes", "BATPredictor")
|
# @tables.register("predictor_classes", "BATPredictor")
|
||||||
# class BATPredictor(nn.Module):
|
# class BATPredictor(nn.Module):
|
||||||
# def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
|
# def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
|
||||||
# super(BATPredictor, self).__init__()
|
# super(BATPredictor, self).__init__()
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from funasr.models.transformer.utils.repeat import repeat, MultiBlocks
|
|||||||
from funasr.models.transformer.utils.subsampling import TooShortUttError
|
from funasr.models.transformer.utils.subsampling import TooShortUttError
|
||||||
from funasr.models.transformer.utils.subsampling import check_short_utt
|
from funasr.models.transformer.utils.subsampling import check_short_utt
|
||||||
from funasr.models.transformer.utils.subsampling import StreamingConvInput
|
from funasr.models.transformer.utils.subsampling import StreamingConvInput
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -312,7 +312,7 @@ class CausalConvolution(nn.Module):
|
|||||||
|
|
||||||
return x, cache
|
return x, cache
|
||||||
|
|
||||||
@register_class("encoder_classes", "ConformerChunkEncoder")
|
@tables.register("encoder_classes", "ConformerChunkEncoder")
|
||||||
class ConformerChunkEncoder(nn.Module):
|
class ConformerChunkEncoder(nn.Module):
|
||||||
"""Encoder module definition.
|
"""Encoder module definition.
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|||||||
from funasr.models.scama.utils import sequence_mask
|
from funasr.models.scama.utils import sequence_mask
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
|
|
||||||
class mae_loss(nn.Module):
|
class mae_loss(nn.Module):
|
||||||
@ -94,7 +94,7 @@ def cif_wo_hidden(alphas, threshold):
|
|||||||
fires = torch.stack(list_fires, 1)
|
fires = torch.stack(list_fires, 1)
|
||||||
return fires
|
return fires
|
||||||
|
|
||||||
@register_class("predictor_classes", "CifPredictorV3")
|
@tables.register("predictor_classes", "CifPredictorV3")
|
||||||
class CifPredictorV3(nn.Module):
|
class CifPredictorV3(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
idim,
|
idim,
|
||||||
|
|||||||
@ -27,12 +27,12 @@ from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,
|
|||||||
from funasr.utils import postprocess_utils
|
from funasr.utils import postprocess_utils
|
||||||
from funasr.utils.datadir_writer import DatadirWriter
|
from funasr.utils.datadir_writer import DatadirWriter
|
||||||
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
from funasr.models.ctc.ctc import CTC
|
from funasr.models.ctc.ctc import CTC
|
||||||
|
|
||||||
from funasr.models.paraformer.model import Paraformer
|
from funasr.models.paraformer.model import Paraformer
|
||||||
|
|
||||||
@register_class("model_classes", "BiCifParaformer")
|
@tables.register("model_classes", "BiCifParaformer")
|
||||||
class BiCifParaformer(Paraformer):
|
class BiCifParaformer(Paraformer):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -43,7 +43,7 @@ from funasr.models.transformer.utils.subsampling import (
|
|||||||
check_short_utt,
|
check_short_utt,
|
||||||
)
|
)
|
||||||
|
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
class BranchformerEncoderLayer(torch.nn.Module):
|
class BranchformerEncoderLayer(torch.nn.Module):
|
||||||
"""Branchformer encoder layer module.
|
"""Branchformer encoder layer module.
|
||||||
@ -291,7 +291,7 @@ class BranchformerEncoderLayer(torch.nn.Module):
|
|||||||
|
|
||||||
return x, mask
|
return x, mask
|
||||||
|
|
||||||
@register_class("encoder_classes", "BranchformerEncoder")
|
@tables.register("encoder_classes", "BranchformerEncoder")
|
||||||
class BranchformerEncoder(nn.Module):
|
class BranchformerEncoder(nn.Module):
|
||||||
"""Branchformer encoder module."""
|
"""Branchformer encoder module."""
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from funasr.models.transformer.model import Transformer
|
from funasr.models.transformer.model import Transformer
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("model_classes", "Branchformer")
|
@tables.register("model_classes", "Branchformer")
|
||||||
class Branchformer(Transformer):
|
class Branchformer(Transformer):
|
||||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||||
|
|
||||||
|
|||||||
@ -45,7 +45,7 @@ from funasr.models.transformer.utils.subsampling import TooShortUttError
|
|||||||
from funasr.models.transformer.utils.subsampling import check_short_utt
|
from funasr.models.transformer.utils.subsampling import check_short_utt
|
||||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
|
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
|
||||||
from funasr.models.transformer.utils.subsampling import StreamingConvInput
|
from funasr.models.transformer.utils.subsampling import StreamingConvInput
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
@ -283,7 +283,7 @@ class EncoderLayer(nn.Module):
|
|||||||
return x, mask
|
return x, mask
|
||||||
|
|
||||||
|
|
||||||
@register_class("encoder_classes", "ConformerEncoder")
|
@tables.register("encoder_classes", "ConformerEncoder")
|
||||||
class ConformerEncoder(nn.Module):
|
class ConformerEncoder(nn.Module):
|
||||||
"""Conformer encoder module.
|
"""Conformer encoder module.
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,9 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from funasr.models.transformer.model import Transformer
|
from funasr.models.transformer.model import Transformer
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("model_classes", "Conformer")
|
@tables.register("model_classes", "Conformer")
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,8 @@
|
|||||||
# You can modify the configuration according to your own requirements.
|
# You can modify the configuration according to your own requirements.
|
||||||
|
|
||||||
# to print the register_table:
|
# to print the register_table:
|
||||||
# from funasr.utils.register import registry_tables
|
# from funasr.register import tables
|
||||||
# registry_tables.print()
|
# tables.print()
|
||||||
|
|
||||||
# network architecture
|
# network architecture
|
||||||
#model: funasr.models.paraformer.model:Paraformer
|
#model: funasr.models.paraformer.model:Paraformer
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
|
|||||||
|
|
||||||
from funasr.models.ctc.ctc import CTC
|
from funasr.models.ctc.ctc import CTC
|
||||||
|
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
class EncoderLayerSANM(nn.Module):
|
class EncoderLayerSANM(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -155,7 +155,7 @@ class EncoderLayerSANM(nn.Module):
|
|||||||
return x, cache
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
@register_class("encoder_classes", "SANMVadEncoder")
|
@tables.register("encoder_classes", "SANMVadEncoder")
|
||||||
class SANMVadEncoder(nn.Module):
|
class SANMVadEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -5,9 +5,9 @@ from typing import Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("model_classes", "CTTransformer")
|
@tables.register("model_classes", "CTTransformer")
|
||||||
class CTTransformer(nn.Module):
|
class CTTransformer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
@ -37,7 +37,7 @@ class CTTransformer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
self.embed = nn.Embedding(vocab_size, embed_unit)
|
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||||
encoder_class = registry_tables.encoder_classes.get(encoder.lower())
|
encoder_class = tables.encoder_classes.get(encoder.lower())
|
||||||
encoder = encoder_class(**encoder_conf)
|
encoder = encoder_class(**encoder_conf)
|
||||||
|
|
||||||
self.decoder = nn.Linear(att_unit, punc_size)
|
self.decoder = nn.Linear(att_unit, punc_size)
|
||||||
|
|||||||
@ -42,7 +42,7 @@ from funasr.models.transformer.utils.subsampling import (
|
|||||||
TooShortUttError,
|
TooShortUttError,
|
||||||
check_short_utt,
|
check_short_utt,
|
||||||
)
|
)
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
class EBranchformerEncoderLayer(torch.nn.Module):
|
class EBranchformerEncoderLayer(torch.nn.Module):
|
||||||
"""E-Branchformer encoder layer module.
|
"""E-Branchformer encoder layer module.
|
||||||
@ -174,7 +174,7 @@ class EBranchformerEncoderLayer(torch.nn.Module):
|
|||||||
|
|
||||||
return x, mask
|
return x, mask
|
||||||
|
|
||||||
@register_class("encoder_classes", "EBranchformerEncoder")
|
@tables.register("encoder_classes", "EBranchformerEncoder")
|
||||||
class EBranchformerEncoder(nn.Module):
|
class EBranchformerEncoder(nn.Module):
|
||||||
"""E-Branchformer encoder module."""
|
"""E-Branchformer encoder module."""
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from funasr.models.transformer.model import Transformer
|
from funasr.models.transformer.model import Transformer
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("model_classes", "EBranchformer")
|
@tables.register("model_classes", "EBranchformer")
|
||||||
class EBranchformer(Transformer):
|
class EBranchformer(Transformer):
|
||||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
class LinearTransform(nn.Module):
|
class LinearTransform(nn.Module):
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ num_syn: output dimension
|
|||||||
fsmn_layers: no. of sequential fsmn layers
|
fsmn_layers: no. of sequential fsmn layers
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@register_class("encoder_classes", "FSMN")
|
@tables.register("encoder_classes", "FSMN")
|
||||||
class FSMN(nn.Module):
|
class FSMN(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -229,7 +229,7 @@ lstride: left stride
|
|||||||
rstride: right stride
|
rstride: right stride
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@register_class("encoder_classes", "DFSMN")
|
@tables.register("encoder_classes", "DFSMN")
|
||||||
class DFSMN(nn.Module):
|
class DFSMN(nn.Module):
|
||||||
|
|
||||||
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
|
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from torch import nn
|
|||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import time
|
import time
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
|
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
|
||||||
from funasr.utils.datadir_writer import DatadirWriter
|
from funasr.utils.datadir_writer import DatadirWriter
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
@ -218,7 +218,7 @@ class WindowDetector(object):
|
|||||||
return int(self.frame_size_ms)
|
return int(self.frame_size_ms)
|
||||||
|
|
||||||
|
|
||||||
@register_class("model_classes", "FsmnVAD")
|
@tables.register("model_classes", "FsmnVAD")
|
||||||
class FsmnVAD(nn.Module):
|
class FsmnVAD(nn.Module):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
@ -238,7 +238,7 @@ class FsmnVAD(nn.Module):
|
|||||||
self.vad_opts.speech_to_sil_time_thres,
|
self.vad_opts.speech_to_sil_time_thres,
|
||||||
self.vad_opts.frame_in_ms)
|
self.vad_opts.frame_in_ms)
|
||||||
|
|
||||||
encoder_class = registry_tables.encoder_classes.get(encoder.lower())
|
encoder_class = tables.encoder_classes.get(encoder.lower())
|
||||||
encoder = encoder_class(**encoder_conf)
|
encoder = encoder_class(**encoder_conf)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
# init variables
|
# init variables
|
||||||
|
|||||||
62
funasr/models/fsmn_vad/template.yaml
Normal file
62
funasr/models/fsmn_vad/template.yaml
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# This is an example that demonstrates how to configure a model file.
|
||||||
|
# You can modify the configuration according to your own requirements.
|
||||||
|
|
||||||
|
# to print the register_table:
|
||||||
|
# from funasr.register import tables
|
||||||
|
# tables.print()
|
||||||
|
|
||||||
|
# network architecture
|
||||||
|
model: FsmnVAD
|
||||||
|
model_conf:
|
||||||
|
sample_rate: 16000
|
||||||
|
detect_mode: 1
|
||||||
|
snr_mode: 0
|
||||||
|
max_end_silence_time: 800
|
||||||
|
max_start_silence_time: 3000
|
||||||
|
do_start_point_detection: True
|
||||||
|
do_end_point_detection: True
|
||||||
|
window_size_ms: 200
|
||||||
|
sil_to_speech_time_thres: 150
|
||||||
|
speech_to_sil_time_thres: 150
|
||||||
|
speech_2_noise_ratio: 1.0
|
||||||
|
do_extend: 1
|
||||||
|
lookback_time_start_point: 200
|
||||||
|
lookahead_time_end_point: 100
|
||||||
|
max_single_segment_time: 60000
|
||||||
|
snr_thres: -100.0
|
||||||
|
noise_frame_num_used_for_snr: 100
|
||||||
|
decibel_thres: -100.0
|
||||||
|
speech_noise_thres: 0.6
|
||||||
|
fe_prior_thres: 0.0001
|
||||||
|
silence_pdf_num: 1
|
||||||
|
sil_pdf_ids: [0]
|
||||||
|
speech_noise_thresh_low: -0.1
|
||||||
|
speech_noise_thresh_high: 0.3
|
||||||
|
output_frame_probs: False
|
||||||
|
frame_in_ms: 10
|
||||||
|
frame_length_ms: 25
|
||||||
|
|
||||||
|
encoder: FSMN
|
||||||
|
encoder_conf:
|
||||||
|
input_dim: 400
|
||||||
|
input_affine_dim: 140
|
||||||
|
fsmn_layers: 4
|
||||||
|
linear_dim: 250
|
||||||
|
proj_dim: 128
|
||||||
|
lorder: 20
|
||||||
|
rorder: 0
|
||||||
|
lstride: 1
|
||||||
|
rstride: 0
|
||||||
|
output_affine_dim: 140
|
||||||
|
output_dim: 248
|
||||||
|
|
||||||
|
frontend: WavFrontend
|
||||||
|
frontend_conf:
|
||||||
|
fs: 16000
|
||||||
|
window: hamming
|
||||||
|
n_mels: 80
|
||||||
|
frame_length: 25
|
||||||
|
frame_shift: 10
|
||||||
|
dither: 0.0
|
||||||
|
lfr_m: 5
|
||||||
|
lfr_n: 1
|
||||||
@ -14,7 +14,7 @@ from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForward
|
|||||||
from funasr.models.transformer.utils.repeat import repeat
|
from funasr.models.transformer.utils.repeat import repeat
|
||||||
from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
|
from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
class ContextualDecoderLayer(nn.Module):
|
class ContextualDecoderLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -98,7 +98,7 @@ class ContextualBiasDecoder(nn.Module):
|
|||||||
x = self.dropout(self.src_attn(x, memory, memory_mask))
|
x = self.dropout(self.src_attn(x, memory, memory_mask))
|
||||||
return x, tgt_mask, memory, memory_mask, cache
|
return x, tgt_mask, memory, memory_mask, cache
|
||||||
|
|
||||||
@register_class("decoder_classes", "ContextualParaformerDecoder")
|
@tables.register("decoder_classes", "ContextualParaformerDecoder")
|
||||||
class ContextualParaformerDecoder(ParaformerSANMDecoder):
|
class ContextualParaformerDecoder(ParaformerSANMDecoder):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -53,9 +53,9 @@ from funasr.utils.datadir_writer import DatadirWriter
|
|||||||
|
|
||||||
from funasr.models.paraformer.model import Paraformer
|
from funasr.models.paraformer.model import Paraformer
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("model_classes", "NeatContextualParaformer")
|
@tables.register("model_classes", "NeatContextualParaformer")
|
||||||
class NeatContextualParaformer(Paraformer):
|
class NeatContextualParaformer(Paraformer):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -2,8 +2,8 @@
|
|||||||
# You can modify the configuration according to your own requirements.
|
# You can modify the configuration according to your own requirements.
|
||||||
|
|
||||||
# to print the register_table:
|
# to print the register_table:
|
||||||
# from funasr.utils.register import registry_tables
|
# from funasr.register import tables
|
||||||
# registry_tables.print()
|
# tables.print()
|
||||||
|
|
||||||
# network architecture
|
# network architecture
|
||||||
model: NeatContextualParaformer
|
model: NeatContextualParaformer
|
||||||
|
|||||||
@ -6,9 +6,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("normalize_classes", "GlobalMVN")
|
@tables.register("normalize_classes", "GlobalMVN")
|
||||||
class GlobalMVN(torch.nn.Module):
|
class GlobalMVN(torch.nn.Module):
|
||||||
"""Apply global mean and variance normalization
|
"""Apply global mean and variance normalization
|
||||||
TODO(kamo): Make this class portable somehow
|
TODO(kamo): Make this class portable somehow
|
||||||
|
|||||||
@ -3,9 +3,9 @@ from typing import Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("normalize_classes", "UtteranceMVN")
|
@tables.register("normalize_classes", "UtteranceMVN")
|
||||||
class UtteranceMVN(torch.nn.Module):
|
class UtteranceMVN(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -8,9 +8,9 @@ from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|||||||
from funasr.models.scama.utils import sequence_mask
|
from funasr.models.scama.utils import sequence_mask
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("predictor_classes", "CifPredictor")
|
@tables.register("predictor_classes", "CifPredictor")
|
||||||
class CifPredictor(nn.Module):
|
class CifPredictor(nn.Module):
|
||||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
|
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -136,7 +136,7 @@ class CifPredictor(nn.Module):
|
|||||||
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
||||||
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
||||||
|
|
||||||
@register_class("predictor_classes", "CifPredictorV2")
|
@tables.register("predictor_classes", "CifPredictorV2")
|
||||||
class CifPredictorV2(nn.Module):
|
class CifPredictorV2(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
idim,
|
idim,
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from funasr.models.transformer.attention import MultiHeadedAttention
|
|||||||
from funasr.models.transformer.embedding import PositionalEncoding
|
from funasr.models.transformer.embedding import PositionalEncoding
|
||||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||||
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
class DecoderLayerSANM(nn.Module):
|
class DecoderLayerSANM(nn.Module):
|
||||||
"""Single decoder layer module.
|
"""Single decoder layer module.
|
||||||
@ -200,7 +200,7 @@ class DecoderLayerSANM(nn.Module):
|
|||||||
return x, memory, fsmn_cache, opt_cache
|
return x, memory, fsmn_cache, opt_cache
|
||||||
|
|
||||||
|
|
||||||
@register_class("decoder_classes", "ParaformerSANMDecoder")
|
@tables.register("decoder_classes", "ParaformerSANMDecoder")
|
||||||
class ParaformerSANMDecoder(BaseTransformerDecoder):
|
class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
@ -525,7 +525,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
|||||||
return y, new_cache
|
return y, new_cache
|
||||||
|
|
||||||
|
|
||||||
@register_class("decoder_classes", "ParaformerDecoderSAN")
|
@tables.register("decoder_classes", "ParaformerDecoderSAN")
|
||||||
class ParaformerDecoderSAN(BaseTransformerDecoder):
|
class ParaformerDecoderSAN(BaseTransformerDecoder):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -25,10 +25,10 @@ from torch.cuda.amp import autocast
|
|||||||
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
|
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
|
||||||
from funasr.utils import postprocess_utils
|
from funasr.utils import postprocess_utils
|
||||||
from funasr.utils.datadir_writer import DatadirWriter
|
from funasr.utils.datadir_writer import DatadirWriter
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
from funasr.models.ctc.ctc import CTC
|
from funasr.models.ctc.ctc import CTC
|
||||||
|
|
||||||
@register_class("model_classes", "Paraformer")
|
@tables.register("model_classes", "Paraformer")
|
||||||
class Paraformer(nn.Module):
|
class Paraformer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
@ -79,17 +79,17 @@ class Paraformer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if specaug is not None:
|
if specaug is not None:
|
||||||
specaug_class = registry_tables.specaug_classes.get(specaug.lower())
|
specaug_class = tables.specaug_classes.get(specaug.lower())
|
||||||
specaug = specaug_class(**specaug_conf)
|
specaug = specaug_class(**specaug_conf)
|
||||||
if normalize is not None:
|
if normalize is not None:
|
||||||
normalize_class = registry_tables.normalize_classes.get(normalize.lower())
|
normalize_class = tables.normalize_classes.get(normalize.lower())
|
||||||
normalize = normalize_class(**normalize_conf)
|
normalize = normalize_class(**normalize_conf)
|
||||||
encoder_class = registry_tables.encoder_classes.get(encoder.lower())
|
encoder_class = tables.encoder_classes.get(encoder.lower())
|
||||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||||
encoder_output_size = encoder.output_size()
|
encoder_output_size = encoder.output_size()
|
||||||
|
|
||||||
if decoder is not None:
|
if decoder is not None:
|
||||||
decoder_class = registry_tables.decoder_classes.get(decoder.lower())
|
decoder_class = tables.decoder_classes.get(decoder.lower())
|
||||||
decoder = decoder_class(
|
decoder = decoder_class(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
encoder_output_size=encoder_output_size,
|
encoder_output_size=encoder_output_size,
|
||||||
@ -104,7 +104,7 @@ class Paraformer(nn.Module):
|
|||||||
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
||||||
)
|
)
|
||||||
if predictor is not None:
|
if predictor is not None:
|
||||||
predictor_class = registry_tables.predictor_classes.get(predictor.lower())
|
predictor_class = tables.predictor_classes.get(predictor.lower())
|
||||||
predictor = predictor_class(**predictor_conf)
|
predictor = predictor_class(**predictor_conf)
|
||||||
|
|
||||||
# note that eos is the same as sos (equivalent ID)
|
# note that eos is the same as sos (equivalent ID)
|
||||||
|
|||||||
@ -2,8 +2,8 @@
|
|||||||
# You can modify the configuration according to your own requirements.
|
# You can modify the configuration according to your own requirements.
|
||||||
|
|
||||||
# to print the register_table:
|
# to print the register_table:
|
||||||
# from funasr.utils.register import registry_tables
|
# from funasr.register import tables
|
||||||
# registry_tables.print()
|
# tables.print()
|
||||||
|
|
||||||
# network architecture
|
# network architecture
|
||||||
#model: funasr.models.paraformer.model:Paraformer
|
#model: funasr.models.paraformer.model:Paraformer
|
||||||
|
|||||||
@ -44,7 +44,7 @@ from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,
|
|||||||
from funasr.utils import postprocess_utils
|
from funasr.utils import postprocess_utils
|
||||||
from funasr.utils.datadir_writer import DatadirWriter
|
from funasr.utils.datadir_writer import DatadirWriter
|
||||||
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||||
from funasr.utils.register import registry_tables
|
from funasr.register import tables
|
||||||
from funasr.models.ctc.ctc import CTC
|
from funasr.models.ctc.ctc import CTC
|
||||||
|
|
||||||
class Paraformer(nn.Module):
|
class Paraformer(nn.Module):
|
||||||
@ -102,19 +102,19 @@ class Paraformer(nn.Module):
|
|||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
|
|
||||||
if frontend is not None:
|
if frontend is not None:
|
||||||
frontend_class = registry_tables.frontend_classes.get_class(frontend.lower())
|
frontend_class = tables.frontend_classes.get_class(frontend.lower())
|
||||||
frontend = frontend_class(**frontend_conf)
|
frontend = frontend_class(**frontend_conf)
|
||||||
if specaug is not None:
|
if specaug is not None:
|
||||||
specaug_class = registry_tables.specaug_classes.get_class(specaug.lower())
|
specaug_class = tables.specaug_classes.get_class(specaug.lower())
|
||||||
specaug = specaug_class(**specaug_conf)
|
specaug = specaug_class(**specaug_conf)
|
||||||
if normalize is not None:
|
if normalize is not None:
|
||||||
normalize_class = registry_tables.normalize_classes.get_class(normalize.lower())
|
normalize_class = tables.normalize_classes.get_class(normalize.lower())
|
||||||
normalize = normalize_class(**normalize_conf)
|
normalize = normalize_class(**normalize_conf)
|
||||||
encoder_class = registry_tables.encoder_classes.get_class(encoder.lower())
|
encoder_class = tables.encoder_classes.get_class(encoder.lower())
|
||||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||||
encoder_output_size = encoder.output_size()
|
encoder_output_size = encoder.output_size()
|
||||||
if decoder is not None:
|
if decoder is not None:
|
||||||
decoder_class = registry_tables.decoder_classes.get_class(decoder.lower())
|
decoder_class = tables.decoder_classes.get_class(decoder.lower())
|
||||||
decoder = decoder_class(
|
decoder = decoder_class(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
encoder_output_size=encoder_output_size,
|
encoder_output_size=encoder_output_size,
|
||||||
@ -129,7 +129,7 @@ class Paraformer(nn.Module):
|
|||||||
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
|
||||||
)
|
)
|
||||||
if predictor is not None:
|
if predictor is not None:
|
||||||
predictor_class = registry_tables.predictor_classes.get_class(predictor.lower())
|
predictor_class = tables.predictor_classes.get_class(predictor.lower())
|
||||||
predictor = predictor_class(**predictor_conf)
|
predictor = predictor_class(**predictor_conf)
|
||||||
|
|
||||||
# note that eos is the same as sos (equivalent ID)
|
# note that eos is the same as sos (equivalent ID)
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from funasr.models.transformer.layer_norm import LayerNorm
|
|||||||
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
||||||
from funasr.models.transformer.utils.repeat import repeat
|
from funasr.models.transformer.utils.repeat import repeat
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
class DecoderLayerSANM(nn.Module):
|
class DecoderLayerSANM(nn.Module):
|
||||||
"""Single decoder layer module.
|
"""Single decoder layer module.
|
||||||
@ -190,7 +190,7 @@ class DecoderLayerSANM(nn.Module):
|
|||||||
return x, memory, fsmn_cache, opt_cache
|
return x, memory, fsmn_cache, opt_cache
|
||||||
|
|
||||||
|
|
||||||
@register_class("decoder_classes", "ParaformerSANMDecoder")
|
@tables.register("decoder_classes", "ParaformerSANMDecoder")
|
||||||
class ParaformerSANMDecoder(BaseTransformerDecoder):
|
class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from funasr.models.transformer.positionwise_feed_forward import (
|
|||||||
from funasr.models.transformer.utils.repeat import repeat
|
from funasr.models.transformer.utils.repeat import repeat
|
||||||
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
|
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
class DecoderLayer(nn.Module):
|
class DecoderLayer(nn.Module):
|
||||||
"""Single decoder layer module.
|
"""Single decoder layer module.
|
||||||
@ -353,7 +353,7 @@ class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
|
|||||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||||
return logp, state_list
|
return logp, state_list
|
||||||
|
|
||||||
@register_class("decoder_classes", "TransformerDecoder")
|
@tables.register("decoder_classes", "TransformerDecoder")
|
||||||
class TransformerDecoder(BaseTransformerDecoder):
|
class TransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -402,7 +402,7 @@ class TransformerDecoder(BaseTransformerDecoder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_class("decoder_classes", "ParaformerDecoderSAN")
|
@tables.register("decoder_classes", "ParaformerDecoderSAN")
|
||||||
class ParaformerDecoderSAN(BaseTransformerDecoder):
|
class ParaformerDecoderSAN(BaseTransformerDecoder):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
@ -516,7 +516,7 @@ class ParaformerDecoderSAN(BaseTransformerDecoder):
|
|||||||
else:
|
else:
|
||||||
return x, olens
|
return x, olens
|
||||||
|
|
||||||
@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
|
@tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder")
|
||||||
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -577,7 +577,7 @@ class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
|
@tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder")
|
||||||
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -639,7 +639,7 @@ class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
|
@tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder")
|
||||||
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -700,7 +700,7 @@ class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
|
@tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder")
|
||||||
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from funasr.models.transformer.layer_norm import LayerNorm
|
|||||||
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
||||||
from funasr.models.transformer.utils.repeat import repeat
|
from funasr.models.transformer.utils.repeat import repeat
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
class DecoderLayerSANM(nn.Module):
|
class DecoderLayerSANM(nn.Module):
|
||||||
"""Single decoder layer module.
|
"""Single decoder layer module.
|
||||||
@ -190,7 +190,7 @@ class DecoderLayerSANM(nn.Module):
|
|||||||
return x, memory, fsmn_cache, opt_cache
|
return x, memory, fsmn_cache, opt_cache
|
||||||
|
|
||||||
|
|
||||||
@register_class("decoder_classes", "FsmnDecoder")
|
@tables.register("decoder_classes", "FsmnDecoder")
|
||||||
class FsmnDecoder(BaseTransformerDecoder):
|
class FsmnDecoder(BaseTransformerDecoder):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from funasr.models.transformer.utils.subsampling import check_short_utt
|
|||||||
|
|
||||||
from funasr.models.ctc.ctc import CTC
|
from funasr.models.ctc.ctc import CTC
|
||||||
|
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
class EncoderLayerSANM(nn.Module):
|
class EncoderLayerSANM(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -153,7 +153,7 @@ class EncoderLayerSANM(nn.Module):
|
|||||||
|
|
||||||
return x, cache
|
return x, cache
|
||||||
|
|
||||||
@register_class("encoder_classes", "SANMEncoder")
|
@tables.register("encoder_classes", "SANMEncoder")
|
||||||
class SANMEncoder(nn.Module):
|
class SANMEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -3,9 +3,9 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from funasr.models.transformer.model import Transformer
|
from funasr.models.transformer.model import Transformer
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("model_classes", "SANM")
|
@tables.register("model_classes", "SANM")
|
||||||
class SANM(Transformer):
|
class SANM(Transformer):
|
||||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from funasr.models.transformer.layer_norm import LayerNorm
|
|||||||
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
||||||
from funasr.models.transformer.utils.repeat import repeat
|
from funasr.models.transformer.utils.repeat import repeat
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
class DecoderLayerSANM(nn.Module):
|
class DecoderLayerSANM(nn.Module):
|
||||||
"""Single decoder layer module.
|
"""Single decoder layer module.
|
||||||
@ -189,7 +189,7 @@ class DecoderLayerSANM(nn.Module):
|
|||||||
|
|
||||||
return x, memory, fsmn_cache, opt_cache
|
return x, memory, fsmn_cache, opt_cache
|
||||||
|
|
||||||
@register_class("decoder_classes", "FsmnDecoderSCAMAOpt")
|
@tables.register("decoder_classes", "FsmnDecoderSCAMAOpt")
|
||||||
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
|
|||||||
|
|
||||||
from funasr.models.ctc.ctc import CTC
|
from funasr.models.ctc.ctc import CTC
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
class EncoderLayerSANM(nn.Module):
|
class EncoderLayerSANM(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -154,7 +154,7 @@ class EncoderLayerSANM(nn.Module):
|
|||||||
return x, cache
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
@register_class("encoder_classes", "SANMEncoderChunkOpt")
|
@tables.register("encoder_classes", "SANMEncoderChunkOpt")
|
||||||
class SANMEncoderChunkOpt(nn.Module):
|
class SANMEncoderChunkOpt(nn.Module):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
|
|||||||
@ -51,10 +51,10 @@ from funasr.utils import postprocess_utils
|
|||||||
from funasr.utils.datadir_writer import DatadirWriter
|
from funasr.utils.datadir_writer import DatadirWriter
|
||||||
|
|
||||||
from funasr.models.paraformer.model import Paraformer
|
from funasr.models.paraformer.model import Paraformer
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
|
|
||||||
@register_class("model_classes", "SeacoParaformer")
|
@tables.register("model_classes", "SeacoParaformer")
|
||||||
class SeacoParaformer(Paraformer):
|
class SeacoParaformer(Paraformer):
|
||||||
"""
|
"""
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||||
@ -100,7 +100,7 @@ class SeacoParaformer(Paraformer):
|
|||||||
seaco_decoder = kwargs.get("seaco_decoder", None)
|
seaco_decoder = kwargs.get("seaco_decoder", None)
|
||||||
if seaco_decoder is not None:
|
if seaco_decoder is not None:
|
||||||
seaco_decoder_conf = kwargs.get("seaco_decoder_conf")
|
seaco_decoder_conf = kwargs.get("seaco_decoder_conf")
|
||||||
seaco_decoder_class = registry_tables.decoder_classes.get(seaco_decoder.lower())
|
seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower())
|
||||||
self.seaco_decoder = seaco_decoder_class(
|
self.seaco_decoder = seaco_decoder_class(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
encoder_output_size=self.inner_dim,
|
encoder_output_size=self.inner_dim,
|
||||||
|
|||||||
@ -2,8 +2,8 @@
|
|||||||
# You can modify the configuration according to your own requirements.
|
# You can modify the configuration according to your own requirements.
|
||||||
|
|
||||||
# to print the register_table:
|
# to print the register_table:
|
||||||
# from funasr.utils.register import registry_tables
|
# from funasr.register import tables
|
||||||
# registry_tables.print()
|
# tables.print()
|
||||||
|
|
||||||
# network architecture
|
# network architecture
|
||||||
model: SeacoParaformer
|
model: SeacoParaformer
|
||||||
|
|||||||
@ -7,11 +7,11 @@ from funasr.models.specaug.mask_along_axis import MaskAlongAxis
|
|||||||
from funasr.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth
|
from funasr.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth
|
||||||
from funasr.models.specaug.mask_along_axis import MaskAlongAxisLFR
|
from funasr.models.specaug.mask_along_axis import MaskAlongAxisLFR
|
||||||
from funasr.models.specaug.time_warp import TimeWarp
|
from funasr.models.specaug.time_warp import TimeWarp
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@register_class("specaug_classes", "SpecAug")
|
@tables.register("specaug_classes", "SpecAug")
|
||||||
class SpecAug(nn.Module):
|
class SpecAug(nn.Module):
|
||||||
"""Implementation of SpecAug.
|
"""Implementation of SpecAug.
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ class SpecAug(nn.Module):
|
|||||||
x, x_lengths = self.time_mask(x, x_lengths)
|
x, x_lengths = self.time_mask(x, x_lengths)
|
||||||
return x, x_lengths
|
return x, x_lengths
|
||||||
|
|
||||||
@register_class("specaug_classes", "SpecAugLFR")
|
@tables.register("specaug_classes", "SpecAugLFR")
|
||||||
class SpecAugLFR(nn.Module):
|
class SpecAugLFR(nn.Module):
|
||||||
"""Implementation of SpecAug.
|
"""Implementation of SpecAug.
|
||||||
lfr_rate:low frame rate
|
lfr_rate:low frame rate
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from funasr.models.transformer.positionwise_feed_forward import (
|
|||||||
from funasr.models.transformer.utils.repeat import repeat
|
from funasr.models.transformer.utils.repeat import repeat
|
||||||
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
|
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
|
||||||
|
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
class DecoderLayer(nn.Module):
|
class DecoderLayer(nn.Module):
|
||||||
"""Single decoder layer module.
|
"""Single decoder layer module.
|
||||||
@ -352,7 +352,7 @@ class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
|
|||||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||||
return logp, state_list
|
return logp, state_list
|
||||||
|
|
||||||
@register_class("decoder_classes", "TransformerDecoder")
|
@tables.register("decoder_classes", "TransformerDecoder")
|
||||||
class TransformerDecoder(BaseTransformerDecoder):
|
class TransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -401,7 +401,7 @@ class TransformerDecoder(BaseTransformerDecoder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
|
@tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder")
|
||||||
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -462,7 +462,7 @@ class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
|
@tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder")
|
||||||
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -524,7 +524,7 @@ class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
|
@tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder")
|
||||||
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -585,7 +585,7 @@ class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
|
@tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder")
|
||||||
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
|
|||||||
from funasr.models.transformer.utils.subsampling import TooShortUttError
|
from funasr.models.transformer.utils.subsampling import TooShortUttError
|
||||||
from funasr.models.transformer.utils.subsampling import check_short_utt
|
from funasr.models.transformer.utils.subsampling import check_short_utt
|
||||||
|
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
class EncoderLayer(nn.Module):
|
class EncoderLayer(nn.Module):
|
||||||
"""Encoder layer module.
|
"""Encoder layer module.
|
||||||
@ -136,7 +136,7 @@ class EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
return x, mask
|
return x, mask
|
||||||
|
|
||||||
@register_class("encoder_classes", "TransformerEncoder")
|
@tables.register("encoder_classes", "TransformerEncoder")
|
||||||
class TransformerEncoder(nn.Module):
|
class TransformerEncoder(nn.Module):
|
||||||
"""Transformer encoder module.
|
"""Transformer encoder module.
|
||||||
|
|
||||||
|
|||||||
@ -15,9 +15,9 @@ from funasr.train_utils.device_funcs import force_gatherable
|
|||||||
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
|
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
|
||||||
from funasr.utils import postprocess_utils
|
from funasr.utils import postprocess_utils
|
||||||
from funasr.utils.datadir_writer import DatadirWriter
|
from funasr.utils.datadir_writer import DatadirWriter
|
||||||
from funasr.utils.register import register_class, registry_tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("model_classes", "Transformer")
|
@tables.register("model_classes", "Transformer")
|
||||||
class Transformer(nn.Module):
|
class Transformer(nn.Module):
|
||||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||||
|
|
||||||
@ -60,19 +60,19 @@ class Transformer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if frontend is not None:
|
if frontend is not None:
|
||||||
frontend_class = registry_tables.frontend_classes.get_class(frontend.lower())
|
frontend_class = tables.frontend_classes.get_class(frontend.lower())
|
||||||
frontend = frontend_class(**frontend_conf)
|
frontend = frontend_class(**frontend_conf)
|
||||||
if specaug is not None:
|
if specaug is not None:
|
||||||
specaug_class = registry_tables.specaug_classes.get_class(specaug.lower())
|
specaug_class = tables.specaug_classes.get_class(specaug.lower())
|
||||||
specaug = specaug_class(**specaug_conf)
|
specaug = specaug_class(**specaug_conf)
|
||||||
if normalize is not None:
|
if normalize is not None:
|
||||||
normalize_class = registry_tables.normalize_classes.get_class(normalize.lower())
|
normalize_class = tables.normalize_classes.get_class(normalize.lower())
|
||||||
normalize = normalize_class(**normalize_conf)
|
normalize = normalize_class(**normalize_conf)
|
||||||
encoder_class = registry_tables.encoder_classes.get_class(encoder.lower())
|
encoder_class = tables.encoder_classes.get_class(encoder.lower())
|
||||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||||
encoder_output_size = encoder.output_size()
|
encoder_output_size = encoder.output_size()
|
||||||
if decoder is not None:
|
if decoder is not None:
|
||||||
decoder_class = registry_tables.decoder_classes.get_class(decoder.lower())
|
decoder_class = tables.decoder_classes.get_class(decoder.lower())
|
||||||
decoder = decoder_class(
|
decoder = decoder_class(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
encoder_output_size=encoder_output_size,
|
encoder_output_size=encoder_output_size,
|
||||||
|
|||||||
@ -2,8 +2,8 @@
|
|||||||
# You can modify the configuration according to your own requirements.
|
# You can modify the configuration according to your own requirements.
|
||||||
|
|
||||||
# to print the register_table:
|
# to print the register_table:
|
||||||
# from funasr.utils.register import registry_tables
|
# from funasr.register import tables
|
||||||
# registry_tables.print()
|
# tables.print()
|
||||||
|
|
||||||
# network architecture
|
# network architecture
|
||||||
#model: funasr.models.paraformer.model:Paraformer
|
#model: funasr.models.paraformer.model:Paraformer
|
||||||
|
|||||||
77
funasr/register.py
Normal file
77
funasr/register.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import logging
|
||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RegisterTables:
|
||||||
|
model_classes = {}
|
||||||
|
frontend_classes = {}
|
||||||
|
specaug_classes = {}
|
||||||
|
normalize_classes = {}
|
||||||
|
encoder_classes = {}
|
||||||
|
decoder_classes = {}
|
||||||
|
joint_network_classes = {}
|
||||||
|
predictor_classes = {}
|
||||||
|
stride_conv_classes = {}
|
||||||
|
tokenizer_classes = {}
|
||||||
|
batch_sampler_classes = {}
|
||||||
|
dataset_classes = {}
|
||||||
|
index_ds_classes = {}
|
||||||
|
|
||||||
|
def print(self,):
|
||||||
|
print("\ntables: \n")
|
||||||
|
fields = vars(self)
|
||||||
|
for classes_key, classes_dict in fields.items():
|
||||||
|
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
|
||||||
|
|
||||||
|
if classes_key.endswith("_meta"):
|
||||||
|
headers = ["class name", "register name", "class location"]
|
||||||
|
metas = []
|
||||||
|
for register_key, meta in classes_dict.items():
|
||||||
|
metas.append(meta)
|
||||||
|
metas.sort(key=lambda x: x[0])
|
||||||
|
data = [headers] + metas
|
||||||
|
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
|
||||||
|
|
||||||
|
for row in data:
|
||||||
|
print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |")
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def register(self, register_tables_key: str, key=None):
|
||||||
|
def decorator(target_class):
|
||||||
|
|
||||||
|
if not hasattr(self, register_tables_key):
|
||||||
|
setattr(self, register_tables_key, {})
|
||||||
|
logging.info("new registry table has been added: {}".format(register_tables_key))
|
||||||
|
|
||||||
|
registry = getattr(self, register_tables_key)
|
||||||
|
registry_key = key if key is not None else target_class.__name__
|
||||||
|
registry_key = registry_key.lower()
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
|
||||||
|
registry_key, target_class, register_tables_key)
|
||||||
|
|
||||||
|
registry[registry_key] = target_class
|
||||||
|
|
||||||
|
# meta, headers = ["class name", "register name", "class location"]
|
||||||
|
register_tables_key_meta = register_tables_key + "_meta"
|
||||||
|
if not hasattr(self, register_tables_key_meta):
|
||||||
|
setattr(self, register_tables_key_meta, {})
|
||||||
|
registry_meta = getattr(self, register_tables_key_meta)
|
||||||
|
class_file = inspect.getfile(target_class)
|
||||||
|
class_line = inspect.getsourcelines(target_class)[1]
|
||||||
|
meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
|
||||||
|
registry_meta[registry_key] = meata_data
|
||||||
|
# print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
|
||||||
|
return target_class
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
tables = RegisterTables()
|
||||||
|
|
||||||
|
|
||||||
|
import funasr
|
||||||
|
|
||||||
@ -5,9 +5,9 @@ from typing import Union
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from funasr.tokenizer.abs_tokenizer import BaseTokenizer
|
from funasr.tokenizer.abs_tokenizer import BaseTokenizer
|
||||||
from funasr.utils.register import register_class
|
from funasr.register import tables
|
||||||
|
|
||||||
@register_class("tokenizer_classes", "CharTokenizer")
|
@tables.register("tokenizer_classes", "CharTokenizer")
|
||||||
class CharTokenizer(BaseTokenizer):
|
class CharTokenizer(BaseTokenizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,72 +0,0 @@
|
|||||||
import logging
|
|
||||||
import inspect
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClassRegistryTables:
|
|
||||||
model_classes = {}
|
|
||||||
frontend_classes = {}
|
|
||||||
specaug_classes = {}
|
|
||||||
normalize_classes = {}
|
|
||||||
encoder_classes = {}
|
|
||||||
decoder_classes = {}
|
|
||||||
joint_network_classes = {}
|
|
||||||
predictor_classes = {}
|
|
||||||
stride_conv_classes = {}
|
|
||||||
tokenizer_classes = {}
|
|
||||||
batch_sampler_classes = {}
|
|
||||||
dataset_classes = {}
|
|
||||||
index_ds_classes = {}
|
|
||||||
|
|
||||||
def print(self,):
|
|
||||||
print("\nregister_tables: \n")
|
|
||||||
fields = vars(self)
|
|
||||||
for classes_key, classes_dict in fields.items():
|
|
||||||
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
|
|
||||||
|
|
||||||
if classes_key.endswith("_meta"):
|
|
||||||
headers = ["class name", "register name", "class location"]
|
|
||||||
metas = []
|
|
||||||
for register_key, meta in classes_dict.items():
|
|
||||||
metas.append(meta)
|
|
||||||
metas.sort(key=lambda x: x[0])
|
|
||||||
data = [headers] + metas
|
|
||||||
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
|
|
||||||
|
|
||||||
for row in data:
|
|
||||||
print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |")
|
|
||||||
print("\n")
|
|
||||||
|
|
||||||
registry_tables = ClassRegistryTables()
|
|
||||||
|
|
||||||
def register_class(registry_tables_key:str, key=None):
|
|
||||||
def decorator(target_class):
|
|
||||||
|
|
||||||
if not hasattr(registry_tables, registry_tables_key):
|
|
||||||
setattr(registry_tables, registry_tables_key, {})
|
|
||||||
logging.info("new registry table has been added: {}".format(registry_tables_key))
|
|
||||||
|
|
||||||
registry = getattr(registry_tables, registry_tables_key)
|
|
||||||
registry_key = key if key is not None else target_class.__name__
|
|
||||||
registry_key = registry_key.lower()
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(registry_key, target_class, registry_tables_key)
|
|
||||||
|
|
||||||
registry[registry_key] = target_class
|
|
||||||
|
|
||||||
# meta, headers = ["class name", "register name", "class location"]
|
|
||||||
registry_tables_key_meta = registry_tables_key + "_meta"
|
|
||||||
if not hasattr(registry_tables, registry_tables_key_meta):
|
|
||||||
setattr(registry_tables, registry_tables_key_meta, {})
|
|
||||||
registry_meta = getattr(registry_tables, registry_tables_key_meta)
|
|
||||||
class_file = inspect.getfile(target_class)
|
|
||||||
class_line = inspect.getsourcelines(target_class)[1]
|
|
||||||
meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
|
|
||||||
registry_meta[registry_key] = meata_data
|
|
||||||
# print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
|
|
||||||
return target_class
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
import funasr
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user