Merge branch 'main' of github.com:alibaba-damo-academy/FunASR

add
This commit is contained in:
游雁 2023-06-29 17:20:46 +08:00
commit 1bdb956318
123 changed files with 235 additions and 689 deletions

View File

@ -11,7 +11,6 @@ import numpy as np
import resampy
import soundfile
from tqdm import tqdm
from typeguard import check_argument_types
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.read_text import read_2column_text
@ -31,7 +30,6 @@ def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
(3, 4, 5)
"""
assert check_argument_types()
if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
return None
return tuple(map(int, integers.strip().split(",")))

View File

@ -1,5 +1,4 @@
import os
<<<<<<< HEAD
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
@ -21,50 +20,17 @@ def modelscope_finetune(params):
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
lr=params.lr)
=======
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from funasr.datasets.ms_dataset import MsDataset
def modelscope_finetune(params):
if not os.path.exists(params["output_dir"]):
os.makedirs(params["output_dir"], exist_ok=True)
# dataset split ["train", "validation"]
ds_dict = MsDataset.load(params["data_dir"])
kwargs = dict(
model=params["model"],
model_revision=params["model_revision"],
data_dir=ds_dict,
dataset_type=params["dataset_type"],
work_dir=params["output_dir"],
batch_bins=params["batch_bins"],
max_epoch=params["max_epoch"],
lr=params["lr"])
>>>>>>> main
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
<<<<<<< HEAD
params = modelscope_args(model="damo/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch", data_path="./data")
params.output_dir = "./checkpoint" # m模型保存路径
params.data_path = "./example_data/" # 数据路径
params.dataset_type = "small" # 小数据量设置small若数据量大于1000小时请使用large
params.batch_bins = 2000 # batch size如果dataset_type="small"batch_bins单位为fbank特征帧数如果dataset_type="large"batch_bins单位为毫秒
params.max_epoch = 50 # 最大训练轮数
params.max_epoch = 20 # 最大训练轮数
params.lr = 0.00005 # 设置学习率
=======
params = {}
params["output_dir"] = "./checkpoint"
params["data_dir"] = "./data"
params["batch_bins"] = 2000
params["dataset_type"] = "small"
params["max_epoch"] = 50
params["lr"] = 0.00005
params["model"] = "damo/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch"
params["model_revision"] = None
>>>>>>> main
modelscope_finetune(params)
modelscope_finetune(params)

View File

@ -1,33 +1,3 @@
<<<<<<< HEAD
import os
import shutil
import argparse
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
def modelscope_infer(args):
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model=args.model,
output_dir=args.output_dir,
batch_size=args.batch_size,
param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
)
inference_pipeline(audio_in=args.audio_in)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="damo/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch")
parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
parser.add_argument('--output_dir', type=str, default="./results/")
parser.add_argument('--decoding_mode', type=str, default="normal")
parser.add_argument('--hotword_txt', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--gpuid', type=str, default="0")
args = parser.parse_args()
modelscope_infer(args)
=======
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
@ -40,5 +10,4 @@ if __name__ == "__main__":
output_dir=output_dir,
)
rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
>>>>>>> main
print(rec_result)

View File

@ -22,9 +22,7 @@ import numpy as np
import requests
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@ -78,7 +76,6 @@ class Speech2Text:
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
@ -192,7 +189,6 @@ class Speech2Text:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@ -248,7 +244,6 @@ class Speech2Text:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
@ -288,7 +283,6 @@ class Speech2TextParaformer:
decoding_ind: int = 0,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
@ -413,7 +407,6 @@ class Speech2TextParaformer:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@ -516,7 +509,6 @@ class Speech2TextParaformer:
vad_offset=begin_time)
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
# assert check_return_type(results)
return results
def generate_hotwords_list(self, hotword_list_or_file):
@ -656,7 +648,6 @@ class Speech2TextParaformerOnline:
hotword_list_or_file: str = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
@ -776,7 +767,6 @@ class Speech2TextParaformerOnline:
text, token, token_int, hyp
"""
assert check_argument_types()
results = []
cache_en = cache["encoder"]
if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
@ -871,7 +861,6 @@ class Speech2TextParaformerOnline:
results.append(postprocessed_result)
# assert check_return_type(results)
return results
@ -912,7 +901,6 @@ class Speech2TextUniASR:
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
@ -1036,7 +1024,6 @@ class Speech2TextUniASR:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@ -1104,7 +1091,6 @@ class Speech2TextUniASR:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
@ -1143,7 +1129,6 @@ class Speech2TextMFCCA:
streaming: bool = False,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
@ -1248,7 +1233,6 @@ class Speech2TextMFCCA:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@ -1298,7 +1282,6 @@ class Speech2TextMFCCA:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
@ -1355,7 +1338,6 @@ class Speech2TextTransducer:
"""Construct a Speech2Text object."""
super().__init__()
assert check_argument_types()
asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
@ -1534,7 +1516,6 @@ class Speech2TextTransducer:
Returns:
nbest_hypothesis: N-best hypothesis.
"""
assert check_argument_types()
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@ -1566,7 +1547,6 @@ class Speech2TextTransducer:
Returns:
nbest_hypothesis: N-best hypothesis.
"""
assert check_argument_types()
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@ -1608,36 +1588,9 @@ class Speech2TextTransducer:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
) -> Speech2Text:
"""Build Speech2Text instance from the pretrained model.
Args:
model_tag: Model tag of the pretrained models.
Return:
: Speech2Text instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
"Please install via `pip install -U espnet_model_zoo`."
)
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
"""Speech2Text class
@ -1675,7 +1628,6 @@ class Speech2TextSAASR:
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
@ -1793,7 +1745,6 @@ class Speech2TextSAASR:
text, text_id, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@ -1886,5 +1837,4 @@ class Speech2TextSAASR:
results.append((text, text_id, token, token_int, hyp))
assert check_return_type(results)
return results

View File

