mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
This commit is contained in:
commit
1bdb956318
@ -11,7 +11,6 @@ import numpy as np
|
||||
import resampy
|
||||
import soundfile
|
||||
from tqdm import tqdm
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.fileio.read_text import read_2column_text
|
||||
@ -31,7 +30,6 @@ def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
|
||||
(3, 4, 5)
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
|
||||
return None
|
||||
return tuple(map(int, integers.strip().split(",")))
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import os
|
||||
<<<<<<< HEAD
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
@ -21,50 +20,17 @@ def modelscope_finetune(params):
|
||||
batch_bins=params.batch_bins,
|
||||
max_epoch=params.max_epoch,
|
||||
lr=params.lr)
|
||||
=======
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from funasr.datasets.ms_dataset import MsDataset
|
||||
|
||||
|
||||
def modelscope_finetune(params):
|
||||
if not os.path.exists(params["output_dir"]):
|
||||
os.makedirs(params["output_dir"], exist_ok=True)
|
||||
# dataset split ["train", "validation"]
|
||||
ds_dict = MsDataset.load(params["data_dir"])
|
||||
kwargs = dict(
|
||||
model=params["model"],
|
||||
model_revision=params["model_revision"],
|
||||
data_dir=ds_dict,
|
||||
dataset_type=params["dataset_type"],
|
||||
work_dir=params["output_dir"],
|
||||
batch_bins=params["batch_bins"],
|
||||
max_epoch=params["max_epoch"],
|
||||
lr=params["lr"])
|
||||
>>>>>>> main
|
||||
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
<<<<<<< HEAD
|
||||
params = modelscope_args(model="damo/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch", data_path="./data")
|
||||
params.output_dir = "./checkpoint" # m模型保存路径
|
||||
params.data_path = "./example_data/" # 数据路径
|
||||
params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large
|
||||
params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
|
||||
params.max_epoch = 50 # 最大训练轮数
|
||||
params.max_epoch = 20 # 最大训练轮数
|
||||
params.lr = 0.00005 # 设置学习率
|
||||
|
||||
=======
|
||||
params = {}
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./data"
|
||||
params["batch_bins"] = 2000
|
||||
params["dataset_type"] = "small"
|
||||
params["max_epoch"] = 50
|
||||
params["lr"] = 0.00005
|
||||
params["model"] = "damo/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch"
|
||||
params["model_revision"] = None
|
||||
>>>>>>> main
|
||||
modelscope_finetune(params)
|
||||
modelscope_finetune(params)
|
||||
@ -1,33 +1,3 @@
|
||||
<<<<<<< HEAD
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
def modelscope_infer(args):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model=args.model,
|
||||
output_dir=args.output_dir,
|
||||
batch_size=args.batch_size,
|
||||
param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
|
||||
)
|
||||
inference_pipeline(audio_in=args.audio_in)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', type=str, default="damo/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch")
|
||||
parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
|
||||
parser.add_argument('--output_dir', type=str, default="./results/")
|
||||
parser.add_argument('--decoding_mode', type=str, default="normal")
|
||||
parser.add_argument('--hotword_txt', type=str, default=None)
|
||||
parser.add_argument('--batch_size', type=int, default=64)
|
||||
parser.add_argument('--gpuid', type=str, default="0")
|
||||
args = parser.parse_args()
|
||||
modelscope_infer(args)
|
||||
=======
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
@ -40,5 +10,4 @@ if __name__ == "__main__":
|
||||
output_dir=output_dir,
|
||||
)
|
||||
rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
|
||||
print(rec_result)
|
||||
>>>>>>> main
|
||||
print(rec_result)
|
||||
@ -22,9 +22,7 @@ import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from packaging.version import parse as V
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
|
||||
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
||||
@ -78,7 +76,6 @@ class Speech2Text:
|
||||
frontend_conf: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
@ -192,7 +189,6 @@ class Speech2Text:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
@ -248,7 +244,6 @@ class Speech2Text:
|
||||
text = None
|
||||
results.append((text, token, token_int, hyp))
|
||||
|
||||
assert check_return_type(results)
|
||||
return results
|
||||
|
||||
|
||||
@ -288,7 +283,6 @@ class Speech2TextParaformer:
|
||||
decoding_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
@ -413,7 +407,6 @@ class Speech2TextParaformer:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
@ -516,7 +509,6 @@ class Speech2TextParaformer:
|
||||
vad_offset=begin_time)
|
||||
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
|
||||
|
||||
# assert check_return_type(results)
|
||||
return results
|
||||
|
||||
def generate_hotwords_list(self, hotword_list_or_file):
|
||||
@ -656,7 +648,6 @@ class Speech2TextParaformerOnline:
|
||||
hotword_list_or_file: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
@ -776,7 +767,6 @@ class Speech2TextParaformerOnline:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
results = []
|
||||
cache_en = cache["encoder"]
|
||||
if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
|
||||
@ -871,7 +861,6 @@ class Speech2TextParaformerOnline:
|
||||
|
||||
results.append(postprocessed_result)
|
||||
|
||||
# assert check_return_type(results)
|
||||
return results
|
||||
|
||||
|
||||
@ -912,7 +901,6 @@ class Speech2TextUniASR:
|
||||
frontend_conf: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
@ -1036,7 +1024,6 @@ class Speech2TextUniASR:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
@ -1104,7 +1091,6 @@ class Speech2TextUniASR:
|
||||
text = None
|
||||
results.append((text, token, token_int, hyp))
|
||||
|
||||
assert check_return_type(results)
|
||||
return results
|
||||
|
||||
|
||||
@ -1143,7 +1129,6 @@ class Speech2TextMFCCA:
|
||||
streaming: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
@ -1248,7 +1233,6 @@ class Speech2TextMFCCA:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
@ -1298,7 +1282,6 @@ class Speech2TextMFCCA:
|
||||
text = None
|
||||
results.append((text, token, token_int, hyp))
|
||||
|
||||
assert check_return_type(results)
|
||||
return results
|
||||
|
||||
|
||||
@ -1355,7 +1338,6 @@ class Speech2TextTransducer:
|
||||
"""Construct a Speech2Text object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
asr_model, asr_train_args = build_model_from_file(
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
@ -1534,7 +1516,6 @@ class Speech2TextTransducer:
|
||||
Returns:
|
||||
nbest_hypothesis: N-best hypothesis.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
@ -1566,7 +1547,6 @@ class Speech2TextTransducer:
|
||||
Returns:
|
||||
nbest_hypothesis: N-best hypothesis.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
@ -1608,36 +1588,9 @@ class Speech2TextTransducer:
|
||||
text = None
|
||||
results.append((text, token, token_int, hyp))
|
||||
|
||||
assert check_return_type(results)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_tag: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Speech2Text:
|
||||
"""Build Speech2Text instance from the pretrained model.
|
||||
Args:
|
||||
model_tag: Model tag of the pretrained models.
|
||||
Return:
|
||||
: Speech2Text instance.
|
||||
"""
|
||||
if model_tag is not None:
|
||||
try:
|
||||
from espnet_model_zoo.downloader import ModelDownloader
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
"`espnet_model_zoo` is not installed. "
|
||||
"Please install via `pip install -U espnet_model_zoo`."
|
||||
)
|
||||
raise
|
||||
d = ModelDownloader()
|
||||
kwargs.update(**d.download_and_unpack(model_tag))
|
||||
|
||||
return Speech2TextTransducer(**kwargs)
|
||||
|
||||
|
||||
class Speech2TextSAASR:
|
||||
"""Speech2Text class
|
||||
@ -1675,7 +1628,6 @@ class Speech2TextSAASR:
|
||||
frontend_conf: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
@ -1793,7 +1745,6 @@ class Speech2TextSAASR:
|
||||
text, text_id, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
@ -1886,5 +1837,4 @@ class Speech2TextSAASR:
|
||||
|
||||
results.append((text, text_id, token, token_int, hyp))
|
||||
|
||||
assert check_return_type(results)
|
||||
return results
|
||||
|
||||
@ -21,7 +21,6 @@ import torch
|
||||
import torchaudio
|
||||
import soundfile
|
||||
import yaml
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.bin.asr_infer import Speech2Text
|
||||
from funasr.bin.asr_infer import Speech2TextMFCCA
|
||||
@ -80,7 +79,6 @@ def inference_asr(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
if batch_size > 1:
|
||||
@ -240,7 +238,6 @@ def inference_paraformer(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
|
||||
@ -481,7 +478,6 @@ def inference_paraformer_vad_punc(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
|
||||
@ -749,7 +745,6 @@ def inference_paraformer_online(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
if word_lm_train_config is not None:
|
||||
raise NotImplementedError("Word LM is not implemented")
|
||||
@ -957,7 +952,6 @@ def inference_uniasr(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
if batch_size > 1:
|
||||
@ -1126,7 +1120,6 @@ def inference_mfcca(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
if batch_size > 1:
|
||||
@ -1314,7 +1307,6 @@ def inference_transducer(
|
||||
right_context: Number of frames in right context AFTER subsampling.
|
||||
display_partial_hypotheses: Whether to display partial hypotheses.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
@ -1464,7 +1456,6 @@ def inference_sa_asr(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
if word_lm_train_config is not None:
|
||||
|
||||
@ -15,7 +15,6 @@ import numpy as np
|
||||
import torch
|
||||
from scipy.ndimage import median_filter
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||
from funasr.tasks.diar import DiarTask
|
||||
@ -45,7 +44,6 @@ class Speech2DiarizationEEND:
|
||||
device: str = "cpu",
|
||||
dtype: str = "float32",
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build Diarization model
|
||||
diar_model, diar_train_args = build_model_from_file(
|
||||
@ -88,7 +86,6 @@ class Speech2DiarizationEEND:
|
||||
diarization results
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
@ -107,36 +104,6 @@ class Speech2DiarizationEEND:
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_tag: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
):
|
||||
"""Build Speech2Diarization instance from the pretrained model.
|
||||
|
||||
Args:
|
||||
model_tag (Optional[str]): Model tag of the pretrained models.
|
||||
Currently, the tags of espnet_model_zoo are supported.
|
||||
|
||||
Returns:
|
||||
Speech2Diarization: Speech2Diarization instance.
|
||||
|
||||
"""
|
||||
if model_tag is not None:
|
||||
try:
|
||||
from espnet_model_zoo.downloader import ModelDownloader
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
"`espnet_model_zoo` is not installed. "
|
||||
"Please install via `pip install -U espnet_model_zoo`."
|
||||
)
|
||||
raise
|
||||
d = ModelDownloader()
|
||||
kwargs.update(**d.download_and_unpack(model_tag))
|
||||
|
||||
return Speech2DiarizationEEND(**kwargs)
|
||||
|
||||
|
||||
class Speech2DiarizationSOND:
|
||||
"""Speech2Xvector class
|
||||
@ -163,7 +130,6 @@ class Speech2DiarizationSOND:
|
||||
smooth_size: int = 83,
|
||||
dur_threshold: float = 10,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# TODO: 1. Build Diarization model
|
||||
diar_model, diar_train_args = build_model_from_file(
|
||||
@ -283,7 +249,6 @@ class Speech2DiarizationSOND:
|
||||
diarization results for each speaker
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
@ -305,33 +270,3 @@ class Speech2DiarizationSOND:
|
||||
results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
|
||||
|
||||
return results, pse_labels
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_tag: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
):
|
||||
"""Build Speech2Xvector instance from the pretrained model.
|
||||
|
||||
Args:
|
||||
model_tag (Optional[str]): Model tag of the pretrained models.
|
||||
Currently, the tags of espnet_model_zoo are supported.
|
||||
|
||||
Returns:
|
||||
Speech2Xvector: Speech2Xvector instance.
|
||||
|
||||
"""
|
||||
if model_tag is not None:
|
||||
try:
|
||||
from espnet_model_zoo.downloader import ModelDownloader
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
"`espnet_model_zoo` is not installed. "
|
||||
"Please install via `pip install -U espnet_model_zoo`."
|
||||
)
|
||||
raise
|
||||
d = ModelDownloader()
|
||||
kwargs.update(**d.download_and_unpack(model_tag))
|
||||
|
||||
return Speech2DiarizationSOND(**kwargs)
|
||||
|
||||
@ -18,7 +18,6 @@ import numpy as np
|
||||
import soundfile
|
||||
import torch
|
||||
from scipy.signal import medfilt
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
|
||||
from funasr.datasets.iterable_dataset import load_bytes
|
||||
@ -52,7 +51,6 @@ def inference_sond(
|
||||
mode: str = "sond",
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
if batch_size > 1:
|
||||
@ -233,7 +231,6 @@ def inference_eend(
|
||||
param_dict: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
if batch_size > 1:
|
||||
|
||||
@ -15,7 +15,6 @@ from typing import Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.parallel import data_parallel
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
|
||||
@ -50,7 +49,6 @@ def inference_lm(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
@ -38,7 +37,6 @@ def inference_punc(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
@ -118,7 +116,6 @@ def inference_punc_vad_realtime(
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
|
||||
|
||||
@ -12,8 +12,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
@ -42,7 +40,6 @@ class Speech2Xvector:
|
||||
streaming: bool = False,
|
||||
embedding_node: str = "resnet1_dense",
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# TODO: 1. Build SV model
|
||||
sv_model, sv_train_args = build_model_from_file(
|
||||
@ -108,7 +105,6 @@ class Speech2Xvector:
|
||||
embedding, ref_embedding, similarity_score
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
self.sv_model.eval()
|
||||
embedding = self.calculate_embedding(speech)
|
||||
ref_emb, score = None, None
|
||||
@ -117,35 +113,4 @@ class Speech2Xvector:
|
||||
score = torch.cosine_similarity(embedding, ref_emb)
|
||||
|
||||
results = (embedding, ref_emb, score)
|
||||
assert check_return_type(results)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_tag: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
):
|
||||
"""Build Speech2Xvector instance from the pretrained model.
|
||||
|
||||
Args:
|
||||
model_tag (Optional[str]): Model tag of the pretrained models.
|
||||
Currently, the tags of espnet_model_zoo are supported.
|
||||
|
||||
Returns:
|
||||
Speech2Xvector: Speech2Xvector instance.
|
||||
|
||||
"""
|
||||
if model_tag is not None:
|
||||
try:
|
||||
from espnet_model_zoo.downloader import ModelDownloader
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
"`espnet_model_zoo` is not installed. "
|
||||
"Please install via `pip install -U espnet_model_zoo`."
|
||||
)
|
||||
raise
|
||||
d = ModelDownloader()
|
||||
kwargs.update(**d.download_and_unpack(model_tag))
|
||||
|
||||
return Speech2Xvector(**kwargs)
|
||||
|
||||
@ -15,7 +15,6 @@ from typing import Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from kaldiio import WriteHelper
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.bin.sv_infer import Speech2Xvector
|
||||
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
|
||||
@ -46,7 +45,6 @@ def inference_sv(
|
||||
param_dict: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
|
||||
@ -79,10 +77,7 @@ def inference_sv(
|
||||
embedding_node=embedding_node
|
||||
)
|
||||
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
|
||||
speech2xvector = Speech2Xvector.from_pretrained(
|
||||
model_tag=model_tag,
|
||||
**speech2xvector_kwargs,
|
||||
)
|
||||
speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
|
||||
speech2xvector.sv_model.eval()
|
||||
|
||||
def _forward(
|
||||
|
||||
@ -7,7 +7,6 @@ import sys
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
@ -81,7 +80,6 @@ def tokenize(
|
||||
cleaner: Optional[str],
|
||||
g2p: Optional[str],
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
@ -26,7 +25,6 @@ class Speech2Timestamp:
|
||||
dtype: str = "float32",
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
# 1. Build ASR model
|
||||
tp_model, tp_train_args = build_model_from_file(
|
||||
timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp"
|
||||
@ -64,7 +62,6 @@ class Speech2Timestamp:
|
||||
speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
||||
text_lengths: Union[torch.Tensor, np.ndarray] = None
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
|
||||
@ -13,7 +13,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.bin.tp_infer import Speech2Timestamp
|
||||
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
|
||||
@ -47,7 +46,6 @@ def inference_tp(
|
||||
seg_dict_file: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
ncpu = kwargs.get("ncpu", 1)
|
||||
torch.set_num_threads(ncpu)
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.build_utils.build_model_from_file import build_model_from_file
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
||||
@ -42,7 +41,6 @@ class Speech2VadSegment:
|
||||
dtype: str = "float32",
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build vad model
|
||||
vad_model, vad_infer_args = build_model_from_file(
|
||||
@ -76,7 +74,6 @@ class Speech2VadSegment:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
@ -149,7 +146,6 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
|
||||
@ -18,7 +18,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
@ -47,7 +46,6 @@ def inference_vad(
|
||||
num_workers: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
@ -148,7 +146,6 @@ def inference_vad_online(
|
||||
num_workers: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.build_utils.build_model import build_model
|
||||
from funasr.models.base_model import FunASRModel
|
||||
@ -30,7 +29,6 @@ def build_model_from_file(
|
||||
device: Device type, "cpu", "cuda", or "cuda:N".
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if config_file is None:
|
||||
assert model_file is not None, (
|
||||
"The argument 'model_file' must be provided "
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.datasets.iterable_dataset import IterableESPnetDataset
|
||||
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
|
||||
@ -23,7 +22,6 @@ def build_streaming_iterator(
|
||||
train: bool = False,
|
||||
) -> DataLoader:
|
||||
"""Build DataLoader using iterable dataset"""
|
||||
assert check_argument_types()
|
||||
|
||||
# preprocess
|
||||
if preprocess_fn is not None:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
@ -254,5 +253,4 @@ def build_sv_model(args):
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
|
||||
@ -25,7 +25,6 @@ import oss2
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.optim
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
||||
from funasr.main_funcs.average_nbest_models import average_nbest_models
|
||||
@ -118,7 +117,6 @@ class Trainer:
|
||||
|
||||
def build_options(self, args: argparse.Namespace) -> TrainerOptions:
|
||||
"""Build options consumed by train(), eval()"""
|
||||
assert check_argument_types()
|
||||
return build_dataclass(TrainerOptions, args)
|
||||
|
||||
@classmethod
|
||||
@ -156,7 +154,6 @@ class Trainer:
|
||||
|
||||
def run(self) -> None:
|
||||
"""Perform training. This method performs the main process of training."""
|
||||
assert check_argument_types()
|
||||
# NOTE(kamo): Don't check the type more strictly as far trainer_options
|
||||
model = self.model
|
||||
optimizers = self.optimizers
|
||||
@ -522,7 +519,6 @@ class Trainer:
|
||||
options: TrainerOptions,
|
||||
distributed_option: DistributedOption,
|
||||
) -> Tuple[bool, bool]:
|
||||
assert check_argument_types()
|
||||
|
||||
grad_noise = options.grad_noise
|
||||
accum_grad = options.accum_grad
|
||||
@ -758,7 +754,6 @@ class Trainer:
|
||||
options: TrainerOptions,
|
||||
distributed_option: DistributedOption,
|
||||
) -> None:
|
||||
assert check_argument_types()
|
||||
ngpu = options.ngpu
|
||||
# no_forward_run = options.no_forward_run
|
||||
distributed = distributed_option.distributed
|
||||
|
||||
@ -6,8 +6,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.modules.nets_utils import pad_list
|
||||
|
||||
@ -22,7 +20,6 @@ class CommonCollateFn:
|
||||
not_sequence: Collection[str] = (),
|
||||
max_sample_size=None
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.float_pad_value = float_pad_value
|
||||
self.int_pad_value = int_pad_value
|
||||
self.not_sequence = set(not_sequence)
|
||||
@ -53,7 +50,6 @@ def common_collate_fn(
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
@ -79,7 +75,6 @@ def common_collate_fn(
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
|
||||
def crop_to_max_size(feature, target_size):
|
||||
@ -99,7 +94,6 @@ def clipping_collate_fn(
|
||||
not_sequence: Collection[str] = (),
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
# mainly for pre-training
|
||||
assert check_argument_types()
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
@ -131,5 +125,4 @@ def clipping_collate_fn(
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
@ -23,8 +23,6 @@ import kaldiio
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.fileio.npy_scp import NpyScpReader
|
||||
from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
|
||||
@ -37,7 +35,6 @@ from funasr.utils.sized_dict import SizedDict
|
||||
|
||||
class AdapterForSoundScpReader(collections.abc.Mapping):
|
||||
def __init__(self, loader, dtype=None):
|
||||
assert check_argument_types()
|
||||
self.loader = loader
|
||||
self.dtype = dtype
|
||||
self.rate = None
|
||||
@ -284,7 +281,6 @@ class ESPnetDataset(AbsDataset):
|
||||
max_cache_fd: int = 0,
|
||||
dest_sample_rate: int = 16000,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"'
|
||||
@ -379,7 +375,6 @@ class ESPnetDataset(AbsDataset):
|
||||
return _mes
|
||||
|
||||
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
|
||||
# Change integer-id to string-id
|
||||
if isinstance(uid, int):
|
||||
@ -444,5 +439,4 @@ class ESPnetDataset(AbsDataset):
|
||||
self.cache[uid] = data
|
||||
|
||||
retval = uid, data
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@ -16,7 +16,6 @@ import torch
|
||||
import torchaudio
|
||||
import soundfile
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
from typeguard import check_argument_types
|
||||
import os.path
|
||||
|
||||
from funasr.datasets.dataset import ESPnetDataset
|
||||
@ -121,7 +120,6 @@ class IterableESPnetDataset(IterableDataset):
|
||||
int_dtype: str = "long",
|
||||
key_file: str = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"'
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Union
|
||||
|
||||
import sentencepiece as spm
|
||||
from torch.utils.data import DataLoader
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.datasets.large_datasets.dataset import Dataset
|
||||
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
||||
@ -43,7 +42,6 @@ def load_seg_dict(seg_dict_file):
|
||||
|
||||
class SentencepiecesTokenizer(AbsTokenizer):
|
||||
def __init__(self, model: Union[Path, str]):
|
||||
assert check_argument_types()
|
||||
self.model = str(model)
|
||||
self.sp = None
|
||||
|
||||
|
||||
@ -11,8 +11,6 @@ from typing import Union
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
import soundfile
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.cleaner import TextCleaner
|
||||
@ -268,7 +266,6 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
def _speech_process(
|
||||
self, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, Union[str, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
if self.speech_name in data:
|
||||
if self.train and (self.rirs is not None or self.noises is not None):
|
||||
speech = data[self.speech_name]
|
||||
@ -355,7 +352,6 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
speech = data[self.speech_name]
|
||||
ma = np.max(np.abs(speech))
|
||||
data[self.speech_name] = speech * self.speech_volume_normalize / ma
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
def _text_process(
|
||||
@ -372,13 +368,11 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
tokens = self.tokenizer.text2tokens(text)
|
||||
text_ints = self.token_id_converter.tokens2ids(tokens)
|
||||
data[self.text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
assert check_argument_types()
|
||||
|
||||
data = self._speech_process(data)
|
||||
data = self._text_process(data)
|
||||
@ -445,7 +439,6 @@ class LMPreprocessor(CommonPreprocessor):
|
||||
tokens = self.tokenizer.text2tokens(text)
|
||||
text_ints = self.token_id_converter.tokens2ids(tokens)
|
||||
data[self.text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
|
||||
@ -502,13 +495,11 @@ class CommonPreprocessor_multi(AbsPreprocessor):
|
||||
tokens = self.tokenizer.text2tokens(text)
|
||||
text_ints = self.token_id_converter.tokens2ids(tokens)
|
||||
data[text_n] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
assert check_argument_types()
|
||||
|
||||
if self.speech_name in data:
|
||||
# Nothing now: candidates:
|
||||
@ -612,7 +603,6 @@ class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
tokens = self.tokenizer[i].text2tokens(text)
|
||||
text_ints = self.token_id_converter[i].tokens2ids(tokens)
|
||||
data[text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
@ -690,7 +680,6 @@ class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
|
||||
) -> Dict[str, Union[list, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
# Split words.
|
||||
if isinstance(data[self.text_name], str):
|
||||
split_text = self.split_words(data[self.text_name])
|
||||
|
||||
@ -6,8 +6,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.modules.nets_utils import pad_list
|
||||
|
||||
@ -22,7 +20,6 @@ class CommonCollateFn:
|
||||
not_sequence: Collection[str] = (),
|
||||
max_sample_size=None
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.float_pad_value = float_pad_value
|
||||
self.int_pad_value = int_pad_value
|
||||
self.not_sequence = set(not_sequence)
|
||||
@ -53,7 +50,6 @@ def common_collate_fn(
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
@ -79,7 +75,6 @@ def common_collate_fn(
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
|
||||
def crop_to_max_size(feature, target_size):
|
||||
|
||||
@ -15,8 +15,6 @@ import kaldiio
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.fileio.npy_scp import NpyScpReader
|
||||
from funasr.fileio.sound_scp import SoundScpReader
|
||||
@ -24,7 +22,6 @@ from funasr.fileio.sound_scp import SoundScpReader
|
||||
|
||||
class AdapterForSoundScpReader(collections.abc.Mapping):
|
||||
def __init__(self, loader, dtype=None):
|
||||
assert check_argument_types()
|
||||
self.loader = loader
|
||||
self.dtype = dtype
|
||||
self.rate = None
|
||||
@ -112,7 +109,6 @@ class ESPnetDataset(Dataset):
|
||||
speed_perturb: Union[list, tuple] = None,
|
||||
mode: str = "train",
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(path_name_type_list) == 0:
|
||||
raise ValueError(
|
||||
'1 or more elements are required for "path_name_type_list"'
|
||||
@ -207,7 +203,6 @@ class ESPnetDataset(Dataset):
|
||||
return _mes
|
||||
|
||||
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
|
||||
# Change integer-id to string-id
|
||||
if isinstance(uid, int):
|
||||
@ -265,5 +260,4 @@ class ESPnetDataset(Dataset):
|
||||
data[name] = value
|
||||
|
||||
retval = uid, data
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Dict
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.read_text import load_num_sequence_text
|
||||
from funasr.samplers.abs_sampler import AbsSampler
|
||||
@ -21,7 +20,6 @@ class LengthBatchSampler(AbsSampler):
|
||||
drop_last: bool = False,
|
||||
padding: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert batch_bins > 0
|
||||
if sort_batch != "ascending" and sort_batch != "descending":
|
||||
raise ValueError(
|
||||
|
||||
@ -10,8 +10,6 @@ from typing import Union
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
import soundfile
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.cleaner import TextCleaner
|
||||
@ -260,7 +258,6 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
def _speech_process(
|
||||
self, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, Union[str, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
if self.speech_name in data:
|
||||
if self.train and (self.rirs is not None or self.noises is not None):
|
||||
speech = data[self.speech_name]
|
||||
@ -347,7 +344,6 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
speech = data[self.speech_name]
|
||||
ma = np.max(np.abs(speech))
|
||||
data[self.speech_name] = speech * self.speech_volume_normalize / ma
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
def _text_process(
|
||||
@ -365,13 +361,11 @@ class CommonPreprocessor(AbsPreprocessor):
|
||||
tokens = self.tokenizer.text2tokens(text)
|
||||
text_ints = self.token_id_converter.tokens2ids(tokens)
|
||||
data[self.text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
assert check_argument_types()
|
||||
|
||||
data = self._speech_process(data)
|
||||
data = self._text_process(data)
|
||||
@ -439,7 +433,6 @@ class LMPreprocessor(CommonPreprocessor):
|
||||
tokens = self.tokenizer.text2tokens(text)
|
||||
text_ints = self.token_id_converter.tokens2ids(tokens)
|
||||
data[self.text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
|
||||
@ -496,13 +489,11 @@ class CommonPreprocessor_multi(AbsPreprocessor):
|
||||
tokens = self.tokenizer.text2tokens(text)
|
||||
text_ints = self.token_id_converter.tokens2ids(tokens)
|
||||
data[text_n] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
||||
) -> Dict[str, np.ndarray]:
|
||||
assert check_argument_types()
|
||||
|
||||
if self.speech_name in data:
|
||||
# Nothing now: candidates:
|
||||
@ -606,7 +597,6 @@ class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
tokens = self.tokenizer[i].text2tokens(text)
|
||||
text_ints = self.token_id_converter[i].tokens2ids(tokens)
|
||||
data[text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
|
||||
@ -685,7 +675,6 @@ class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
|
||||
) -> Dict[str, Union[list, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
# Split words.
|
||||
if isinstance(data[self.text_name], str):
|
||||
split_text = self.split_words(data[self.text_name])
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
from typing import Union, Dict
|
||||
from pathlib import Path
|
||||
from typeguard import check_argument_types
|
||||
|
||||
import os
|
||||
import logging
|
||||
@ -26,7 +25,6 @@ class ModelExport:
|
||||
calib_num: int = 200,
|
||||
model_revision: str = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.set_all_random_seed(0)
|
||||
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
@ -2,8 +2,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
|
||||
class DatadirWriter:
|
||||
@ -20,7 +18,6 @@ class DatadirWriter:
|
||||
"""
|
||||
|
||||
def __init__(self, p: Union[Path, str]):
|
||||
assert check_argument_types()
|
||||
self.path = Path(p)
|
||||
self.chilidren = {}
|
||||
self.fd = None
|
||||
@ -31,7 +28,6 @@ class DatadirWriter:
|
||||
return self
|
||||
|
||||
def __getitem__(self, key: str) -> "DatadirWriter":
|
||||
assert check_argument_types()
|
||||
if self.fd is not None:
|
||||
raise RuntimeError("This writer points out a file")
|
||||
|
||||
@ -41,11 +37,9 @@ class DatadirWriter:
|
||||
self.has_children = True
|
||||
|
||||
retval = self.chilidren[key]
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
def __setitem__(self, key: str, value: str):
|
||||
assert check_argument_types()
|
||||
if self.has_children:
|
||||
raise RuntimeError("This writer points out a directory")
|
||||
if key in self.keys:
|
||||
|
||||
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.read_text import read_2column_text
|
||||
|
||||
@ -25,7 +24,6 @@ class NpyScpWriter:
|
||||
"""
|
||||
|
||||
def __init__(self, outdir: Union[Path, str], scpfile: Union[Path, str]):
|
||||
assert check_argument_types()
|
||||
self.dir = Path(outdir)
|
||||
self.dir.mkdir(parents=True, exist_ok=True)
|
||||
scpfile = Path(scpfile)
|
||||
@ -73,7 +71,6 @@ class NpyScpReader(collections.abc.Mapping):
|
||||
"""
|
||||
|
||||
def __init__(self, fname: Union[Path, str]):
|
||||
assert check_argument_types()
|
||||
self.fname = Path(fname)
|
||||
self.data = read_2column_text(fname)
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.read_text import load_num_sequence_text
|
||||
|
||||
@ -29,7 +28,6 @@ class FloatRandomGenerateDataset(collections.abc.Mapping):
|
||||
dtype: Union[str, np.dtype] = "float32",
|
||||
loader_type: str = "csv_int",
|
||||
):
|
||||
assert check_argument_types()
|
||||
shape_file = Path(shape_file)
|
||||
self.utt2shape = load_num_sequence_text(shape_file, loader_type)
|
||||
self.dtype = np.dtype(dtype)
|
||||
@ -68,7 +66,6 @@ class IntRandomGenerateDataset(collections.abc.Mapping):
|
||||
dtype: Union[str, np.dtype] = "int64",
|
||||
loader_type: str = "csv_int",
|
||||
):
|
||||
assert check_argument_types()
|
||||
shape_file = Path(shape_file)
|
||||
self.utt2shape = load_num_sequence_text(shape_file, loader_type)
|
||||
self.dtype = np.dtype(dtype)
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
|
||||
@ -19,7 +18,6 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
|
||||
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
data = {}
|
||||
with Path(path).open("r", encoding="utf-8") as f:
|
||||
@ -47,7 +45,6 @@ def load_num_sequence_text(
|
||||
>>> d = load_num_sequence_text('text')
|
||||
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if loader_type == "text_int":
|
||||
delimiter = " "
|
||||
dtype = int
|
||||
|
||||
@ -6,7 +6,6 @@ import random
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import librosa
|
||||
from typeguard import check_argument_types
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
@ -106,7 +105,6 @@ class SoundScpReader(collections.abc.Mapping):
|
||||
dest_sample_rate: int = 16000,
|
||||
speed_perturb: Union[list, tuple] = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.fname = fname
|
||||
self.dtype = dtype
|
||||
self.always_2d = always_2d
|
||||
@ -179,7 +177,6 @@ class SoundScpWriter:
|
||||
format="wav",
|
||||
dtype=None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.dir = Path(outdir)
|
||||
self.dir.mkdir(parents=True, exist_ok=True)
|
||||
scpfile = Path(scpfile)
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
||||
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
|
||||
@ -51,7 +50,6 @@ class ChunkIterFactory(AbsIterFactory):
|
||||
collate_fn=None,
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert all(len(x) == 1 for x in batches), "batch-size must be 1"
|
||||
|
||||
self.per_sample_iter_factory = SequenceIterFactory(
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Collection
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
||||
|
||||
@ -16,7 +15,6 @@ class MultipleIterFactory(AbsIterFactory):
|
||||
seed: int = 0,
|
||||
shuffle: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.build_funcs = list(build_funcs)
|
||||
self.seed = seed
|
||||
self.shuffle = shuffle
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.iterators.abs_iter_factory import AbsIterFactory
|
||||
from funasr.samplers.abs_sampler import AbsSampler
|
||||
@ -46,7 +45,6 @@ class SequenceIterFactory(AbsIterFactory):
|
||||
collate_fn=None,
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
if not isinstance(batches, AbsSampler):
|
||||
self.sampler = RawSampler(batches)
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
@ -28,7 +27,6 @@ class GlobalMVN(AbsNormalize, InversibleInterface):
|
||||
norm_vars: bool = True,
|
||||
eps: float = 1.0e-20,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
@ -13,7 +12,6 @@ class LabelAggregate(torch.nn.Module):
|
||||
hop_length: int = 128,
|
||||
center: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
self.win_length = win_length
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import math
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Sequence
|
||||
from typing import Union
|
||||
|
||||
@ -147,7 +146,6 @@ class MaskAlongAxis(torch.nn.Module):
|
||||
dim: Union[int, str] = "time",
|
||||
replace_with_zero: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if isinstance(mask_width_range, int):
|
||||
mask_width_range = (0, mask_width_range)
|
||||
if len(mask_width_range) != 2:
|
||||
@ -214,7 +212,6 @@ class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
|
||||
dim: Union[int, str] = "time",
|
||||
replace_with_zero: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if isinstance(mask_width_ratio_range, float):
|
||||
mask_width_ratio_range = (0.0, mask_width_ratio_range)
|
||||
if len(mask_width_ratio_range) != 2:
|
||||
@ -283,7 +280,6 @@ class MaskAlongAxisLFR(torch.nn.Module):
|
||||
replace_with_zero: bool = True,
|
||||
lfr_rate: int = 1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if isinstance(mask_width_range, int):
|
||||
mask_width_range = (0, mask_width_range)
|
||||
if len(mask_width_range) != 2:
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
"""Sinc convolutions."""
|
||||
import math
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Union
|
||||
|
||||
|
||||
@ -71,7 +70,6 @@ class SincConv(torch.nn.Module):
|
||||
window_func: Window function on the filter, one of ["hamming", "none"].
|
||||
fs (str, int, float): Sample rate of the input data
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
window_funcs = {
|
||||
"none": self.none_window,
|
||||
@ -208,7 +206,6 @@ class MelScale:
|
||||
torch.Tensor: Filter start frequencíes.
|
||||
torch.Tensor: Filter stop frequencies.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
# min and max bandpass edge frequencies
|
||||
min_frequency = torch.tensor(30.0)
|
||||
max_frequency = torch.tensor(fs * 0.5)
|
||||
@ -257,7 +254,6 @@ class BarkScale:
|
||||
torch.Tensor: Filter start frequencíes.
|
||||
torch.Tensor: Filter stop frequencíes.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
# min and max BARK center frequencies by approximation
|
||||
min_center_frequency = torch.tensor(70.0)
|
||||
max_center_frequency = torch.tensor(fs * 0.45)
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.layers.complex_utils import is_complex
|
||||
@ -30,7 +29,6 @@ class Stft(torch.nn.Module, InversibleInterface):
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
if win_length is None:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
@ -14,7 +13,6 @@ class UtteranceMVN(AbsNormalize):
|
||||
norm_vars: bool = False,
|
||||
eps: float = 1.0e-20,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
|
||||
@ -8,7 +8,6 @@ import os
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Collection
|
||||
|
||||
from funasr.train.reporter import Reporter
|
||||
@ -34,7 +33,6 @@ def average_nbest_models(
|
||||
nbest: Number of best model files to be averaged
|
||||
suffix: A suffix added to the averaged model file name
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if isinstance(nbest, int):
|
||||
nbests = [nbest]
|
||||
else:
|
||||
|
||||
@ -11,7 +11,6 @@ import numpy as np
|
||||
import torch
|
||||
from torch.nn.parallel import data_parallel
|
||||
from torch.utils.data import DataLoader
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
from funasr.fileio.npy_scp import NpyScpWriter
|
||||
@ -37,7 +36,6 @@ def collect_stats(
|
||||
This method is used before executing train().
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
npy_scp_writers = {}
|
||||
for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
class CTC(torch.nn.Module):
|
||||
@ -25,7 +24,6 @@ class CTC(torch.nn.Module):
|
||||
reduce: bool = True,
|
||||
ignore_nan_grad: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dropout_rate = dropout_rate
|
||||
@ -41,11 +39,6 @@ class CTC(torch.nn.Module):
|
||||
if ignore_nan_grad:
|
||||
logging.warning("ignore_nan_grad option is not supported for warp_ctc")
|
||||
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
|
||||
|
||||
elif self.ctc_type == "gtnctc":
|
||||
from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
|
||||
|
||||
self.ctc_loss = GTNCTCLossFunction.apply
|
||||
else:
|
||||
raise ValueError(
|
||||
f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'
|
||||
|
||||
@ -10,7 +10,6 @@ from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
@ -40,7 +39,6 @@ class Data2VecPretrainModel(FunASRModel):
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ import numpy as np
|
||||
|
||||
from funasr.modules.streaming_utils import utils as myutils
|
||||
from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
@ -126,7 +125,6 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
|
||||
kernel_size: int = 21,
|
||||
sanm_shfit: int = 0,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
|
||||
@ -3,7 +3,6 @@ import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.nets_utils import to_device
|
||||
@ -97,7 +96,6 @@ class RNNDecoder(AbsDecoder):
|
||||
att_conf: dict = get_default_kwargs(build_attention_list),
|
||||
):
|
||||
# FIXME(kamo): The parts of num_spk should be refactored more more more
|
||||
assert check_argument_types()
|
||||
if rnn_type not in {"lstm", "gru"}:
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
@ -38,7 +37,6 @@ class RNNTDecoder(torch.nn.Module):
|
||||
"""Construct a RNNDecoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
if rnn_type not in ("lstm", "gru"):
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
@ -7,7 +7,6 @@ import numpy as np
|
||||
|
||||
from funasr.modules.streaming_utils import utils as myutils
|
||||
from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
@ -181,7 +180,6 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
|
||||
embed_tensor_name_prefix_tf: str = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
@ -838,7 +836,6 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
tf2torch_tensor_name_prefix_torch: str = "decoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.modules.attention import MultiHeadedAttention
|
||||
@ -184,7 +183,6 @@ class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
@ -373,7 +371,6 @@ class TransformerDecoder(BaseTransformerDecoder):
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
@ -428,7 +425,6 @@ class ParaformerDecoderSAN(BaseTransformerDecoder):
|
||||
concat_after: bool = False,
|
||||
embeds_id: int = -1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
@ -540,7 +536,6 @@ class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
@ -602,7 +597,6 @@ class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
@ -664,7 +658,6 @@ class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
@ -726,7 +719,6 @@ class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
@ -781,7 +773,6 @@ class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
@ -955,7 +946,6 @@ class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
|
||||
@ -11,7 +11,6 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
@ -65,7 +64,6 @@ class ASRModel(FunASRModel):
|
||||
preencoder: Optional[AbsPreEncoder] = None,
|
||||
postencoder: Optional[AbsPostEncoder] = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Union
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.models.ctc import CTC
|
||||
@ -73,7 +72,6 @@ class NeatContextualParaformer(Paraformer):
|
||||
preencoder: Optional[AbsPreEncoder] = None,
|
||||
postencoder: Optional[AbsPostEncoder] = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.e2e_asr_common import ErrorCalculator
|
||||
from funasr.modules.nets_utils import th_accuracy
|
||||
@ -65,7 +64,6 @@ class MFCCA(FunASRModel):
|
||||
sym_blank: str = "<blank>",
|
||||
preencoder: Optional[AbsPreEncoder] = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert rnnt_decoder is None, "Not implemented"
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ from typing import Union
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
@ -80,7 +79,6 @@ class Paraformer(FunASRModel):
|
||||
postencoder: Optional[AbsPostEncoder] = None,
|
||||
use_1st_decoder_loss: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
@ -645,7 +643,6 @@ class ParaformerOnline(Paraformer):
|
||||
postencoder: Optional[AbsPostEncoder] = None,
|
||||
use_1st_decoder_loss: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
@ -1255,7 +1252,6 @@ class ParaformerBert(Paraformer):
|
||||
preencoder: Optional[AbsPreEncoder] = None,
|
||||
postencoder: Optional[AbsPostEncoder] = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
@ -1528,7 +1524,6 @@ class BiCifParaformer(Paraformer):
|
||||
preencoder: Optional[AbsPreEncoder] = None,
|
||||
postencoder: Optional[AbsPostEncoder] = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
@ -1806,7 +1801,6 @@ class ContextualParaformer(Paraformer):
|
||||
preencoder: Optional[AbsPreEncoder] = None,
|
||||
postencoder: Optional[AbsPostEncoder] = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import parse as V
|
||||
from typeguard import check_argument_types
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
@ -86,8 +85,6 @@ class TransducerModel(FunASRModel):
|
||||
"""Construct an ESPnetASRTransducerModel object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
|
||||
self.blank_id = 0
|
||||
self.vocab_size = vocab_size
|
||||
@ -546,8 +543,6 @@ class UnifiedTransducerModel(FunASRModel):
|
||||
"""Construct an ESPnetASRTransducerModel object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
|
||||
self.blank_id = 0
|
||||
|
||||
@ -713,7 +708,7 @@ class UnifiedTransducerModel(FunASRModel):
|
||||
loss_lm = self._calc_lm_loss(decoder_out, target)
|
||||
|
||||
loss_trans = loss_trans_utt + loss_trans_chunk
|
||||
loss_ctc = loss_ctc + loss_ctc_chunk
|
||||
loss_ctc = loss_ctc + loss_ctc_chunk
|
||||
loss_ctc = loss_att + loss_att_chunk
|
||||
|
||||
loss = (
|
||||
@ -1018,4 +1013,4 @@ class UnifiedTransducerModel(FunASRModel):
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
return loss_att, acc_att
|
||||
return loss_att, acc_att
|
||||
@ -9,7 +9,6 @@ from typing import Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
||||
@ -48,7 +47,6 @@ class DiarEENDOLAModel(FunASRModel):
|
||||
mapping_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
self.frontend = frontend
|
||||
|
||||
@ -12,7 +12,6 @@ from typing import Tuple, List
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import to_device
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
@ -66,7 +65,6 @@ class DiarSondModel(FunASRModel):
|
||||
inter_score_loss_weight: float = 0.0,
|
||||
inputs_type: str = "raw",
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
@ -67,7 +66,6 @@ class SAASRModel(FunASRModel):
|
||||
sym_blank: str = "<blank>",
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
@ -56,7 +55,6 @@ class ESPnetSVModel(FunASRModel):
|
||||
pooling_layer: torch.nn.Module,
|
||||
decoder: AbsDecoder,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
@ -42,7 +41,6 @@ class TimestampPredictor(FunASRModel):
|
||||
predictor_bias: int = 0,
|
||||
token_list=None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.e2e_asr_common import ErrorCalculator
|
||||
from funasr.modules.nets_utils import th_accuracy
|
||||
@ -82,7 +81,6 @@ class UniASR(FunASRModel):
|
||||
postencoder: Optional[AbsPostEncoder] = None,
|
||||
encoder1_encoder2_joint_training: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.modules.attention import (
|
||||
@ -533,7 +532,6 @@ class ConformerEncoder(AbsEncoder):
|
||||
interctc_use_conditioning: bool = False,
|
||||
stochastic_depth_rate: Union[float, List[float]] = 0.0,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
@ -943,7 +941,6 @@ class ConformerChunkEncoder(AbsEncoder):
|
||||
"""Construct an Encoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
self.embed = StreamingConvInput(
|
||||
input_size,
|
||||
|
||||
@ -10,7 +10,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.modules.data2vec.data_utils import compute_mask_indices
|
||||
@ -97,7 +96,6 @@ class Data2VecEncoder(AbsEncoder):
|
||||
# FP16 optimization
|
||||
required_seq_len_multiple: int = 2,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
# ConvFeatureExtractionModel
|
||||
|
||||
@ -5,7 +5,6 @@ import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
|
||||
from funasr.modules.nets_utils import get_activation
|
||||
@ -161,7 +160,6 @@ class MFCCAEncoder(AbsEncoder):
|
||||
cnn_module_kernel: int = 31,
|
||||
padding_idx: int = -1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
@ -90,7 +89,6 @@ class ConvEncoder(AbsEncoder):
|
||||
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = num_units
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
|
||||
@ -7,7 +7,6 @@ import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
|
||||
@ -144,7 +143,6 @@ class SelfAttentionEncoder(AbsEncoder):
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
out_units=None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.rnn.encoders import RNN
|
||||
@ -37,7 +36,6 @@ class RNNEncoder(AbsEncoder):
|
||||
dropout: float = 0.0,
|
||||
subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
self.rnn_type = rnn_type
|
||||
|
||||
@ -8,7 +8,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
@ -151,7 +150,6 @@ class SANMEncoder(AbsEncoder):
|
||||
tf2torch_tensor_name_prefix_torch: str = "encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
@ -601,7 +599,6 @@ class SANMEncoderChunkOpt(AbsEncoder):
|
||||
tf2torch_tensor_name_prefix_torch: str = "encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
@ -1060,7 +1057,6 @@ class SANMVadEncoder(AbsEncoder):
|
||||
sanm_shfit : int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typeguard import check_argument_types
|
||||
import logging
|
||||
|
||||
from funasr.models.ctc import CTC
|
||||
@ -189,7 +188,6 @@ class TransformerEncoder(AbsEncoder):
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ import humanfriendly
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.log_mel import LogMel
|
||||
from funasr.layers.stft import Stft
|
||||
@ -40,7 +39,6 @@ class DefaultFrontend(AbsFrontend):
|
||||
apply_stft: bool = True,
|
||||
use_channel: int = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
@ -167,7 +165,6 @@ class MultiChannelFrontend(AbsFrontend):
|
||||
cmvn_file: str = None,
|
||||
mc: bool = True
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
@ -3,7 +3,6 @@ from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.s3prl import S3prlFrontend
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@ -12,7 +11,6 @@ class FusedFrontends(AbsFrontend):
|
||||
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
|
||||
):
|
||||
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.align_method = (
|
||||
align_method # fusing method : linear_projection only for now
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.modules.frontends.frontend import Frontend
|
||||
@ -37,7 +36,6 @@ class S3prlFrontend(AbsFrontend):
|
||||
download_dir: str = None,
|
||||
multilayer_feature: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
@ -6,7 +6,6 @@ import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from typeguard import check_argument_types
|
||||
|
||||
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
@ -95,7 +94,6 @@ class WavFrontend(AbsFrontend):
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.window = window
|
||||
@ -227,7 +225,6 @@ class WavFrontendOnline(AbsFrontend):
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.window = window
|
||||
@ -466,7 +463,6 @@ class WavFrontendMel23(AbsFrontend):
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.frame_length = frame_length
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@ -38,7 +37,6 @@ class SlidingWindow(AbsFrontend):
|
||||
padding: Padding (placeholder, currently not implemented).
|
||||
fs: Sampling rate (placeholder for compatibility, not used).
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.win_length = win_length
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
import copy
|
||||
@ -30,7 +29,6 @@ class HuggingFaceTransformersPostEncoder(AbsPostEncoder):
|
||||
model_name_or_path: str,
|
||||
):
|
||||
"""Initialize the module."""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
if not is_transformers_available:
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
"""Linear Projection."""
|
||||
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
@ -20,7 +19,6 @@ class LinearProjection(AbsPreEncoder):
|
||||
output_size: int,
|
||||
):
|
||||
"""Initialize the module."""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
self.output_dim = output_size
|
||||
|
||||
@ -10,7 +10,6 @@ from funasr.layers.sinc_conv import LogCompression
|
||||
from funasr.layers.sinc_conv import SincConv
|
||||
import humanfriendly
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
@ -60,7 +59,6 @@ class LightweightSincConvs(AbsPreEncoder):
|
||||
windowing_type: Choice of windowing function.
|
||||
scale_type: Choice of filter-bank initialization scale.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
@ -268,7 +266,6 @@ class SpatialDropout(torch.nn.Module):
|
||||
dropout_probability: Dropout probability.
|
||||
shape (tuple, list): Shape of input tensors.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if shape is None:
|
||||
shape = (0, 2, 1)
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typeguard import check_argument_types
|
||||
from funasr.train.abs_model import AbsLM
|
||||
|
||||
|
||||
@ -27,7 +26,6 @@ class SequentialRNNLM(AbsLM):
|
||||
rnn_type: str = "lstm",
|
||||
ignore_id: int = 0,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
ninp = unit
|
||||
|
||||
@ -2,7 +2,7 @@ import copy
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
from eend.utils.power import create_powerlabel
|
||||
from funasr.modules.eend_ola.utils.power import create_powerlabel
|
||||
from itertools import combinations
|
||||
|
||||
metrics = [
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
class SGD(torch.optim.SGD):
|
||||
@ -21,7 +20,6 @@ class SGD(torch.optim.SGD):
|
||||
weight_decay: float = 0.0,
|
||||
nesterov: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
params,
|
||||
lr=lr,
|
||||
|
||||
214
funasr/runtime/SDK_advanced.md
Normal file
214
funasr/runtime/SDK_advanced.md
Normal file
@ -0,0 +1,214 @@
|
||||
# FunASR离线文件转写服务开发指南
|
||||
FunASR提供可一键本地或者云端服务器部署的中文离线文件转写服务,内核为FunASR已开源runtime-SDK。FunASR-runtime结合了达摩院语音实验室在Modelscope社区开源的语音端点检测(VAD)、Paraformer-large语音识别(ASR)、标点检测(PUNC) 等相关能力,可以准确、高效的对音频进行高并发转写。
|
||||
|
||||
本文档为FunASR离线文件转写服务开发指南。如果您想快速体验离线文件转写服务,请参考FunASR离线文件转写服务一键部署示例([点击此处](./SDK_tutorial.md))。
|
||||
## Docker安装
|
||||
下述步骤为手动安装docker及docker镜像的步骤,如您docker镜像已启动,可以忽略本步骤:
|
||||
### docker环境安装
|
||||
```shell
|
||||
# Ubuntu:
|
||||
curl -fsSL https://test.docker.com -o test-docker.sh
|
||||
sudo sh test-docker.sh
|
||||
# Debian:
|
||||
curl -fsSL https://get.docker.com -o get-docker.sh
|
||||
sudo sh get-docker.sh
|
||||
# CentOS:
|
||||
curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun
|
||||
# MacOS:
|
||||
brew install --cask --appdir=/Applications docker
|
||||
```
|
||||
安装详见:https://alibaba-damo-academy.github.io/FunASR/en/installation/docker.html
|
||||
### docker启动
|
||||
```shell
|
||||
sudo systemctl start docker
|
||||
```
|
||||
### 镜像拉取及启动
|
||||
通过下述命令拉取并启动FunASR runtime-SDK的docker镜像:
|
||||
```shell
|
||||
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.0.1
|
||||
|
||||
sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.0.1
|
||||
```
|
||||
```text
|
||||
-p <宿主机端口>:<映射到docker端口>
|
||||
如示例,宿主机(ecs)端口10095映射到docker端口10095上。前提是确保ecs安全规则打开了10095端口。
|
||||
-v <宿主机路径>:<挂载至docker路径>
|
||||
如示例,宿主机路径/root挂载至docker路径/workspace/models
|
||||
```
|
||||
|
||||
|
||||
## 服务端启动
|
||||
docker启动之后,启动 funasr-wss-server服务程序:
|
||||
|
||||
funasr-wss-server支持从Modelscope下载模型,需要设置同时设置模型下载地址(--download-model-dir)及model ID(--model-dir、--vad-dir、--punc-dir),示例如下:
|
||||
```shell
|
||||
cd /workspace/FunASR/funasr/runtime/websocket/build/bin
|
||||
|
||||
./funasr-wss-server \
|
||||
--download-model-dir /workspace/models \
|
||||
--model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
|
||||
--vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
|
||||
--punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
|
||||
--decoder-thread-num 32 \
|
||||
--io-thread-num 8 \
|
||||
--port 10095 \
|
||||
--certfile ../../../ssl_key/server.crt \
|
||||
--keyfile ../../../ssl_key/server.key
|
||||
```
|
||||
```text
|
||||
--download-model-dir #模型下载地址,通过设置model ID从Modelscope下载模型
|
||||
--model-dir # modelscope model ID
|
||||
--quantize # True为量化ASR模型,False为非量化ASR模型,默认是True
|
||||
--vad-dir # modelscope model ID
|
||||
--vad-quant # True为量化VAD模型,False为非量化VAD模型,默认是True
|
||||
--punc-dir # modelscope model ID
|
||||
--punc-quant # True为量化PUNC模型,False为非量化PUNC模型,默认是True
|
||||
--port # 服务端监听的端口号,默认为 8889
|
||||
--decoder-thread-num # 服务端启动的推理线程数,默认为 8
|
||||
--io-thread-num # 服务端启动的IO线程数,默认为 1
|
||||
--certfile <string> # ssl的证书文件,默认为:../../../ssl_key/server.crt
|
||||
--keyfile <string> # ssl的密钥文件,默认为:../../../ssl_key/server.key
|
||||
```
|
||||
funasr-wss-server同时也支持从本地路径加载模型(本地模型资源准备详见[模型资源准备](#anchor-1)),需要设置设置模型本地路径(--download-model-dir)示例如下:
|
||||
```shell
|
||||
cd /workspace/FunASR/funasr/runtime/websocket/build/bin
|
||||
|
||||
./funasr-wss-server \
|
||||
--model-dir /workspace/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
|
||||
--vad-dir /workspace/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
|
||||
--punc-dir /workspace/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
|
||||
--decoder-thread-num 32 \
|
||||
--io-thread-num 8 \
|
||||
--port 10095 \
|
||||
--certfile ../../../ssl_key/server.crt \
|
||||
--keyfile ../../../ssl_key/server.key
|
||||
```
|
||||
```text
|
||||
--model-dir # ASR模型路径D,默认为:/workspace/models/asr
|
||||
--quantize # True为量化ASR模型,False为非量化ASR模型,默认是True
|
||||
--vad-dir # VAD模型路径,默认为:/workspace/models/vad
|
||||
--vad-quant # True为量化VAD模型,False为非量化VAD模型,默认是True
|
||||
--punc-dir # PUNC模型路径,默认为:/workspace/models/punc
|
||||
--punc-quant # True为量化PUNC模型,False为非量化PUNC模型,默认是True
|
||||
--port # 服务端监听的端口号,默认为 8889
|
||||
--decoder-thread-num # 服务端启动的推理线程数,默认为 8
|
||||
--io-thread-num # 服务端启动的IO线程数,默认为 1
|
||||
--certfile <string> # ssl的证书文件,默认为:../../../ssl_key/server.crt
|
||||
--keyfile <string> # ssl的密钥文件,默认为:../../../ssl_key/server.key
|
||||
```
|
||||
|
||||
## <a id="anchor-1">模型资源准备</a>
|
||||
如果您选择通过funasr-wss-server从Modelscope下载模型,可以跳过本步骤。
|
||||
|
||||
FunASR离线文件转写服务中的vad、asr和punc模型资源均来自Modelscope,模型地址详见下表:
|
||||
|
||||
| 模型 | Modelscope链接 |
|
||||
|------|------------------------------------------------------------------------------------------------------------------|
|
||||
| VAD | https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary |
|
||||
| ASR | https://www.modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary |
|
||||
| PUNC | https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary |
|
||||
|
||||
离线文件转写服务中部署的是量化后的ONNX模型,下面介绍下如何导出ONNX模型及其量化:您可以选择从Modelscope导出ONNX模型、从本地文件导出ONNX模型或者从finetune后的资源导出模型:
|
||||
### 从Modelscope导出ONNX模型
|
||||
从Modelscope网站下载对应model name的模型,然后导出量化后的ONNX模型:
|
||||
```shell
|
||||
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
|
||||
```
|
||||
```text
|
||||
--model-name Modelscope上的模型名称,例如damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
||||
--export-dir ONNX模型导出地址
|
||||
--type 模型类型,目前支持 ONNX、torch
|
||||
--quantize int8模型量化
|
||||
```
|
||||
|
||||
### 从本地文件导出ONNX模型
|
||||
设置model name为模型本地路径,导出量化后的ONNX模型:
|
||||
```shell
|
||||
python -m funasr.export.export_model --model-name /workspace/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
|
||||
```
|
||||
```text
|
||||
--model-name 模型本地路径,例如/workspace/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
||||
--export-dir ONNX模型导出地址
|
||||
--type 模型类型,目前支持 ONNX、torch
|
||||
--quantize int8模型量化
|
||||
```
|
||||
### 从finetune后的资源导出模型
|
||||
假如您想部署finetune后的模型,可以参考如下步骤:
|
||||
|
||||
将您finetune后需要部署的模型(例如10epoch.pb),重命名为model.pb,并将原modelscope中模型model.pb替换掉,假如替换后的模型路径为/path/to/finetune/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch,通过下述命令把finetune后的模型转成onnx模型:
|
||||
```shell
|
||||
python -m funasr.export.export_model --model-name /path/to/finetune/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
|
||||
```
|
||||
|
||||
## 客户端启动
|
||||
在服务器上完成FunASR离线文件转写服务部署以后,可以通过如下的步骤来测试和使用离线文件转写服务。目前FunASR-bin支持多种方式启动客户端,如下是基于python-client、c++-client的命令行实例及自定义客户端Websocket通信协议:
|
||||
### python-client
|
||||
```shell
|
||||
python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "./data/wav.scp" --send_without_sleep --output_dir "./results"
|
||||
```
|
||||
```text
|
||||
--host # 服务端ip地址,本机测试可设置为 127.0.0.1
|
||||
--port # 服务端监听端口号
|
||||
--audio_in # 音频输入,输入可以是:wav路径 或者 wav.scp路径(kaldi格式的wav list,wav_id \t wav_path)
|
||||
--output_dir # 识别结果输出路径
|
||||
--ssl # 是否使用SSL加密,默认使用
|
||||
--mode # offline模式
|
||||
```
|
||||
|
||||
### c++-client:
|
||||
```shell
|
||||
. /funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
|
||||
```
|
||||
```text
|
||||
--server-ip # 服务端ip地址,本机测试可设置为 127.0.0.1
|
||||
--port # 服务端监听端口号
|
||||
--wav-path # 音频输入,输入可以是:wav路径 或者 wav.scp路径(kaldi格式的wav list,wav_id \t wav_path)
|
||||
--thread-num # 客户端线程数
|
||||
--is-ssl # 是否使用SSL加密,默认使用
|
||||
```
|
||||
|
||||
### 自定义客户端:
|
||||
如果您想定义自己的client,websocket通信协议为:
|
||||
```text
|
||||
# 首次通信
|
||||
{"mode": "offline", "wav_name": wav_name, "is_speaking": True}
|
||||
# 发送wav数据
|
||||
bytes数据
|
||||
# 发送结束标志
|
||||
{"is_speaking": False}
|
||||
```
|
||||
|
||||
## 如何定制服务部署
|
||||
FunASR-runtime的代码已开源,如果服务端和客户端不能很好的满足您的需求,您可以根据自己的需求进行进一步的开发:
|
||||
### c++ 客户端:
|
||||
https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/websocket
|
||||
###python 客户端:
|
||||
https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket
|
||||
### c++ 服务端:
|
||||
#### VAD
|
||||
```c++
|
||||
// VAD模型的使用分为FsmnVadInit和FsmnVadInfer两个步骤:
|
||||
FUNASR_HANDLE vad_hanlde=FsmnVadInit(model_path, thread_num);
|
||||
// 其中:model_path 包含"model-dir"、"quantize",thread_num为onnx线程数;
|
||||
FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000);
|
||||
// 其中:vad_hanlde为FunOfflineInit返回值,wav_file为音频路径,sampling_rate为采样率(默认16k)
|
||||
```
|
||||
使用示例详见:https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
|
||||
#### ASR
|
||||
```text
|
||||
// ASR模型的使用分为FunOfflineInit和FunOfflineInfer两个步骤:
|
||||
FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
|
||||
// 其中:model_path 包含"model-dir"、"quantize",thread_num为onnx线程数;
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, 16000);
|
||||
// 其中:asr_hanlde为FunOfflineInit返回值,wav_file为音频路径,sampling_rate为采样率(默认16k)
|
||||
```
|
||||
使用示例详见:https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
|
||||
#### PUNC
|
||||
```text
|
||||
// PUNC模型的使用分为CTTransformerInit和CTTransformerInfer两个步骤:
|
||||
FUNASR_HANDLE punc_hanlde=CTTransformerInit(model_path, thread_num);
|
||||
// 其中:model_path 包含"model-dir"、"quantize",thread_num为onnx线程数;
|
||||
FUNASR_RESULT result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
|
||||
// 其中:punc_hanlde为CTTransformerInit返回值,txt_str为文本
|
||||
```
|
||||
使用示例详见:https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
|
||||
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
import kaldi_native_fbank as knf
|
||||
|
||||
root_dir = Path(__file__).resolve().parent
|
||||
@ -28,7 +27,6 @@ class WavFrontend():
|
||||
dither: float = 1.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
check_argument_types()
|
||||
|
||||
opts = knf.FbankOptions()
|
||||
opts.frame_opts.samp_freq = fs
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
import warnings
|
||||
|
||||
@ -21,7 +20,6 @@ logger_initialized = {}
|
||||
class TokenIDConverter():
|
||||
def __init__(self, token_list: Union[List, str],
|
||||
):
|
||||
check_argument_types()
|
||||
|
||||
self.token_list = token_list
|
||||
self.unk_symbol = token_list[-1]
|
||||
@ -51,7 +49,6 @@ class CharTokenizer():
|
||||
space_symbol: str = "<space>",
|
||||
remove_non_linguistic_symbols: bool = False,
|
||||
):
|
||||
check_argument_types()
|
||||
|
||||
self.space_symbol = space_symbol
|
||||
self.non_linguistic_symbols = self.load_symbols(symbol_value)
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
import kaldi_native_fbank as knf
|
||||
|
||||
root_dir = Path(__file__).resolve().parent
|
||||
@ -29,7 +28,6 @@ class WavFrontend():
|
||||
dither: float = 1.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
check_argument_types()
|
||||
|
||||
opts = knf.FbankOptions()
|
||||
opts.frame_opts.samp_freq = fs
|
||||
|
||||
@ -10,7 +10,6 @@ import numpy as np
|
||||
import yaml
|
||||
from onnxruntime import (GraphOptimizationLevel, InferenceSession,
|
||||
SessionOptions, get_available_providers, get_device)
|
||||
from typeguard import check_argument_types
|
||||
|
||||
import warnings
|
||||
|
||||
@ -22,7 +21,6 @@ logger_initialized = {}
|
||||
class TokenIDConverter():
|
||||
def __init__(self, token_list: Union[List, str],
|
||||
):
|
||||
check_argument_types()
|
||||
|
||||
self.token_list = token_list
|
||||
self.unk_symbol = token_list[-1]
|
||||
@ -52,7 +50,6 @@ class CharTokenizer():
|
||||
space_symbol: str = "<space>",
|
||||
remove_non_linguistic_symbols: bool = False,
|
||||
):
|
||||
check_argument_types()
|
||||
|
||||
self.space_symbol = space_symbol
|
||||
self.non_linguistic_symbols = self.load_symbols(symbol_value)
|
||||
|
||||
@ -109,7 +109,6 @@ class WavFrontend():
|
||||
lfr_n: int = 6,
|
||||
dither: float = 1.0
|
||||
) -> None:
|
||||
# check_argument_types()
|
||||
|
||||
self.fs = fs
|
||||
self.window = window
|
||||
|
||||
@ -4,8 +4,6 @@ from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.samplers.abs_sampler import AbsSampler
|
||||
from funasr.samplers.folded_batch_sampler import FoldedBatchSampler
|
||||
@ -104,7 +102,6 @@ def build_batch_sampler(
|
||||
padding: Whether sequences are input as a padded tensor or not.
|
||||
used for "numel" mode
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if len(shape_files) == 0:
|
||||
raise ValueError("No shape file are given")
|
||||
|
||||
@ -164,5 +161,4 @@ def build_batch_sampler(
|
||||
|
||||
else:
|
||||
raise ValueError(f"Not supported: {type}")
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.read_text import load_num_sequence_text
|
||||
from funasr.fileio.read_text import read_2column_text
|
||||
@ -23,7 +22,6 @@ class FoldedBatchSampler(AbsSampler):
|
||||
drop_last: bool = False,
|
||||
utt2category_file: str = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert batch_size > 0
|
||||
if sort_batch != "ascending" and sort_batch != "descending":
|
||||
raise ValueError(
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Dict
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.read_text import load_num_sequence_text
|
||||
from funasr.samplers.abs_sampler import AbsSampler
|
||||
@ -21,7 +20,6 @@ class LengthBatchSampler(AbsSampler):
|
||||
drop_last: bool = False,
|
||||
padding: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert batch_bins > 0
|
||||
if sort_batch != "ascending" and sort_batch != "descending":
|
||||
raise ValueError(
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.read_text import load_num_sequence_text
|
||||
from funasr.samplers.abs_sampler import AbsSampler
|
||||
@ -21,7 +20,6 @@ class NumElementsBatchSampler(AbsSampler):
|
||||
drop_last: bool = False,
|
||||
padding: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert batch_bins > 0
|
||||
if sort_batch != "ascending" and sort_batch != "descending":
|
||||
raise ValueError(
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
from typing import Iterator
|
||||
from typing import Tuple
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.read_text import load_num_sequence_text
|
||||
from funasr.samplers.abs_sampler import AbsSampler
|
||||
@ -26,7 +25,6 @@ class SortedBatchSampler(AbsSampler):
|
||||
sort_batch: str = "ascending",
|
||||
drop_last: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert batch_size > 0
|
||||
self.batch_size = batch_size
|
||||
self.shape_file = shape_file
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
from typing import Iterator
|
||||
from typing import Tuple
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.read_text import read_2column_text
|
||||
from funasr.samplers.abs_sampler import AbsSampler
|
||||
@ -28,7 +27,6 @@ class UnsortedBatchSampler(AbsSampler):
|
||||
drop_last: bool = False,
|
||||
utt2category_file: str = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert batch_size > 0
|
||||
self.batch_size = batch_size
|
||||
self.key_file = key_file
|
||||
|
||||
@ -4,7 +4,6 @@ import warnings
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
|
||||
|
||||
@ -31,7 +30,6 @@ class NoamLR(_LRScheduler, AbsBatchStepScheduler):
|
||||
warmup_steps: Union[int, float] = 25000,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.model_size = model_size
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Optional, List
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
|
||||
|
||||
@ -22,7 +21,6 @@ class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
|
||||
init_lr_scale: float = 0.01,
|
||||
final_lr_scale: float = 0.01,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.optimizer = optimizer
|
||||
self.last_epoch = last_epoch
|
||||
self.phase_ratio = phase_ratio
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
|
||||
|
||||
@ -30,7 +29,6 @@ class WarmupLR(_LRScheduler, AbsBatchStepScheduler):
|
||||
warmup_steps: Union[int, float] = 25000,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
# __init__() must be invoked before setting field
|
||||
|
||||
@ -32,8 +32,6 @@ import torch.optim
|
||||
import yaml
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from torch.utils.data import DataLoader
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr import __version__
|
||||
from funasr.datasets.dataset import AbsDataset
|
||||
@ -269,7 +267,6 @@ class AbsTask(ABC):
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls) -> config_argparse.ArgumentParser:
|
||||
assert check_argument_types()
|
||||
|
||||
class ArgumentDefaultsRawTextHelpFormatter(
|
||||
argparse.RawTextHelpFormatter,
|
||||
@ -959,7 +956,6 @@ class AbsTask(ABC):
|
||||
cls.trainer.add_arguments(parser)
|
||||
cls.add_task_arguments(parser)
|
||||
|
||||
assert check_return_type(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -1007,7 +1003,6 @@ class AbsTask(ABC):
|
||||
return _cls
|
||||
|
||||
# This method is used only for --print_config
|
||||
assert check_argument_types()
|
||||
parser = cls.get_parser()
|
||||
args, _ = parser.parse_known_args()
|
||||
config = vars(args)
|
||||
@ -1047,7 +1042,6 @@ class AbsTask(ABC):
|
||||
|
||||
@classmethod
|
||||
def check_required_command_args(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
if hasattr(args, "required"):
|
||||
for k in vars(args):
|
||||
if "-" in k:
|
||||
@ -1077,7 +1071,6 @@ class AbsTask(ABC):
|
||||
inference: bool = False,
|
||||
) -> None:
|
||||
"""Check if the dataset satisfy the requirement of current Task"""
|
||||
assert check_argument_types()
|
||||
mes = (
|
||||
f"If you intend to use an additional input, modify "
|
||||
f'"{cls.__name__}.required_data_names()" or '
|
||||
@ -1104,14 +1097,12 @@ class AbsTask(ABC):
|
||||
|
||||
@classmethod
|
||||
def print_config(cls, file=sys.stdout) -> None:
|
||||
assert check_argument_types()
|
||||
# Shows the config: e.g. python train.py asr --print_config
|
||||
config = cls.get_default_config()
|
||||
file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False))
|
||||
|
||||
@classmethod
|
||||
def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
|
||||
assert check_argument_types()
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
if args is None:
|
||||
parser = cls.get_parser()
|
||||
@ -1148,7 +1139,6 @@ class AbsTask(ABC):
|
||||
|
||||
@classmethod
|
||||
def main_worker(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
|
||||
# 0. Init distributed process
|
||||
distributed_option = build_dataclass(DistributedOption, args)
|
||||
@ -1556,7 +1546,6 @@ class AbsTask(ABC):
|
||||
- 4 epoch with "--num_iters_per_epoch" == 4
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
iter_options = cls.build_iter_options(args, distributed_option, mode)
|
||||
|
||||
# Overwrite iter_options if any kwargs is given
|
||||
@ -1589,7 +1578,6 @@ class AbsTask(ABC):
|
||||
def build_sequence_iter_factory(
|
||||
cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str
|
||||
) -> AbsIterFactory:
|
||||
assert check_argument_types()
|
||||
|
||||
if hasattr(args, "frontend_conf"):
|
||||
if args.frontend_conf is not None and "fs" in args.frontend_conf:
|
||||
@ -1683,7 +1671,6 @@ class AbsTask(ABC):
|
||||
iter_options: IteratorOptions,
|
||||
mode: str,
|
||||
) -> AbsIterFactory:
|
||||
assert check_argument_types()
|
||||
|
||||
dataset = ESPnetDataset(
|
||||
iter_options.data_path_and_name_and_type,
|
||||
@ -1788,7 +1775,6 @@ class AbsTask(ABC):
|
||||
def build_multiple_iter_factory(
|
||||
cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str
|
||||
):
|
||||
assert check_argument_types()
|
||||
iter_options = cls.build_iter_options(args, distributed_option, mode)
|
||||
assert len(iter_options.data_path_and_name_and_type) > 0, len(
|
||||
iter_options.data_path_and_name_and_type
|
||||
@ -1885,7 +1871,6 @@ class AbsTask(ABC):
|
||||
inference: bool = False,
|
||||
) -> DataLoader:
|
||||
"""Build DataLoader using iterable dataset"""
|
||||
assert check_argument_types()
|
||||
# For backward compatibility for pytorch DataLoader
|
||||
if collate_fn is not None:
|
||||
kwargs = dict(collate_fn=collate_fn)
|
||||
@ -1935,7 +1920,6 @@ class AbsTask(ABC):
|
||||
device: Device type, "cpu", "cuda", or "cuda:N".
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if config_file is None:
|
||||
assert model_file is not None, (
|
||||
"The argument 'model_file' must be provided "
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user