aishell example

This commit is contained in:
游雁 2024-02-19 21:26:25 +08:00
parent 4ebde3c4ac
commit ff4306346e
6 changed files with 105 additions and 20 deletions

View File

@ -50,6 +50,7 @@ inference_scp="wav.scp"
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Data Download"
mkdir -p ${raw_data}
local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
fi
@ -76,9 +77,8 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: Feature and CMVN Generation"
# utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$config" --scale 1.0
python ../../../funasr/bin/compute_audio_cmvn.py \
--config-path "${workspace}" \
--config-path "${workspace}/conf" \
--config-name "${config}" \
++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \
@ -109,13 +109,14 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
log_file="${exp_dir}/exp/${model_dir}/train.log.txt"
echo "log_file: ${log_file}"
torchrun \
--nnodes 1 \
--nproc_per_node ${gpu_num} \
../../../funasr/bin/train.py \
--config-path "${workspace}" \
--config-path "${workspace}/conf" \
--config-name "${config}" \
++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
++tokenizer_conf.token_list="${token_list}" \

View File

@ -79,8 +79,8 @@ def main(**kwargs):
fbank = batch["speech"].numpy()[0, :, :]
if total_frames == 0:
mean_stats = fbank
var_stats = np.square(fbank)
mean_stats = np.sum(fbank, axis=0)
var_stats = np.sum(np.square(fbank), axis=0)
else:
mean_stats += np.sum(fbank, axis=0)
var_stats += np.sum(np.square(fbank), axis=0)
@ -93,6 +93,7 @@ def main(**kwargs):
'total_frames': total_frames
}
cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
# import pdb;pdb.set_trace()
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
@ -110,14 +111,14 @@ def main(**kwargs):
fout.write("</Nnet>" + '\n')
"""
python funasr/bin/compute_audio_cmvn.py \
--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
++dataset_conf.num_workers=0
"""
if __name__ == "__main__":
main_hydra()
"""
python funasr/bin/compute_status.py \
--config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
++dataset_conf.num_workers=32
"""

View File

@ -79,9 +79,8 @@ def main(**kwargs):
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# import pdb;
# pdb.set_trace()
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))

View File

@ -22,12 +22,12 @@ class AudioDataset(torch.utils.data.Dataset):
self.index_ds = index_ds_class(path, **kwargs)
preprocessor_speech = kwargs.get("preprocessor_speech", None)
if preprocessor_speech:
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
self.preprocessor_speech = preprocessor_speech
preprocessor_text = kwargs.get("preprocessor_text", None)
if preprocessor_text:
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
self.preprocessor_text = preprocessor_text
@ -57,7 +57,7 @@ class AudioDataset(torch.utils.data.Dataset):
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src)
data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
target = item["target"]

View File

@ -0,0 +1,83 @@
import os
import json
import torch
import logging
import concurrent.futures
import librosa
import torch.distributed as dist
from typing import Collection
import torch
import torchaudio
from torch import nn
import random
import re
from funasr.tokenizer.cleaner import TextCleaner
from funasr.register import tables
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
class SpeechPreprocessSpeedPerturb(nn.Module):
def __init__(self, speed_perturb: list=None, **kwargs):
super().__init__()
self.speed_perturb = speed_perturb
def forward(self, waveform, fs, **kwargs):
if self.speed_perturb is None:
return waveform
speed = random.choice(self.speed_perturb)
if speed != 1.0:
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
waveform = waveform.view(-1)
return waveform
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
class TextPreprocessSegDict(nn.Module):
def __init__(self, seg_dict: str = None,
text_cleaner: Collection[str] = None,
split_with_space: bool = False,
**kwargs):
super().__init__()
self.seg_dict = None
if seg_dict is not None:
self.seg_dict = {}
with open(seg_dict, "r", encoding="utf8") as f:
lines = f.readlines()
for line in lines:
s = line.strip().split()
key = s[0]
value = s[1:]
self.seg_dict[key] = " ".join(value)
self.text_cleaner = TextCleaner(text_cleaner)
self.split_with_space = split_with_space
def forward(self, text, **kwargs):
if self.seg_dict is not None:
text = self.text_cleaner(text)
if self.split_with_space:
tokens = text.strip().split(" ")
if self.seg_dict is not None:
text = seg_tokenize(tokens, self.seg_dict)
return text
def seg_tokenize(txt, seg_dict):
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
out_txt = ""
for word in txt:
word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
if pattern.match(word):
for char in word:
if char in seg_dict:
out_txt += seg_dict[char] + " "
else:
out_txt += "<unk>" + " "
else:
out_txt += "<unk>" + " "
return out_txt.strip().split()

View File

@ -32,6 +32,7 @@ def load_cmvn(cmvn_file):
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
import pdb;pdb.set_trace()
means = np.array(means_list).astype(np.float32)
vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])