@ -21,7 +21,6 @@ import torch
import torchaudio
import soundfile
import yaml
from typeguard import check_argument_types
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextMFCCA
@ -80,7 +79,6 @@ def inference_asr(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@ -240,7 +238,6 @@ def inference_paraformer(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
@ -481,7 +478,6 @@ def inference_paraformer_vad_punc(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
@ -749,7 +745,6 @@ def inference_paraformer_online(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
@ -957,7 +952,6 @@ def inference_uniasr(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@ -1126,7 +1120,6 @@ def inference_mfcca(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@ -1314,7 +1307,6 @@ def inference_transducer(
right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses.
"""
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
@ -1464,7 +1456,6 @@ def inference_sa_asr(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:

View File

@ -15,7 +15,6 @@ import numpy as np
import torch
from scipy.ndimage import median_filter
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.tasks.diar import DiarTask
@ -45,7 +44,6 @@ class Speech2DiarizationEEND:
device: str = "cpu",
dtype: str = "float32",
):
assert check_argument_types()
# 1. Build Diarization model
diar_model, diar_train_args = build_model_from_file(
@ -88,7 +86,6 @@ class Speech2DiarizationEEND:
diarization results
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@ -107,36 +104,6 @@ class Speech2DiarizationEEND:
return results
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
):
"""Build Speech2Diarization instance from the pretrained model.
Args:
model_tag (Optional[str]): Model tag of the pretrained models.
Currently, the tags of espnet_model_zoo are supported.
Returns:
Speech2Diarization: Speech2Diarization instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
"Please install via `pip install -U espnet_model_zoo`."
)
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2DiarizationEEND(**kwargs)
class Speech2DiarizationSOND:
"""Speech2Xvector class
@ -163,7 +130,6 @@ class Speech2DiarizationSOND:
smooth_size: int = 83,
dur_threshold: float = 10,
):
assert check_argument_types()
# TODO: 1. Build Diarization model
diar_model, diar_train_args = build_model_from_file(
@ -283,7 +249,6 @@ class Speech2DiarizationSOND:
diarization results for each speaker
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@ -305,33 +270,3 @@ class Speech2DiarizationSOND:
results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
return results, pse_labels
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
):
"""Build Speech2Xvector instance from the pretrained model.
Args:
model_tag (Optional[str]): Model tag of the pretrained models.
Currently, the tags of espnet_model_zoo are supported.
Returns:
Speech2Xvector: Speech2Xvector instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
"Please install via `pip install -U espnet_model_zoo`."
)
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2DiarizationSOND(**kwargs)

View File

@ -18,7 +18,6 @@ import numpy as np
import soundfile
import torch
from scipy.signal import medfilt
from typeguard import check_argument_types
from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
from funasr.datasets.iterable_dataset import load_bytes
@ -52,7 +51,6 @@ def inference_sond(
mode: str = "sond",
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@ -233,7 +231,6 @@ def inference_eend(
param_dict: Optional[dict] = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:

View File

@ -15,7 +15,6 @@ from typing import Union
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
@ -50,7 +49,6 @@ def inference_lm(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)

View File

@ -14,7 +14,6 @@ from typing import Optional
from typing import Union
import torch
from typeguard import check_argument_types
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
@ -38,7 +37,6 @@ def inference_punc(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@ -118,7 +116,6 @@ def inference_punc_vad_realtime(
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)

View File

@ -12,8 +12,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
@ -42,7 +40,6 @@ class Speech2Xvector:
streaming: bool = False,
embedding_node: str = "resnet1_dense",
):
assert check_argument_types()
# TODO: 1. Build SV model
sv_model, sv_train_args = build_model_from_file(
@ -108,7 +105,6 @@ class Speech2Xvector:
embedding, ref_embedding, similarity_score
"""
assert check_argument_types()
self.sv_model.eval()
embedding = self.calculate_embedding(speech)
ref_emb, score = None, None
@ -117,35 +113,4 @@ class Speech2Xvector:
score = torch.cosine_similarity(embedding, ref_emb)
results = (embedding, ref_emb, score)
assert check_return_type(results)
return results
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
):
"""Build Speech2Xvector instance from the pretrained model.
Args:
model_tag (Optional[str]): Model tag of the pretrained models.
Currently, the tags of espnet_model_zoo are supported.
Returns:
Speech2Xvector: Speech2Xvector instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
"`espnet_model_zoo` is not installed. "
"Please install via `pip install -U espnet_model_zoo`."
)
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2Xvector(**kwargs)

View File

@ -15,7 +15,6 @@ from typing import Union
import numpy as np
import torch
from kaldiio import WriteHelper
from typeguard import check_argument_types
from funasr.bin.sv_infer import Speech2Xvector
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
@ -46,7 +45,6 @@ def inference_sv(
param_dict: Optional[dict] = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
@ -79,10 +77,7 @@ def inference_sv(
embedding_node=embedding_node
)
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
speech2xvector = Speech2Xvector.from_pretrained(
model_tag=model_tag,
**speech2xvector_kwargs,
)
speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
speech2xvector.sv_model.eval()
def _forward(

View File

@ -7,7 +7,6 @@ import sys
from typing import List
from typing import Optional
from typeguard import check_argument_types
from funasr.utils.cli_utils import get_commandline_args
from funasr.text.build_tokenizer import build_tokenizer
@ -81,7 +80,6 @@ def tokenize(
cleaner: Optional[str],
g2p: Optional[str],
):
assert check_argument_types()
logging.basicConfig(
level=log_level,

View File

@ -9,7 +9,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
@ -26,7 +25,6 @@ class Speech2Timestamp:
dtype: str = "float32",
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
tp_model, tp_train_args = build_model_from_file(
timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp"
@ -64,7 +62,6 @@ class Speech2Timestamp:
speech_lengths: Union[torch.Tensor, np.ndarray] = None,
text_lengths: Union[torch.Tensor, np.ndarray] = None
):
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):

View File

@ -13,7 +13,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
@ -47,7 +46,6 @@ def inference_tp(
seg_dict_file: Optional[str] = None,
**kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)

View File

@ -13,7 +13,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@ -42,7 +41,6 @@ class Speech2VadSegment:
dtype: str = "float32",
**kwargs,
):
assert check_argument_types()
# 1. Build vad model
vad_model, vad_infer_args = build_model_from_file(
@ -76,7 +74,6 @@ class Speech2VadSegment:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@ -149,7 +146,6 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):

View File

@ -18,7 +18,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
@ -47,7 +46,6 @@ def inference_vad(
num_workers: int = 1,
**kwargs,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
@ -148,7 +146,6 @@ def inference_vad_online(
num_workers: int = 1,
**kwargs,
):
assert check_argument_types()
logging.basicConfig(
level=log_level,

View File

@ -6,7 +6,6 @@ from typing import Union
import torch
import yaml
from typeguard import check_argument_types
from funasr.build_utils.build_model import build_model
from funasr.models.base_model import FunASRModel
@ -30,7 +29,6 @@ def build_model_from_file(
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "

View File

@ -1,6 +1,5 @@
import numpy as np
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr.datasets.iterable_dataset import IterableESPnetDataset
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
@ -23,7 +22,6 @@ def build_streaming_iterator(
train: bool = False,
) -> DataLoader:
"""Build DataLoader using iterable dataset"""
assert check_argument_types()
# preprocess
if preprocess_fn is not None:

View File

@ -1,7 +1,6 @@
import logging
import torch
from typeguard import check_return_type
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
@ -254,5 +253,4 @@ def build_sv_model(args):
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model

View File

@ -25,7 +25,6 @@ import oss2
import torch
import torch.nn
import torch.optim
from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.main_funcs.average_nbest_models import average_nbest_models
@ -118,7 +117,6 @@ class Trainer:
def build_options(self, args: argparse.Namespace) -> TrainerOptions:
"""Build options consumed by train(), eval()"""
assert check_argument_types()
return build_dataclass(TrainerOptions, args)
@classmethod
@ -156,7 +154,6 @@ class Trainer:
def run(self) -> None:
"""Perform training. This method performs the main process of training."""
assert check_argument_types()
# NOTE(kamo): Don't check the type more strictly as far trainer_options
model = self.model
optimizers = self.optimizers
@ -522,7 +519,6 @@ class Trainer:
options: TrainerOptions,
distributed_option: DistributedOption,
) -> Tuple[bool, bool]:
assert check_argument_types()
grad_noise = options.grad_noise
accum_grad = options.accum_grad
@ -758,7 +754,6 @@ class Trainer:
options: TrainerOptions,
distributed_option: DistributedOption,
) -> None:
assert check_argument_types()
ngpu = options.ngpu
# no_forward_run = options.no_forward_run
distributed = distributed_option.distributed

View File

@ -6,8 +6,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.modules.nets_utils import pad_list
@ -22,7 +20,6 @@ class CommonCollateFn:
not_sequence: Collection[str] = (),
max_sample_size=None
):
assert check_argument_types()
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
@ -53,7 +50,6 @@ def common_collate_fn(
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
"""
assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
@ -79,7 +75,6 @@ def common_collate_fn(
output[key + "_lengths"] = lens
output = (uttids, output)
assert check_return_type(output)
return output
def crop_to_max_size(feature, target_size):
@ -99,7 +94,6 @@ def clipping_collate_fn(
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
# mainly for pre-training
assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
@ -131,5 +125,4 @@ def clipping_collate_fn(
output[key + "_lengths"] = lens
output = (uttids, output)
assert check_return_type(output)
return output

View File

@ -23,8 +23,6 @@ import kaldiio
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.npy_scp import NpyScpReader
from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
@ -37,7 +35,6 @@ from funasr.utils.sized_dict import SizedDict
class AdapterForSoundScpReader(collections.abc.Mapping):
def __init__(self, loader, dtype=None):
assert check_argument_types()
self.loader = loader
self.dtype = dtype
self.rate = None
@ -284,7 +281,6 @@ class ESPnetDataset(AbsDataset):
max_cache_fd: int = 0,
dest_sample_rate: int = 16000,
):
assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'
@ -379,7 +375,6 @@ class ESPnetDataset(AbsDataset):
return _mes
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
assert check_argument_types()
# Change integer-id to string-id
if isinstance(uid, int):
@ -444,5 +439,4 @@ class ESPnetDataset(AbsDataset):
self.cache[uid] = data
retval = uid, data
assert check_return_type(retval)
return retval

View File

@ -16,7 +16,6 @@ import torch
import torchaudio
import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
from funasr.datasets.dataset import ESPnetDataset
@ -121,7 +120,6 @@ class IterableESPnetDataset(IterableDataset):
int_dtype: str = "long",
key_file: str = None,
):
assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'

View File

@ -6,7 +6,6 @@ from typing import Union
import sentencepiece as spm
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr.datasets.large_datasets.dataset import Dataset
from funasr.iterators.abs_iter_factory import AbsIterFactory
@ -43,7 +42,6 @@ def load_seg_dict(seg_dict_file):
class SentencepiecesTokenizer(AbsTokenizer):
def __init__(self, model: Union[Path, str]):
assert check_argument_types()
self.model = str(model)
self.sp = None

View File

@ -11,8 +11,6 @@ from typing import Union
import numpy as np
import scipy.signal
import soundfile
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
@ -268,7 +266,6 @@ class CommonPreprocessor(AbsPreprocessor):
def _speech_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, Union[str, np.ndarray]]:
assert check_argument_types()
if self.speech_name in data:
if self.train and (self.rirs is not None or self.noises is not None):
speech = data[self.speech_name]
@ -355,7 +352,6 @@ class CommonPreprocessor(AbsPreprocessor):
speech = data[self.speech_name]
ma = np.max(np.abs(speech))
data[self.speech_name] = speech * self.speech_volume_normalize / ma
assert check_return_type(data)
return data
def _text_process(
@ -372,13 +368,11 @@ class CommonPreprocessor(AbsPreprocessor):
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
assert check_argument_types()
data = self._speech_process(data)
data = self._text_process(data)
@ -445,7 +439,6 @@ class LMPreprocessor(CommonPreprocessor):
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
@ -502,13 +495,11 @@ class CommonPreprocessor_multi(AbsPreprocessor):
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[text_n] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
assert check_argument_types()
if self.speech_name in data:
# Nothing now: candidates:
@ -612,7 +603,6 @@ class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
tokens = self.tokenizer[i].text2tokens(text)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
@ -690,7 +680,6 @@ class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
def __call__(
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
) -> Dict[str, Union[list, np.ndarray]]:
assert check_argument_types()
# Split words.
if isinstance(data[self.text_name], str):
split_text = self.split_words(data[self.text_name])

View File

@ -6,8 +6,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.modules.nets_utils import pad_list
@ -22,7 +20,6 @@ class CommonCollateFn:
not_sequence: Collection[str] = (),
max_sample_size=None
):
assert check_argument_types()
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
@ -53,7 +50,6 @@ def common_collate_fn(
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
"""
assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
@ -79,7 +75,6 @@ def common_collate_fn(
output[key + "_lengths"] = lens
output = (uttids, output)
assert check_return_type(output)
return output
def crop_to_max_size(feature, target_size):

View File

@ -15,8 +15,6 @@ import kaldiio
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.npy_scp import NpyScpReader
from funasr.fileio.sound_scp import SoundScpReader
@ -24,7 +22,6 @@ from funasr.fileio.sound_scp import SoundScpReader
class AdapterForSoundScpReader(collections.abc.Mapping):
def __init__(self, loader, dtype=None):
assert check_argument_types()
self.loader = loader
self.dtype = dtype
self.rate = None
@ -112,7 +109,6 @@ class ESPnetDataset(Dataset):
speed_perturb: Union[list, tuple] = None,
mode: str = "train",
):
assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'
@ -207,7 +203,6 @@ class ESPnetDataset(Dataset):
return _mes
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
assert check_argument_types()
# Change integer-id to string-id
if isinstance(uid, int):
@ -265,5 +260,4 @@ class ESPnetDataset(Dataset):
data[name] = value
retval = uid, data
assert check_return_type(retval)
return retval

View File

@ -4,7 +4,6 @@ from typing import Dict
from typing import Tuple
from typing import Union
from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.samplers.abs_sampler import AbsSampler
@ -21,7 +20,6 @@ class LengthBatchSampler(AbsSampler):
drop_last: bool = False,
padding: bool = True,
):
assert check_argument_types()
assert batch_bins > 0
if sort_batch != "ascending" and sort_batch != "descending":
raise ValueError(

View File

@ -10,8 +10,6 @@ from typing import Union
import numpy as np
import scipy.signal
import soundfile
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
@ -260,7 +258,6 @@ class CommonPreprocessor(AbsPreprocessor):
def _speech_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, Union[str, np.ndarray]]:
assert check_argument_types()
if self.speech_name in data:
if self.train and (self.rirs is not None or self.noises is not None):
speech = data[self.speech_name]
@ -347,7 +344,6 @@ class CommonPreprocessor(AbsPreprocessor):
speech = data[self.speech_name]
ma = np.max(np.abs(speech))
data[self.speech_name] = speech * self.speech_volume_normalize / ma
assert check_return_type(data)
return data
def _text_process(
@ -365,13 +361,11 @@ class CommonPreprocessor(AbsPreprocessor):
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
assert check_argument_types()
data = self._speech_process(data)
data = self._text_process(data)
@ -439,7 +433,6 @@ class LMPreprocessor(CommonPreprocessor):
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
@ -496,13 +489,11 @@ class CommonPreprocessor_multi(AbsPreprocessor):
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[text_n] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
assert check_argument_types()
if self.speech_name in data:
# Nothing now: candidates:
@ -606,7 +597,6 @@ class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
tokens = self.tokenizer[i].text2tokens(text)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
@ -685,7 +675,6 @@ class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
def __call__(
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
) -> Dict[str, Union[list, np.ndarray]]:
assert check_argument_types()
# Split words.
if isinstance(data[self.text_name], str):
split_text = self.split_words(data[self.text_name])

View File

@ -1,7 +1,6 @@
import json
from typing import Union, Dict
from pathlib import Path
from typeguard import check_argument_types
import os
import logging
@ -26,7 +25,6 @@ class ModelExport:
calib_num: int = 200,
model_revision: str = None,
):
assert check_argument_types()
self.set_all_random_seed(0)
self.cache_dir = cache_dir

View File

@ -2,8 +2,6 @@ from pathlib import Path
from typing import Union
import warnings
from typeguard import check_argument_types
from typeguard import check_return_type
class DatadirWriter:
@ -20,7 +18,6 @@ class DatadirWriter:
"""
def __init__(self, p: Union[Path, str]):
assert check_argument_types()
self.path = Path(p)
self.chilidren = {}
self.fd = None
@ -31,7 +28,6 @@ class DatadirWriter:
return self
def __getitem__(self, key: str) -> "DatadirWriter":
assert check_argument_types()
if self.fd is not None:
raise RuntimeError("This writer points out a file")
@ -41,11 +37,9 @@ class DatadirWriter:
self.has_children = True
retval = self.chilidren[key]
assert check_return_type(retval)
return retval
def __setitem__(self, key: str, value: str):
assert check_argument_types()
if self.has_children:
raise RuntimeError("This writer points out a directory")
if key in self.keys:

View File

@ -3,7 +3,6 @@ from pathlib import Path
from typing import Union
import numpy as np
from typeguard import check_argument_types
from funasr.fileio.read_text import read_2column_text
@ -25,7 +24,6 @@ class NpyScpWriter:
"""
def __init__(self, outdir: Union[Path, str], scpfile: Union[Path, str]):
assert check_argument_types()
self.dir = Path(outdir)
self.dir.mkdir(parents=True, exist_ok=True)
scpfile = Path(scpfile)
@ -73,7 +71,6 @@ class NpyScpReader(collections.abc.Mapping):
"""
def __init__(self, fname: Union[Path, str]):
assert check_argument_types()
self.fname = Path(fname)
self.data = read_2column_text(fname)

View File

@ -3,7 +3,6 @@ from pathlib import Path
from typing import Union
import numpy as np
from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
@ -29,7 +28,6 @@ class FloatRandomGenerateDataset(collections.abc.Mapping):
dtype: Union[str, np.dtype] = "float32",
loader_type: str = "csv_int",
):
assert check_argument_types()
shape_file = Path(shape_file)
self.utt2shape = load_num_sequence_text(shape_file, loader_type)
self.dtype = np.dtype(dtype)
@ -68,7 +66,6 @@ class IntRandomGenerateDataset(collections.abc.Mapping):
dtype: Union[str, np.dtype] = "int64",
loader_type: str = "csv_int",
):
assert check_argument_types()
shape_file = Path(shape_file)
self.utt2shape = load_num_sequence_text(shape_file, loader_type)
self.dtype = np.dtype(dtype)

View File

@ -4,7 +4,6 @@ from typing import Dict
from typing import List
from typing import Union
from typeguard import check_argument_types
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
@ -19,7 +18,6 @@ def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
assert check_argument_types()
data = {}
with Path(path).open("r", encoding="utf-8") as f:
@ -47,7 +45,6 @@ def load_num_sequence_text(
>>> d = load_num_sequence_text('text')
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
"""
assert check_argument_types()
if loader_type == "text_int":
delimiter = " "
dtype = int

View File

@ -6,7 +6,6 @@ import random
import numpy as np
import soundfile
import librosa
from typeguard import check_argument_types
import torch
import torchaudio
@ -106,7 +105,6 @@ class SoundScpReader(collections.abc.Mapping):
dest_sample_rate: int = 16000,
speed_perturb: Union[list, tuple] = None,
):
assert check_argument_types()
self.fname = fname
self.dtype = dtype
self.always_2d = always_2d
@ -179,7 +177,6 @@ class SoundScpWriter:
format="wav",
dtype=None,
):
assert check_argument_types()
self.dir = Path(outdir)
self.dir.mkdir(parents=True, exist_ok=True)
scpfile = Path(scpfile)

View File

@ -9,7 +9,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
@ -51,7 +50,6 @@ class ChunkIterFactory(AbsIterFactory):
collate_fn=None,
pin_memory: bool = False,
):
assert check_argument_types()
assert all(len(x) == 1 for x in batches), "batch-size must be 1"
self.per_sample_iter_factory = SequenceIterFactory(

View File

@ -4,7 +4,6 @@ from typing import Collection
from typing import Iterator
import numpy as np
from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
@ -16,7 +15,6 @@ class MultipleIterFactory(AbsIterFactory):
seed: int = 0,
shuffle: bool = False,
):
assert check_argument_types()
self.build_funcs = list(build_funcs)
self.seed = seed
self.shuffle = shuffle

View File

@ -4,7 +4,6 @@ from typing import Union
import numpy as np
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.samplers.abs_sampler import AbsSampler
@ -46,7 +45,6 @@ class SequenceIterFactory(AbsIterFactory):
collate_fn=None,
pin_memory: bool = False,
):
assert check_argument_types()
if not isinstance(batches, AbsSampler):
self.sampler = RawSampler(batches)

View File

@ -4,7 +4,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.abs_normalize import AbsNormalize
@ -28,7 +27,6 @@ class GlobalMVN(AbsNormalize, InversibleInterface):
norm_vars: bool = True,
eps: float = 1.0e-20,
):
assert check_argument_types()
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars

View File

@ -1,5 +1,4 @@
import torch
from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
@ -13,7 +12,6 @@ class LabelAggregate(torch.nn.Module):
hop_length: int = 128,
center: bool = True,
):
assert check_argument_types()
super().__init__()
self.win_length = win_length

View File

@ -1,6 +1,5 @@
import math
import torch
from typeguard import check_argument_types
from typing import Sequence
from typing import Union
@ -147,7 +146,6 @@ class MaskAlongAxis(torch.nn.Module):
dim: Union[int, str] = "time",
replace_with_zero: bool = True,
):
assert check_argument_types()
if isinstance(mask_width_range, int):
mask_width_range = (0, mask_width_range)
if len(mask_width_range) != 2:
@ -214,7 +212,6 @@ class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
dim: Union[int, str] = "time",
replace_with_zero: bool = True,
):
assert check_argument_types()
if isinstance(mask_width_ratio_range, float):
mask_width_ratio_range = (0.0, mask_width_ratio_range)
if len(mask_width_ratio_range) != 2:
@ -283,7 +280,6 @@ class MaskAlongAxisLFR(torch.nn.Module):
replace_with_zero: bool = True,
lfr_rate: int = 1,
):
assert check_argument_types()
if isinstance(mask_width_range, int):
mask_width_range = (0, mask_width_range)
if len(mask_width_range) != 2:

