From ebdf631d98bc5eeae086a4cf036dedb0dc6aa58f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=98=89=E6=B8=8A?= Date: Fri, 12 May 2023 11:22:58 +0800 Subject: [PATCH] update repo --- egs/aishell/conformer/run.sh | 2 ++ .../data2vec_paraformer_finetune/run.sh | 2 ++ .../data2vec_transformer_finetune/run.sh | 2 ++ egs/aishell/paraformer/run.sh | 2 ++ .../conf/train_asr_transformer.yaml | 2 +- egs/aishell/transformer/run.sh | 2 ++ egs/librispeech_100h/conformer/run.sh | 2 ++ funasr/bin/train.py | 8 ++++++- .../large_datasets/build_dataloader.py | 3 ++- funasr/datasets/large_datasets/dataset.py | 23 ++++++++++++++----- funasr/datasets/small_datasets/dataset.py | 5 +++- .../small_datasets/sequence_iter_factory.py | 1 + 12 files changed, 44 insertions(+), 10 deletions(-) diff --git a/egs/aishell/conformer/run.sh b/egs/aishell/conformer/run.sh index eb3e13c8e..8db0d97c9 100755 --- a/egs/aishell/conformer/run.sh +++ b/egs/aishell/conformer/run.sh @@ -19,6 +19,7 @@ lang=zh token_type=char type=sound scp=wav.scp +speed_perturb="0.9 1.0 1.1" stage=3 stop_stage=4 @@ -129,6 +130,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --train_set ${train_set} \ --valid_set ${valid_set} \ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \ + --speed_perturb ${speed_perturb} \ --resume true \ --output_dir ${exp_dir}/exp/${model_dir} \ --config $asr_config \ diff --git a/egs/aishell/data2vec_paraformer_finetune/run.sh b/egs/aishell/data2vec_paraformer_finetune/run.sh index 42b44258c..b9d166893 100755 --- a/egs/aishell/data2vec_paraformer_finetune/run.sh +++ b/egs/aishell/data2vec_paraformer_finetune/run.sh @@ -19,6 +19,7 @@ lang=zh token_type=char type=sound scp=wav.scp +speed_perturb="0.9 1.0 1.1" stage=3 stop_stage=4 @@ -183,6 +184,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --gpuid_list ${gpuid_list} \ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \ + --speed_perturb ${speed_perturb} \ --key_file "${_logdir}"/keys.JOB.scp \ --asr_train_config "${asr_exp}"/config.yaml \ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ diff --git a/egs/aishell/data2vec_transformer_finetune/run.sh b/egs/aishell/data2vec_transformer_finetune/run.sh index 65dd71b09..e040290a3 100755 --- a/egs/aishell/data2vec_transformer_finetune/run.sh +++ b/egs/aishell/data2vec_transformer_finetune/run.sh @@ -19,6 +19,7 @@ lang=zh token_type=char type=sound scp=wav.scp +speed_perturb="0.9 1.0 1.1" stage=3 stop_stage=4 @@ -134,6 +135,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --valid_set ${valid_set} \ --init_param ${init_param} \ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \ + --speed_perturb ${speed_perturb} \ --resume true \ --output_dir ${exp_dir}/exp/${model_dir} \ --config $asr_config \ diff --git a/egs/aishell/paraformer/run.sh b/egs/aishell/paraformer/run.sh index 5094a728e..ccf8f6e25 100755 --- a/egs/aishell/paraformer/run.sh +++ b/egs/aishell/paraformer/run.sh @@ -19,6 +19,7 @@ lang=zh token_type=char type=sound scp=wav.scp +speed_perturb="0.9 1.0 1.1" stage=1 stop_stage=3 @@ -129,6 +130,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --train_set ${train_set} \ --valid_set ${valid_set} \ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \ + --speed_perturb ${speed_perturb} \ --resume true \ --output_dir ${exp_dir}/exp/${model_dir} \ --config $asr_config \ diff --git a/egs/aishell/transformer/conf/train_asr_transformer.yaml b/egs/aishell/transformer/conf/train_asr_transformer.yaml index c62bdac91..22e651bd9 100644 --- a/egs/aishell/transformer/conf/train_asr_transformer.yaml +++ b/egs/aishell/transformer/conf/train_asr_transformer.yaml @@ -43,7 +43,7 @@ model_conf: # optimization related accum_grad: 1 grad_clip: 5 -patience: 3 +patience: none max_epoch: 60 val_scheduler_criterion: - valid diff --git a/egs/aishell/transformer/run.sh b/egs/aishell/transformer/run.sh index 76a76618e..b7ad9cde5 100755 --- a/egs/aishell/transformer/run.sh +++ b/egs/aishell/transformer/run.sh @@ -19,6 +19,7 @@ lang=zh token_type=char type=sound scp=wav.scp +speed_perturb="0.9 1.0 1.1" stage=3 stop_stage=4 @@ -129,6 +130,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --train_set ${train_set} \ --valid_set ${valid_set} \ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \ + --speed_perturb ${speed_perturb} \ --resume true \ --output_dir ${exp_dir}/exp/${model_dir} \ --config $asr_config \ diff --git a/egs/librispeech_100h/conformer/run.sh b/egs/librispeech_100h/conformer/run.sh index 7eee9a885..e9808065d 100755 --- a/egs/librispeech_100h/conformer/run.sh +++ b/egs/librispeech_100h/conformer/run.sh @@ -19,6 +19,7 @@ lang=en token_type=bpe type=sound scp=wav.scp +speed_perturb="0.9 1.0 1.1" stage=3 stop_stage=4 @@ -139,6 +140,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --train_set ${train_set} \ --valid_set ${valid_set} \ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \ + --speed_perturb ${speed_perturb} \ --resume true \ --output_dir ${exp_dir}/exp/${model_dir} \ --config $asr_config \ diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 8436dd59f..ba5df1db3 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -334,7 +334,13 @@ def get_parser(): default="validation", help="dev dataset", ) - + parser.add_argument( + "--speed_perturb", + type=float, + nargs="+", + default=None, + help="speed perturb", + ) parser.add_argument( "--use_preprocessor", type=str2bool, diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py index f1ec00587..7889e7051 100644 --- a/funasr/datasets/large_datasets/build_dataloader.py +++ b/funasr/datasets/large_datasets/build_dataloader.py @@ -75,7 +75,8 @@ class LargeDataLoader(AbsIterFactory): logging.info("dataloader config: {}".format(self.dataset_conf)) batch_mode = self.dataset_conf.get("batch_mode", "padding") self.dataset = Dataset(args.data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer, - self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode) + self.dataset_conf, self.frontend_conf, speed_perturb=args.speed_perturb, + mode=mode, batch_mode=batch_mode) def build_iter(self, epoch, shuffle=True): self.dataset.set_epoch(epoch) diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py index b0e1b8f31..33ed13ab5 100644 --- a/funasr/datasets/large_datasets/dataset.py +++ b/funasr/datasets/large_datasets/dataset.py @@ -1,20 +1,20 @@ +import logging import os import random -import numpy from functools import partial import torch -import torchaudio import torch.distributed as dist +import torchaudio from kaldiio import ReadHelper from torch.utils.data import IterableDataset from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe +from funasr.datasets.large_datasets.utils.clipping import clipping from funasr.datasets.large_datasets.utils.filter import filter from funasr.datasets.large_datasets.utils.padding import padding -from funasr.datasets.large_datasets.utils.clipping import clipping from funasr.datasets.large_datasets.utils.tokenize import tokenize @@ -28,7 +28,8 @@ def read_lists(list_file): class AudioDataset(IterableDataset): - def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"): + def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None, + mode="train"): self.scp_lists = scp_lists self.data_names = data_names self.data_types = data_types @@ -40,6 +41,9 @@ class AudioDataset(IterableDataset): self.world_size = 1 self.worker_id = 0 self.num_workers = 1 + self.speed_perturb = speed_perturb + if self.speed_perturb is not None: + logging.info("Using speed_perturb: {}".format(speed_perturb)) def set_epoch(self, epoch): self.epoch = epoch @@ -124,9 +128,14 @@ class AudioDataset(IterableDataset): if sampling_rate != self.frontend_conf["fs"]: waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=self.frontend_conf["fs"])(waveform) - sampling_rate = self.frontend_conf["fs"] + sampling_rate = self.frontend_conf["fs"] waveform = waveform.numpy() mat = waveform[0] + if self.speed_perturb is not None: + speed = random.choice(self.speed_perturb) + if speed != 1.0: + mat, _ = torchaudio.sox_effects.apply_effects_tensor( + mat, sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]]) sample_dict[data_name] = mat sample_dict["sampling_rate"] = sampling_rate if data_name == "speech": @@ -161,13 +170,15 @@ def Dataset(data_list_file, bpe_tokenizer, conf, frontend_conf, + speed_perturb=None, mode="train", batch_mode="padding"): scp_lists = read_lists(data_list_file) shuffle = conf.get('shuffle', True) data_names = conf.get("data_names", "speech,text") data_types = conf.get("data_types", "kaldi_ark,text") - dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode) + dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, + speed_perturb=speed_perturb, mode=mode) filter_conf = conf.get('filter_conf', {}) filter_fn = partial(filter, **filter_conf) diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py index 33b9276bf..5d61d8860 100644 --- a/funasr/datasets/small_datasets/dataset.py +++ b/funasr/datasets/small_datasets/dataset.py @@ -127,6 +127,8 @@ class ESPnetDataset(Dataset): self.dest_sample_rate = dest_sample_rate self.speed_perturb = speed_perturb self.mode = mode + if self.speed_perturb is not None: + logging.info("Using speed_perturb: {}".format(speed_perturb)) self.loader_dict = {} self.debug_info = {} @@ -151,7 +153,8 @@ class ESPnetDataset(Dataset): """ if loader_type == "sound": speed_perturb = self.speed_perturb if self.mode == "train" else None - loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False, speed_perturb=speed_perturb) + loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False, + speed_perturb=speed_perturb) return AdapterForSoundScpReader(loader, self.float_dtype) elif loader_type == "kaldi_ark": loader = kaldiio.load_scp(path) diff --git a/funasr/datasets/small_datasets/sequence_iter_factory.py b/funasr/datasets/small_datasets/sequence_iter_factory.py index c35314fad..f524c7856 100644 --- a/funasr/datasets/small_datasets/sequence_iter_factory.py +++ b/funasr/datasets/small_datasets/sequence_iter_factory.py @@ -57,6 +57,7 @@ class SequenceIterFactory(AbsIterFactory): data_path_and_name_and_type, preprocess=preprocess_fn, dest_sample_rate=dest_sample_rate, + speed_perturb=args.speed_perturb, ) # sampler