diff --git a/examples/aishell/paraformer/run.sh b/examples/aishell/paraformer/run.sh index 410751af1..149f4d791 100755 --- a/examples/aishell/paraformer/run.sh +++ b/examples/aishell/paraformer/run.sh @@ -50,6 +50,7 @@ inference_scp="wav.scp" if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then echo "stage -1: Data Download" + mkdir -p ${raw_data} local/download_and_untar.sh ${raw_data} ${data_url} data_aishell local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell fi @@ -76,9 +77,8 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then echo "stage 1: Feature and CMVN Generation" -# utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$config" --scale 1.0 python ../../../funasr/bin/compute_audio_cmvn.py \ - --config-path "${workspace}" \ + --config-path "${workspace}/conf" \ --config-name "${config}" \ ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \ ++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \ @@ -109,13 +109,14 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then echo "stage 4: ASR Training" + mkdir -p ${exp_dir}/exp/${model_dir} log_file="${exp_dir}/exp/${model_dir}/train.log.txt" echo "log_file: ${log_file}" torchrun \ --nnodes 1 \ --nproc_per_node ${gpu_num} \ ../../../funasr/bin/train.py \ - --config-path "${workspace}" \ + --config-path "${workspace}/conf" \ --config-name "${config}" \ ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \ ++tokenizer_conf.token_list="${token_list}" \ diff --git a/funasr/bin/compute_audio_cmvn.py b/funasr/bin/compute_audio_cmvn.py index b66bb14d6..4561bec41 100644 --- a/funasr/bin/compute_audio_cmvn.py +++ b/funasr/bin/compute_audio_cmvn.py @@ -79,8 +79,8 @@ def main(**kwargs): fbank = batch["speech"].numpy()[0, :, :] if total_frames == 0: - mean_stats = fbank - var_stats = np.square(fbank) + mean_stats = np.sum(fbank, axis=0) + var_stats = np.sum(np.square(fbank), axis=0) else: mean_stats += np.sum(fbank, axis=0) var_stats += np.sum(np.square(fbank), axis=0) @@ -93,6 +93,7 @@ def main(**kwargs): 'total_frames': total_frames } cmvn_file = kwargs.get("cmvn_file", "cmvn.json") + # import pdb;pdb.set_trace() with open(cmvn_file, 'w') as fout: fout.write(json.dumps(cmvn_info)) @@ -110,14 +111,14 @@ def main(**kwargs): fout.write("" + '\n') - + +""" +python funasr/bin/compute_audio_cmvn.py \ +--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \ +--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \ +++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \ +++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \ +++dataset_conf.num_workers=0 +""" if __name__ == "__main__": main_hydra() - """ - python funasr/bin/compute_status.py \ - --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \ - --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \ - ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \ - ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \ - ++dataset_conf.num_workers=32 - """ \ No newline at end of file diff --git a/funasr/bin/train.py b/funasr/bin/train.py index c9a4a6784..d9165096f 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -79,9 +79,8 @@ def main(**kwargs): frontend = frontend_class(**kwargs["frontend_conf"]) kwargs["frontend"] = frontend kwargs["input_size"] = frontend.output_size() - - # import pdb; - # pdb.set_trace() + + # build model model_class = tables.model_classes.get(kwargs["model"]) model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py index 62acb44af..ab08fb048 100644 --- a/funasr/datasets/audio_datasets/datasets.py +++ b/funasr/datasets/audio_datasets/datasets.py @@ -22,12 +22,12 @@ class AudioDataset(torch.utils.data.Dataset): self.index_ds = index_ds_class(path, **kwargs) preprocessor_speech = kwargs.get("preprocessor_speech", None) if preprocessor_speech: - preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech) + preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech) preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) self.preprocessor_speech = preprocessor_speech preprocessor_text = kwargs.get("preprocessor_text", None) if preprocessor_text: - preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text) + preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) self.preprocessor_text = preprocessor_text @@ -57,7 +57,7 @@ class AudioDataset(torch.utils.data.Dataset): source = item["source"] data_src = load_audio_text_image_video(source, fs=self.fs) if self.preprocessor_speech: - data_src = self.preprocessor_speech(data_src) + data_src = self.preprocessor_speech(data_src, fs=self.fs) speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d] target = item["target"] diff --git a/funasr/datasets/audio_datasets/preprocessor.py b/funasr/datasets/audio_datasets/preprocessor.py new file mode 100644 index 000000000..6c21fbf0e --- /dev/null +++ b/funasr/datasets/audio_datasets/preprocessor.py @@ -0,0 +1,83 @@ +import os +import json +import torch +import logging +import concurrent.futures +import librosa +import torch.distributed as dist +from typing import Collection +import torch +import torchaudio +from torch import nn +import random +import re +from funasr.tokenizer.cleaner import TextCleaner +from funasr.register import tables + + +@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb") +class SpeechPreprocessSpeedPerturb(nn.Module): + def __init__(self, speed_perturb: list=None, **kwargs): + super().__init__() + self.speed_perturb = speed_perturb + + def forward(self, waveform, fs, **kwargs): + if self.speed_perturb is None: + return waveform + speed = random.choice(self.speed_perturb) + if speed != 1.0: + waveform, _ = torchaudio.sox_effects.apply_effects_tensor( + torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]]) + waveform = waveform.view(-1) + + return waveform + + +@tables.register("preprocessor_classes", "TextPreprocessSegDict") +class TextPreprocessSegDict(nn.Module): + def __init__(self, seg_dict: str = None, + text_cleaner: Collection[str] = None, + split_with_space: bool = False, + **kwargs): + super().__init__() + + self.seg_dict = None + if seg_dict is not None: + self.seg_dict = {} + with open(seg_dict, "r", encoding="utf8") as f: + lines = f.readlines() + for line in lines: + s = line.strip().split() + key = s[0] + value = s[1:] + self.seg_dict[key] = " ".join(value) + self.text_cleaner = TextCleaner(text_cleaner) + self.split_with_space = split_with_space + + def forward(self, text, **kwargs): + if self.seg_dict is not None: + text = self.text_cleaner(text) + if self.split_with_space: + tokens = text.strip().split(" ") + if self.seg_dict is not None: + text = seg_tokenize(tokens, self.seg_dict) + + return text + +def seg_tokenize(txt, seg_dict): + pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$') + out_txt = "" + for word in txt: + word = word.lower() + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + if pattern.match(word): + for char in word: + if char in seg_dict: + out_txt += seg_dict[char] + " " + else: + out_txt += "" + " " + else: + out_txt += "" + " " + return out_txt.strip().split() \ No newline at end of file diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py index c6e03e86e..71cf77a07 100644 --- a/funasr/frontends/wav_frontend.py +++ b/funasr/frontends/wav_frontend.py @@ -32,6 +32,7 @@ def load_cmvn(cmvn_file): rescale_line = line_item[3:(len(line_item) - 1)] vars_list = list(rescale_line) continue + import pdb;pdb.set_trace() means = np.array(means_list).astype(np.float32) vars = np.array(vars_list).astype(np.float32) cmvn = np.array([means, vars])