View File

@ -5,7 +5,6 @@
"""Sinc convolutions."""
import math
import torch
from typeguard import check_argument_types
from typing import Union
@ -71,7 +70,6 @@ class SincConv(torch.nn.Module):
window_func: Window function on the filter, one of ["hamming", "none"].
fs (str, int, float): Sample rate of the input data
"""
assert check_argument_types()
super().__init__()
window_funcs = {
"none": self.none_window,
@ -208,7 +206,6 @@ class MelScale:
torch.Tensor: Filter start frequencíes.
torch.Tensor: Filter stop frequencies.
"""
assert check_argument_types()
# min and max bandpass edge frequencies
min_frequency = torch.tensor(30.0)
max_frequency = torch.tensor(fs * 0.5)
@ -257,7 +254,6 @@ class BarkScale:
torch.Tensor: Filter start frequencíes.
torch.Tensor: Filter stop frequencíes.
"""
assert check_argument_types()
# min and max BARK center frequencies by approximation
min_center_frequency = torch.tensor(70.0)
max_center_frequency = torch.tensor(fs * 0.45)

View File

@ -5,7 +5,6 @@ from typing import Union
import torch
from torch_complex.tensor import ComplexTensor
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.complex_utils import is_complex
@ -30,7 +29,6 @@ class Stft(torch.nn.Module, InversibleInterface):
normalized: bool = False,
onesided: bool = True,
):
assert check_argument_types()
super().__init__()
self.n_fft = n_fft
if win_length is None:

View File

@ -1,7 +1,6 @@
from typing import Tuple
import torch
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.abs_normalize import AbsNormalize
@ -14,7 +13,6 @@ class UtteranceMVN(AbsNormalize):
norm_vars: bool = False,
eps: float = 1.0e-20,
):
assert check_argument_types()
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars

View File

@ -8,7 +8,6 @@ import os
from io import BytesIO
import torch
from typeguard import check_argument_types
from typing import Collection
from funasr.train.reporter import Reporter
@ -34,7 +33,6 @@ def average_nbest_models(
nbest: Number of best model files to be averaged
suffix: A suffix added to the averaged model file name
"""
assert check_argument_types()
if isinstance(nbest, int):
nbests = [nbest]
else:

