update repo

This commit is contained in:
嘉渊 2023-05-12 11:22:58 +08:00
parent dcc6002363
commit ebdf631d98
12 changed files with 44 additions and 10 deletions

View File

@ -19,6 +19,7 @@ lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=3
stop_stage=4
@ -129,6 +130,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--train_set ${train_set} \
--valid_set ${valid_set} \
--cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \

View File

@ -19,6 +19,7 @@ lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=3
stop_stage=4
@ -183,6 +184,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
--cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--speed_perturb ${speed_perturb} \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \

View File

@ -19,6 +19,7 @@ lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=3
stop_stage=4
@ -134,6 +135,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--valid_set ${valid_set} \
--init_param ${init_param} \
--cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \

View File

@ -19,6 +19,7 @@ lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=1
stop_stage=3
@ -129,6 +130,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--train_set ${train_set} \
--valid_set ${valid_set} \
--cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \

View File

@ -43,7 +43,7 @@ model_conf:
# optimization related
accum_grad: 1
grad_clip: 5
patience: 3
patience: none
max_epoch: 60
val_scheduler_criterion:
- valid

View File

@ -19,6 +19,7 @@ lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=3
stop_stage=4
@ -129,6 +130,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--train_set ${train_set} \
--valid_set ${valid_set} \
--cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \

View File

@ -19,6 +19,7 @@ lang=en
token_type=bpe
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=3
stop_stage=4
@ -139,6 +140,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--train_set ${train_set} \
--valid_set ${valid_set} \
--cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \

View File

@ -334,7 +334,13 @@ def get_parser():
default="validation",
help="dev dataset",
)
parser.add_argument(
"--speed_perturb",
type=float,
nargs="+",
default=None,
help="speed perturb",
)
parser.add_argument(
"--use_preprocessor",
type=str2bool,

View File

@ -75,7 +75,8 @@ class LargeDataLoader(AbsIterFactory):
logging.info("dataloader config: {}".format(self.dataset_conf))
batch_mode = self.dataset_conf.get("batch_mode", "padding")
self.dataset = Dataset(args.data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
self.dataset_conf, self.frontend_conf, speed_perturb=args.speed_perturb,
mode=mode, batch_mode=batch_mode)
def build_iter(self, epoch, shuffle=True):
self.dataset.set_epoch(epoch)

View File

@ -1,20 +1,20 @@
import logging
import os
import random
import numpy
from functools import partial
import torch
import torchaudio
import torch.distributed as dist
import torchaudio
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.filter import filter
from funasr.datasets.large_datasets.utils.padding import padding
from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.tokenize import tokenize
@ -28,7 +28,8 @@ def read_lists(list_file):
class AudioDataset(IterableDataset):
def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
mode="train"):
self.scp_lists = scp_lists
self.data_names = data_names
self.data_types = data_types
@ -40,6 +41,9 @@ class AudioDataset(IterableDataset):
self.world_size = 1
self.worker_id = 0
self.num_workers = 1
self.speed_perturb = speed_perturb
if self.speed_perturb is not None:
logging.info("Using speed_perturb: {}".format(speed_perturb))
def set_epoch(self, epoch):
self.epoch = epoch
@ -124,9 +128,14 @@ class AudioDataset(IterableDataset):
if sampling_rate != self.frontend_conf["fs"]:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend_conf["fs"])(waveform)
sampling_rate = self.frontend_conf["fs"]
sampling_rate = self.frontend_conf["fs"]
waveform = waveform.numpy()
mat = waveform[0]
if self.speed_perturb is not None:
speed = random.choice(self.speed_perturb)
if speed != 1.0:
mat, _ = torchaudio.sox_effects.apply_effects_tensor(
mat, sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
sample_dict[data_name] = mat
sample_dict["sampling_rate"] = sampling_rate
if data_name == "speech":
@ -161,13 +170,15 @@ def Dataset(data_list_file,
bpe_tokenizer,
conf,
frontend_conf,
speed_perturb=None,
mode="train",
batch_mode="padding"):
scp_lists = read_lists(data_list_file)
shuffle = conf.get('shuffle', True)
data_names = conf.get("data_names", "speech,text")
data_types = conf.get("data_types", "kaldi_ark,text")
dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode)
dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle,
speed_perturb=speed_perturb, mode=mode)
filter_conf = conf.get('filter_conf', {})
filter_fn = partial(filter, **filter_conf)

View File

@ -127,6 +127,8 @@ class ESPnetDataset(Dataset):
self.dest_sample_rate = dest_sample_rate
self.speed_perturb = speed_perturb
self.mode = mode
if self.speed_perturb is not None:
logging.info("Using speed_perturb: {}".format(speed_perturb))
self.loader_dict = {}
self.debug_info = {}
@ -151,7 +153,8 @@ class ESPnetDataset(Dataset):
"""
if loader_type == "sound":
speed_perturb = self.speed_perturb if self.mode == "train" else None
loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False, speed_perturb=speed_perturb)
loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
speed_perturb=speed_perturb)
return AdapterForSoundScpReader(loader, self.float_dtype)
elif loader_type == "kaldi_ark":
loader = kaldiio.load_scp(path)

View File

@ -57,6 +57,7 @@ class SequenceIterFactory(AbsIterFactory):
data_path_and_name_and_type,
preprocess=preprocess_fn,
dest_sample_rate=dest_sample_rate,
speed_perturb=args.speed_perturb,
)
# sampler