View File

@ -11,7 +11,6 @@ import numpy as np
import torch
from torch.nn.parallel import data_parallel
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.fileio.npy_scp import NpyScpWriter
@ -37,7 +36,6 @@ def collect_stats(
This method is used before executing train().
"""
assert check_argument_types()
npy_scp_writers = {}
for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):

View File

@ -2,7 +2,6 @@ import logging
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
class CTC(torch.nn.Module):
@ -25,7 +24,6 @@ class CTC(torch.nn.Module):
reduce: bool = True,
ignore_nan_grad: bool = True,
):
assert check_argument_types()
super().__init__()
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
@ -41,11 +39,6 @@ class CTC(torch.nn.Module):
if ignore_nan_grad:
logging.warning("ignore_nan_grad option is not supported for warp_ctc")
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
elif self.ctc_type == "gtnctc":
from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
self.ctc_loss = GTNCTCLossFunction.apply
else:
raise ValueError(
f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'

View File

@ -10,7 +10,6 @@ from typing import Optional
from typing import Tuple
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.encoder.abs_encoder import AbsEncoder
@ -40,7 +39,6 @@ class Data2VecPretrainModel(FunASRModel):
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
):
assert check_argument_types()
super().__init__()

View File

@ -7,7 +7,6 @@ import numpy as np
from funasr.modules.streaming_utils import utils as myutils
from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
from typeguard import check_argument_types
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.modules.embedding import PositionalEncoding
@ -126,7 +125,6 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
kernel_size: int = 21,
sanm_shfit: int = 0,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,

View File

@ -3,7 +3,6 @@ import random
import numpy as np
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.nets_utils import to_device
@ -97,7 +96,6 @@ class RNNDecoder(AbsDecoder):
att_conf: dict = get_default_kwargs(build_attention_list),
):
# FIXME(kamo): The parts of num_spk should be refactored more more more
assert check_argument_types()
if rnn_type not in {"lstm", "gru"}:
raise ValueError(f"Not supported: rnn_type={rnn_type}")

View File

@ -3,7 +3,6 @@
from typing import List, Optional, Tuple
import torch
from typeguard import check_argument_types
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.specaug.specaug import SpecAug
@ -38,7 +37,6 @@ class RNNTDecoder(torch.nn.Module):
"""Construct a RNNDecoder object."""
super().__init__()
assert check_argument_types()
if rnn_type not in ("lstm", "gru"):
raise ValueError(f"Not supported: rnn_type={rnn_type}")

View File

@ -7,7 +7,6 @@ import numpy as np
from funasr.modules.streaming_utils import utils as myutils
from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
from typeguard import check_argument_types
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.modules.embedding import PositionalEncoding
@ -181,7 +180,6 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
embed_tensor_name_prefix_tf: str = None,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@ -838,7 +836,6 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
tf2torch_tensor_name_prefix_torch: str = "decoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,

View File

@ -9,7 +9,6 @@ from typing import Tuple
import torch
from torch import nn
from typeguard import check_argument_types
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.modules.attention import MultiHeadedAttention
@ -184,7 +183,6 @@ class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
@ -373,7 +371,6 @@ class TransformerDecoder(BaseTransformerDecoder):
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@ -428,7 +425,6 @@ class ParaformerDecoderSAN(BaseTransformerDecoder):
concat_after: bool = False,
embeds_id: int = -1,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@ -540,7 +536,6 @@ class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
@ -602,7 +597,6 @@ class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
@ -664,7 +658,6 @@ class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
@ -726,7 +719,6 @@ class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
@ -781,7 +773,6 @@ class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
@ -955,7 +946,6 @@ class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,

View File

@ -11,7 +11,6 @@ from typing import Tuple
from typing import Union
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@ -65,7 +64,6 @@ class ASRModel(FunASRModel):
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight

View File

@ -9,7 +9,6 @@ from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.ctc import CTC
@ -73,7 +72,6 @@ class NeatContextualParaformer(Paraformer):
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight

View File

@ -7,7 +7,6 @@ from typing import Tuple
from typing import Union
import logging
import torch
from typeguard import check_argument_types
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
@ -65,7 +64,6 @@ class MFCCA(FunASRModel):
sym_blank: str = "<blank>",
preencoder: Optional[AbsPreEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert rnnt_decoder is None, "Not implemented"

View File

@ -10,7 +10,6 @@ from typing import Union
import torch
import random
import numpy as np
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@ -80,7 +79,6 @@ class Paraformer(FunASRModel):
postencoder: Optional[AbsPostEncoder] = None,
use_1st_decoder_loss: bool = False,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@ -645,7 +643,6 @@ class ParaformerOnline(Paraformer):
postencoder: Optional[AbsPostEncoder] = None,
use_1st_decoder_loss: bool = False,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@ -1255,7 +1252,6 @@ class ParaformerBert(Paraformer):
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@ -1528,7 +1524,6 @@ class BiCifParaformer(Paraformer):
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@ -1806,7 +1801,6 @@ class ContextualParaformer(Paraformer):
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight

View File

@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
@ -86,8 +85,6 @@ class TransducerModel(FunASRModel):
"""Construct an ESPnetASRTransducerModel object."""
super().__init__()
assert check_argument_types()
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
self.blank_id = 0
self.vocab_size = vocab_size
@ -546,8 +543,6 @@ class UnifiedTransducerModel(FunASRModel):
"""Construct an ESPnetASRTransducerModel object."""
super().__init__()
assert check_argument_types()
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
self.blank_id = 0
@ -713,7 +708,7 @@ class UnifiedTransducerModel(FunASRModel):
loss_lm = self._calc_lm_loss(decoder_out, target)
loss_trans = loss_trans_utt + loss_trans_chunk
loss_ctc = loss_ctc + loss_ctc_chunk
loss_ctc = loss_ctc + loss_ctc_chunk
loss_ctc = loss_att + loss_att_chunk
loss = (
@ -1018,4 +1013,4 @@ class UnifiedTransducerModel(FunASRModel):
ignore_label=self.ignore_id,
)
return loss_att, acc_att
return loss_att, acc_att

View File

@ -9,7 +9,6 @@ from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
@ -48,7 +47,6 @@ class DiarEENDOLAModel(FunASRModel):
mapping_dict=None,
**kwargs,
):
assert check_argument_types()
super().__init__()
self.frontend = frontend

View File

@ -12,7 +12,6 @@ from typing import Tuple, List
import numpy as np
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
@ -66,7 +65,6 @@ class DiarSondModel(FunASRModel):
inter_score_loss_weight: float = 0.0,
inputs_type: str = "raw",
):
assert check_argument_types()
super().__init__()

View File

@ -12,7 +12,6 @@ from typing import Union
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@ -67,7 +66,6 @@ class SAASRModel(FunASRModel):
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight

View File

@ -12,7 +12,6 @@ from typing import Tuple
from typing import Union
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@ -56,7 +55,6 @@ class ESPnetSVModel(FunASRModel):
pooling_layer: torch.nn.Module,
decoder: AbsDecoder,
):
assert check_argument_types()
super().__init__()
# note that eos is the same as sos (equivalent ID)

View File

@ -9,7 +9,6 @@ from typing import Union
import torch
import numpy as np
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
@ -42,7 +41,6 @@ class TimestampPredictor(FunASRModel):
predictor_bias: int = 0,
token_list=None,
):
assert check_argument_types()
super().__init__()
# note that eos is the same as sos (equivalent ID)

View File

@ -8,7 +8,6 @@ from typing import Tuple
from typing import Union
import torch
from typeguard import check_argument_types
from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
@ -82,7 +81,6 @@ class UniASR(FunASRModel):
postencoder: Optional[AbsPostEncoder] = None,
encoder1_encoder2_joint_training: bool = True,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight

View File

@ -12,7 +12,6 @@ from typing import Dict
import torch
from torch import nn
from typeguard import check_argument_types
from funasr.models.ctc import CTC
from funasr.modules.attention import (
@ -533,7 +532,6 @@ class ConformerEncoder(AbsEncoder):
interctc_use_conditioning: bool = False,
stochastic_depth_rate: Union[float, List[float]] = 0.0,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
@ -943,7 +941,6 @@ class ConformerChunkEncoder(AbsEncoder):
"""Construct an Encoder object."""
super().__init__()
assert check_argument_types()
self.embed = StreamingConvInput(
input_size,

View File

@ -10,7 +10,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.data2vec.data_utils import compute_mask_indices
@ -97,7 +96,6 @@ class Data2VecEncoder(AbsEncoder):
# FP16 optimization
required_seq_len_multiple: int = 2,
):
assert check_argument_types()
super().__init__()
# ConvFeatureExtractionModel

View File

@ -5,7 +5,6 @@ import logging
import torch
from torch import nn
from typeguard import check_argument_types
from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
from funasr.modules.nets_utils import get_activation
@ -161,7 +160,6 @@ class MFCCAEncoder(AbsEncoder):
cnn_module_kernel: int = 31,
padding_idx: int = -1,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size

View File

@ -7,7 +7,6 @@ import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.layer_norm import LayerNorm
@ -90,7 +89,6 @@ class ConvEncoder(AbsEncoder):
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
):
assert check_argument_types()
super().__init__()
self._output_size = num_units

View File

@ -7,7 +7,6 @@ import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.layer_norm import LayerNorm

View File

@ -7,7 +7,6 @@ import logging
import torch
import torch.nn as nn
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
@ -144,7 +143,6 @@ class SelfAttentionEncoder(AbsEncoder):
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
out_units=None,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size

View File

@ -5,7 +5,6 @@ from typing import Tuple
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.rnn.encoders import RNN
@ -37,7 +36,6 @@ class RNNEncoder(AbsEncoder):
dropout: float = 0.0,
subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
self.rnn_type = rnn_type

View File

@ -8,7 +8,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.torch_utils.device_funcs import to_device
from funasr.modules.nets_utils import make_pad_mask
@ -151,7 +150,6 @@ class SANMEncoder(AbsEncoder):
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
@ -601,7 +599,6 @@ class SANMEncoderChunkOpt(AbsEncoder):
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
@ -1060,7 +1057,6 @@ class SANMVadEncoder(AbsEncoder):
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size

View File

@ -9,7 +9,6 @@ from typing import Tuple
import torch
from torch import nn
from typeguard import check_argument_types
import logging
from funasr.models.ctc import CTC
@ -189,7 +188,6 @@ class TransformerEncoder(AbsEncoder):
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size

View File

@ -7,7 +7,6 @@ import humanfriendly
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
from typeguard import check_argument_types
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
@ -40,7 +39,6 @@ class DefaultFrontend(AbsFrontend):
apply_stft: bool = True,
use_channel: int = None,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
@ -167,7 +165,6 @@ class MultiChannelFrontend(AbsFrontend):
cmvn_file: str = None,
mc: bool = True
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)

View File

@ -3,7 +3,6 @@ from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.s3prl import S3prlFrontend
import numpy as np
import torch
from typeguard import check_argument_types
from typing import Tuple
@ -12,7 +11,6 @@ class FusedFrontends(AbsFrontend):
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
):
assert check_argument_types()
super().__init__()
self.align_method = (
align_method # fusing method : linear_projection only for now

View File

@ -8,7 +8,6 @@ from typing import Union
import humanfriendly
import torch
from typeguard import check_argument_types
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
@ -37,7 +36,6 @@ class S3prlFrontend(AbsFrontend):
download_dir: str = None,
multilayer_feature: bool = False,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)

View File

@ -6,7 +6,6 @@ import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
from typeguard import check_argument_types
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from funasr.models.frontend.abs_frontend import AbsFrontend
@ -95,7 +94,6 @@ class WavFrontend(AbsFrontend):
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
@ -227,7 +225,6 @@ class WavFrontendOnline(AbsFrontend):
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
@ -466,7 +463,6 @@ class WavFrontendMel23(AbsFrontend):
lfr_m: int = 1,
lfr_n: int = 1,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.frame_length = frame_length

View File

@ -6,7 +6,6 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
import torch
from typeguard import check_argument_types
from typing import Tuple
@ -38,7 +37,6 @@ class SlidingWindow(AbsFrontend):
padding: Padding (placeholder, currently not implemented).
fs: Sampling rate (placeholder for compatibility, not used).
"""
assert check_argument_types()
super().__init__()
self.fs = fs
self.win_length = win_length

View File

@ -6,7 +6,6 @@
from funasr.modules.nets_utils import make_pad_mask
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from typeguard import check_argument_types
from typing import Tuple
import copy
@ -30,7 +29,6 @@ class HuggingFaceTransformersPostEncoder(AbsPostEncoder):
model_name_or_path: str,
):
"""Initialize the module."""
assert check_argument_types()
super().__init__()
if not is_transformers_available:

View File

@ -5,7 +5,6 @@
"""Linear Projection."""
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from typeguard import check_argument_types
from typing import Tuple
import torch
@ -20,7 +19,6 @@ class LinearProjection(AbsPreEncoder):
output_size: int,
):
"""Initialize the module."""
assert check_argument_types()
super().__init__()
self.output_dim = output_size

View File

@ -10,7 +10,6 @@ from funasr.layers.sinc_conv import LogCompression
from funasr.layers.sinc_conv import SincConv
import humanfriendly
import torch
from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
from typing import Union
@ -60,7 +59,6 @@ class LightweightSincConvs(AbsPreEncoder):
windowing_type: Choice of windowing function.
scale_type: Choice of filter-bank initialization scale.
"""
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
@ -268,7 +266,6 @@ class SpatialDropout(torch.nn.Module):
dropout_probability: Dropout probability.
shape (tuple, list): Shape of input tensors.
"""
assert check_argument_types()
super().__init__()
if shape is None:
shape = (0, 2, 1)

View File

@ -4,7 +4,6 @@ from typing import Union
import torch
import torch.nn as nn
from typeguard import check_argument_types
from funasr.train.abs_model import AbsLM
@ -27,7 +26,6 @@ class SequentialRNNLM(AbsLM):
rnn_type: str = "lstm",
ignore_id: int = 0,
):
assert check_argument_types()
super().__init__()
ninp = unit

View File

@ -2,7 +2,7 @@ import copy
import numpy as np
import time
import torch
from eend.utils.power import create_powerlabel
from funasr.modules.eend_ola.utils.power import create_powerlabel
from itertools import combinations
metrics = [

View File

@ -1,5 +1,4 @@
import torch
from typeguard import check_argument_types
class SGD(torch.optim.SGD):
@ -21,7 +20,6 @@ class SGD(torch.optim.SGD):
weight_decay: float = 0.0,
nesterov: bool = False,
):
assert check_argument_types()
super().__init__(
params,
lr=lr,

View File

@ -0,0 +1,214 @@
# FunASR离线文件转写服务开发指南
FunASR提供可一键本地或者云端服务器部署的中文离线文件转写服务内核为FunASR已开源runtime-SDK。FunASR-runtime结合了达摩院语音实验室在Modelscope社区开源的语音端点检测(VAD)、Paraformer-large语音识别(ASR)、标点检测(PUNC) 等相关能力,可以准确、高效的对音频进行高并发转写。
本文档为FunASR离线文件转写服务开发指南。如果您想快速体验离线文件转写服务请参考FunASR离线文件转写服务一键部署示例[点击此处](./SDK_tutorial.md))。
## Docker安装
下述步骤为手动安装docker及docker镜像的步骤如您docker镜像已启动可以忽略本步骤
### docker环境安装
```shell
# Ubuntu
curl -fsSL https://test.docker.com -o test-docker.sh
sudo sh test-docker.sh
# Debian
curl -fsSL https://get.docker.com -o get-docker.sh
sudo sh get-docker.sh
# CentOS
curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun
# MacOS
brew install --cask --appdir=/Applications docker
```
安装详见https://alibaba-damo-academy.github.io/FunASR/en/installation/docker.html
### docker启动
```shell
sudo systemctl start docker
```
### 镜像拉取及启动
通过下述命令拉取并启动FunASR runtime-SDK的docker镜像
```shell
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.0.1
sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.0.1
```
```text
-p <宿主机端口>:<映射到docker端口>
如示例,宿主机(ecs)端口10095映射到docker端口10095上。前提是确保ecs安全规则打开了10095端口。
-v <宿主机路径>:<挂载至docker路径>
如示例,宿主机路径/root挂载至docker路径/workspace/models
```
## 服务端启动
docker启动之后启动 funasr-wss-server服务程序
funasr-wss-server支持从Modelscope下载模型需要设置同时设置模型下载地址--download-model-dir及model ID--model-dir、--vad-dir、--punc-dir,示例如下:
```shell
cd /workspace/FunASR/funasr/runtime/websocket/build/bin
./funasr-wss-server \
--download-model-dir /workspace/models \
--model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
--vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
--punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
--decoder-thread-num 32 \
--io-thread-num 8 \
--port 10095 \
--certfile ../../../ssl_key/server.crt \
--keyfile ../../../ssl_key/server.key
```
```text
--download-model-dir #模型下载地址通过设置model ID从Modelscope下载模型
--model-dir # modelscope model ID
--quantize # True为量化ASR模型False为非量化ASR模型默认是True
--vad-dir # modelscope model ID
--vad-quant # True为量化VAD模型False为非量化VAD模型默认是True
--punc-dir # modelscope model ID
--punc-quant # True为量化PUNC模型False为非量化PUNC模型默认是True
--port # 服务端监听的端口号,默认为 8889
--decoder-thread-num # 服务端启动的推理线程数,默认为 8
--io-thread-num # 服务端启动的IO线程数默认为 1
--certfile <string> # ssl的证书文件默认为../../../ssl_key/server.crt
--keyfile <string> # ssl的密钥文件默认为../../../ssl_key/server.key
```
funasr-wss-server同时也支持从本地路径加载模型本地模型资源准备详见[模型资源准备](#anchor-1)),需要设置设置模型本地路径(--download-model-dir示例如下
```shell
cd /workspace/FunASR/funasr/runtime/websocket/build/bin
./funasr-wss-server \
--model-dir /workspace/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
--vad-dir /workspace/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
--punc-dir /workspace/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
--decoder-thread-num 32 \
--io-thread-num 8 \
--port 10095 \
--certfile ../../../ssl_key/server.crt \
--keyfile ../../../ssl_key/server.key
```
```text
--model-dir # ASR模型路径D默认为/workspace/models/asr
--quantize # True为量化ASR模型False为非量化ASR模型默认是True
--vad-dir # VAD模型路径默认为/workspace/models/vad
--vad-quant # True为量化VAD模型False为非量化VAD模型默认是True
--punc-dir # PUNC模型路径默认为/workspace/models/punc
--punc-quant # True为量化PUNC模型False为非量化PUNC模型默认是True
--port # 服务端监听的端口号,默认为 8889
--decoder-thread-num # 服务端启动的推理线程数,默认为 8
--io-thread-num # 服务端启动的IO线程数默认为 1
--certfile <string> # ssl的证书文件默认为../../../ssl_key/server.crt
--keyfile <string> # ssl的密钥文件默认为../../../ssl_key/server.key
```
## <a id="anchor-1">模型资源准备</a>
如果您选择通过funasr-wss-server从Modelscope下载模型可以跳过本步骤。
FunASR离线文件转写服务中的vad、asr和punc模型资源均来自Modelscope模型地址详见下表
| 模型 | Modelscope链接 |
|------|------------------------------------------------------------------------------------------------------------------|
| VAD | https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary |
| ASR | https://www.modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary |
| PUNC | https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary |
离线文件转写服务中部署的是量化后的ONNX模型下面介绍下如何导出ONNX模型及其量化您可以选择从Modelscope导出ONNX模型、从本地文件导出ONNX模型或者从finetune后的资源导出模型
### 从Modelscope导出ONNX模型
从Modelscope网站下载对应model name的模型然后导出量化后的ONNX模型
```shell
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
```
```text
--model-name Modelscope上的模型名称例如damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
--export-dir ONNX模型导出地址
--type 模型类型,目前支持 ONNX、torch
--quantize int8模型量化
```
### 从本地文件导出ONNX模型
设置model name为模型本地路径导出量化后的ONNX模型
```shell
python -m funasr.export.export_model --model-name /workspace/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
```
```text
--model-name 模型本地路径,例如/workspace/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
--export-dir ONNX模型导出地址
--type 模型类型,目前支持 ONNX、torch
--quantize int8模型量化
```
### 从finetune后的资源导出模型
假如您想部署finetune后的模型可以参考如下步骤
将您finetune后需要部署的模型例如10epoch.pb重命名为model.pb并将原modelscope中模型model.pb替换掉假如替换后的模型路径为/path/to/finetune/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch通过下述命令把finetune后的模型转成onnx模型
```shell
python -m funasr.export.export_model --model-name /path/to/finetune/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
```
## 客户端启动
在服务器上完成FunASR离线文件转写服务部署以后可以通过如下的步骤来测试和使用离线文件转写服务。目前FunASR-bin支持多种方式启动客户端如下是基于python-client、c++-client的命令行实例及自定义客户端Websocket通信协议
### python-client
```shell
python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "./data/wav.scp" --send_without_sleep --output_dir "./results"
```
```text
--host # 服务端ip地址本机测试可设置为 127.0.0.1
--port # 服务端监听端口号
--audio_in # 音频输入输入可以是wav路径 或者 wav.scp路径kaldi格式的wav listwav_id \t wav_path
--output_dir # 识别结果输出路径
--ssl # 是否使用SSL加密默认使用
--mode # offline模式
```
### c++-client
```shell
. /funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
```
```text
--server-ip # 服务端ip地址本机测试可设置为 127.0.0.1
--port # 服务端监听端口号
--wav-path # 音频输入输入可以是wav路径 或者 wav.scp路径kaldi格式的wav listwav_id \t wav_path
--thread-num # 客户端线程数
--is-ssl # 是否使用SSL加密默认使用
```
### 自定义客户端:
如果您想定义自己的clientwebsocket通信协议为
```text
# 首次通信
{"mode": "offline", "wav_name": wav_name, "is_speaking": True}
# 发送wav数据
bytes数据
# 发送结束标志
{"is_speaking": False}
```
## 如何定制服务部署
FunASR-runtime的代码已开源如果服务端和客户端不能很好的满足您的需求您可以根据自己的需求进行进一步的开发
### c++ 客户端:
https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/websocket
###python 客户端:
https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket
### c++ 服务端:
#### VAD
```c++
// VAD模型的使用分为FsmnVadInit和FsmnVadInfer两个步骤
FUNASR_HANDLE vad_hanlde=FsmnVadInit(model_path, thread_num);
// 其中model_path 包含"model-dir"、"quantize"thread_num为onnx线程数
FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000);
// 其中vad_hanlde为FunOfflineInit返回值wav_file为音频路径sampling_rate为采样率(默认16k)
```
使用示例详见https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
#### ASR
```text
// ASR模型的使用分为FunOfflineInit和FunOfflineInfer两个步骤
FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
// 其中model_path 包含"model-dir"、"quantize"thread_num为onnx线程数
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, 16000);
// 其中asr_hanlde为FunOfflineInit返回值wav_file为音频路径sampling_rate为采样率(默认16k)
```
使用示例详见https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
#### PUNC
```text
// PUNC模型的使用分为CTTransformerInit和CTTransformerInfer两个步骤
FUNASR_HANDLE punc_hanlde=CTTransformerInit(model_path, thread_num);
// 其中model_path 包含"model-dir"、"quantize"thread_num为onnx线程数
FUNASR_RESULT result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
// 其中punc_hanlde为CTTransformerInit返回值txt_str为文本
```
使用示例详见https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp

View File

@ -3,7 +3,6 @@ from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import numpy as np
from typeguard import check_argument_types
import kaldi_native_fbank as knf
root_dir = Path(__file__).resolve().parent
@ -28,7 +27,6 @@ class WavFrontend():
dither: float = 1.0,
**kwargs,
) -> None:
check_argument_types()
opts = knf.FbankOptions()
opts.frame_opts.samp_freq = fs

View File

@ -9,7 +9,6 @@ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import numpy as np
import yaml
from typeguard import check_argument_types
import warnings
@ -21,7 +20,6 @@ logger_initialized = {}
class TokenIDConverter():
def __init__(self, token_list: Union[List, str],
):
check_argument_types()
self.token_list = token_list
self.unk_symbol = token_list[-1]
@ -51,7 +49,6 @@ class CharTokenizer():
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
check_argument_types()
self.space_symbol = space_symbol
self.non_linguistic_symbols = self.load_symbols(symbol_value)

View File

@ -4,7 +4,6 @@ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import copy
import numpy as np
from typeguard import check_argument_types
import kaldi_native_fbank as knf
root_dir = Path(__file__).resolve().parent
@ -29,7 +28,6 @@ class WavFrontend():
dither: float = 1.0,
**kwargs,
) -> None:
check_argument_types()
opts = knf.FbankOptions()
opts.frame_opts.samp_freq = fs

View File

@ -10,7 +10,6 @@ import numpy as np
import yaml
from onnxruntime import (GraphOptimizationLevel, InferenceSession,
SessionOptions, get_available_providers, get_device)
from typeguard import check_argument_types
import warnings
@ -22,7 +21,6 @@ logger_initialized = {}
class TokenIDConverter():
def __init__(self, token_list: Union[List, str],
):
check_argument_types()
self.token_list = token_list
self.unk_symbol = token_list[-1]
@ -52,7 +50,6 @@ class CharTokenizer():
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
check_argument_types()
self.space_symbol = space_symbol
self.non_linguistic_symbols = self.load_symbols(symbol_value)

View File

@ -109,7 +109,6 @@ class WavFrontend():
lfr_n: int = 6,
dither: float = 1.0
) -> None:
# check_argument_types()
self.fs = fs
self.window = window

View File

@ -4,8 +4,6 @@ from typing import Sequence
from typing import Tuple
from typing import Union
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.samplers.abs_sampler import AbsSampler
from funasr.samplers.folded_batch_sampler import FoldedBatchSampler
@ -104,7 +102,6 @@ def build_batch_sampler(
padding: Whether sequences are input as a padded tensor or not.
used for "numel" mode
"""
assert check_argument_types()
if len(shape_files) == 0:
raise ValueError("No shape file are given")
@ -164,5 +161,4 @@ def build_batch_sampler(
else:
raise ValueError(f"Not supported: {type}")
assert check_return_type(retval)
return retval

View File

@ -4,7 +4,6 @@ from typing import Sequence
from typing import Tuple
from typing import Union
from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.fileio.read_text import read_2column_text
@ -23,7 +22,6 @@ class FoldedBatchSampler(AbsSampler):
drop_last: bool = False,
utt2category_file: str = None,
):
assert check_argument_types()
assert batch_size > 0
if sort_batch != "ascending" and sort_batch != "descending":
raise ValueError(

View File

@ -4,7 +4,6 @@ from typing import Dict
from typing import Tuple
from typing import Union
from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.samplers.abs_sampler import AbsSampler
@ -21,7 +20,6 @@ class LengthBatchSampler(AbsSampler):
drop_last: bool = False,
padding: bool = True,
):
assert check_argument_types()
assert batch_bins > 0
if sort_batch != "ascending" and sort_batch != "descending":
raise ValueError(

View File

@ -4,7 +4,6 @@ from typing import Tuple
from typing import Union
import numpy as np
from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.samplers.abs_sampler import AbsSampler
@ -21,7 +20,6 @@ class NumElementsBatchSampler(AbsSampler):
drop_last: bool = False,
padding: bool = True,
):
assert check_argument_types()
assert batch_bins > 0
if sort_batch != "ascending" and sort_batch != "descending":
raise ValueError(

View File

@ -2,7 +2,6 @@ import logging
from typing import Iterator
from typing import Tuple
from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.samplers.abs_sampler import AbsSampler
@ -26,7 +25,6 @@ class SortedBatchSampler(AbsSampler):
sort_batch: str = "ascending",
drop_last: bool = False,
):
assert check_argument_types()
assert batch_size > 0
self.batch_size = batch_size
self.shape_file = shape_file

View File

@ -2,7 +2,6 @@ import logging
from typing import Iterator
from typing import Tuple
from typeguard import check_argument_types
from funasr.fileio.read_text import read_2column_text
from funasr.samplers.abs_sampler import AbsSampler
@ -28,7 +27,6 @@ class UnsortedBatchSampler(AbsSampler):
drop_last: bool = False,
utt2category_file: str = None,
):
assert check_argument_types()
assert batch_size > 0
self.batch_size = batch_size
self.key_file = key_file

View File

@ -4,7 +4,6 @@ import warnings
import torch
from torch.optim.lr_scheduler import _LRScheduler
from typeguard import check_argument_types
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
@ -31,7 +30,6 @@ class NoamLR(_LRScheduler, AbsBatchStepScheduler):
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
assert check_argument_types()
self.model_size = model_size
self.warmup_steps = warmup_steps

View File

@ -8,7 +8,6 @@ from typing import Optional, List
import torch
from torch.optim.lr_scheduler import _LRScheduler
from typeguard import check_argument_types
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
@ -22,7 +21,6 @@ class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
init_lr_scale: float = 0.01,
final_lr_scale: float = 0.01,
):
assert check_argument_types()
self.optimizer = optimizer
self.last_epoch = last_epoch
self.phase_ratio = phase_ratio

View File

@ -3,7 +3,6 @@ from typing import Union
import torch
from torch.optim.lr_scheduler import _LRScheduler
from typeguard import check_argument_types
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
@ -30,7 +29,6 @@ class WarmupLR(_LRScheduler, AbsBatchStepScheduler):
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
assert check_argument_types()
self.warmup_steps = warmup_steps
# __init__() must be invoked before setting field

View File

@ -32,8 +32,6 @@ import torch.optim
import yaml
from funasr.models.base_model import FunASRModel
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr import __version__
from funasr.datasets.dataset import AbsDataset
@ -269,7 +267,6 @@ class AbsTask(ABC):
@classmethod
def get_parser(cls) -> config_argparse.ArgumentParser:
assert check_argument_types()
class ArgumentDefaultsRawTextHelpFormatter(
argparse.RawTextHelpFormatter,
@ -959,7 +956,6 @@ class AbsTask(ABC):
cls.trainer.add_arguments(parser)
cls.add_task_arguments(parser)
assert check_return_type(parser)
return parser
@classmethod
@ -1007,7 +1003,6 @@ class AbsTask(ABC):
return _cls
# This method is used only for --print_config
assert check_argument_types()
parser = cls.get_parser()
args, _ = parser.parse_known_args()
config = vars(args)
@ -1047,7 +1042,6 @@ class AbsTask(ABC):
@classmethod
def check_required_command_args(cls, args: argparse.Namespace):
assert check_argument_types()
if hasattr(args, "required"):
for k in vars(args):
if "-" in k:
@ -1077,7 +1071,6 @@ class AbsTask(ABC):
inference: bool = False,
) -> None:
"""Check if the dataset satisfy the requirement of current Task"""
assert check_argument_types()
mes = (
f"If you intend to use an additional input, modify "
f'"{cls.__name__}.required_data_names()" or '
@ -1104,14 +1097,12 @@ class AbsTask(ABC):
@classmethod
def print_config(cls, file=sys.stdout) -> None:
assert check_argument_types()
# Shows the config: e.g. python train.py asr --print_config
config = cls.get_default_config()
file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False))
@classmethod
def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
assert check_argument_types()
print(get_commandline_args(), file=sys.stderr)
if args is None:
parser = cls.get_parser()
@ -1148,7 +1139,6 @@ class AbsTask(ABC):
@classmethod
def main_worker(cls, args: argparse.Namespace):
assert check_argument_types()
# 0. Init distributed process
distributed_option = build_dataclass(DistributedOption, args)
@ -1556,7 +1546,6 @@ class AbsTask(ABC):
- 4 epoch with "--num_iters_per_epoch" == 4
"""
assert check_argument_types()
iter_options = cls.build_iter_options(args, distributed_option, mode)
# Overwrite iter_options if any kwargs is given
@ -1589,7 +1578,6 @@ class AbsTask(ABC):
def build_sequence_iter_factory(
cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str
) -> AbsIterFactory:
assert check_argument_types()
if hasattr(args, "frontend_conf"):
if args.frontend_conf is not None and "fs" in args.frontend_conf:
@ -1683,7 +1671,6 @@ class AbsTask(ABC):
iter_options: IteratorOptions,
mode: str,
) -> AbsIterFactory:
assert check_argument_types()
dataset = ESPnetDataset(
iter_options.data_path_and_name_and_type,
@ -1788,7 +1775,6 @@ class AbsTask(ABC):
def build_multiple_iter_factory(
cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str
):
assert check_argument_types()
iter_options = cls.build_iter_options(args, distributed_option, mode)
assert len(iter_options.data_path_and_name_and_type) > 0, len(
iter_options.data_path_and_name_and_type
@ -1885,7 +1871,6 @@ class AbsTask(ABC):
inference: bool = False,
) -> DataLoader:
"""Build DataLoader using iterable dataset"""
assert check_argument_types()
# For backward compatibility for pytorch DataLoader
if collate_fn is not None:
kwargs = dict(collate_fn=collate_fn)
@ -1935,7 +1920,6 @@ class AbsTask(ABC):
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "

Some files were not shown because too many files have changed in this diff Show More