mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
aishell example
This commit is contained in:
parent
4ebde3c4ac
commit
ff4306346e
@ -50,6 +50,7 @@ inference_scp="wav.scp"
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
echo "stage -1: Data Download"
|
||||
mkdir -p ${raw_data}
|
||||
local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
|
||||
local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
|
||||
fi
|
||||
@ -76,9 +77,8 @@ fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "stage 1: Feature and CMVN Generation"
|
||||
# utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$config" --scale 1.0
|
||||
python ../../../funasr/bin/compute_audio_cmvn.py \
|
||||
--config-path "${workspace}" \
|
||||
--config-path "${workspace}/conf" \
|
||||
--config-name "${config}" \
|
||||
++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
|
||||
++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \
|
||||
@ -109,13 +109,14 @@ fi
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "stage 4: ASR Training"
|
||||
|
||||
mkdir -p ${exp_dir}/exp/${model_dir}
|
||||
log_file="${exp_dir}/exp/${model_dir}/train.log.txt"
|
||||
echo "log_file: ${log_file}"
|
||||
torchrun \
|
||||
--nnodes 1 \
|
||||
--nproc_per_node ${gpu_num} \
|
||||
../../../funasr/bin/train.py \
|
||||
--config-path "${workspace}" \
|
||||
--config-path "${workspace}/conf" \
|
||||
--config-name "${config}" \
|
||||
++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
|
||||
++tokenizer_conf.token_list="${token_list}" \
|
||||
|
||||
@ -79,8 +79,8 @@ def main(**kwargs):
|
||||
|
||||
fbank = batch["speech"].numpy()[0, :, :]
|
||||
if total_frames == 0:
|
||||
mean_stats = fbank
|
||||
var_stats = np.square(fbank)
|
||||
mean_stats = np.sum(fbank, axis=0)
|
||||
var_stats = np.sum(np.square(fbank), axis=0)
|
||||
else:
|
||||
mean_stats += np.sum(fbank, axis=0)
|
||||
var_stats += np.sum(np.square(fbank), axis=0)
|
||||
@ -93,6 +93,7 @@ def main(**kwargs):
|
||||
'total_frames': total_frames
|
||||
}
|
||||
cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
|
||||
# import pdb;pdb.set_trace()
|
||||
with open(cmvn_file, 'w') as fout:
|
||||
fout.write(json.dumps(cmvn_info))
|
||||
|
||||
@ -110,14 +111,14 @@ def main(**kwargs):
|
||||
fout.write("</Nnet>" + '\n')
|
||||
|
||||
|
||||
|
||||
|
||||
"""
|
||||
python funasr/bin/compute_audio_cmvn.py \
|
||||
--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
|
||||
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
|
||||
++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
|
||||
++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
|
||||
++dataset_conf.num_workers=0
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
main_hydra()
|
||||
"""
|
||||
python funasr/bin/compute_status.py \
|
||||
--config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \
|
||||
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
|
||||
++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
|
||||
++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
|
||||
++dataset_conf.num_workers=32
|
||||
"""
|
||||
@ -79,9 +79,8 @@ def main(**kwargs):
|
||||
frontend = frontend_class(**kwargs["frontend_conf"])
|
||||
kwargs["frontend"] = frontend
|
||||
kwargs["input_size"] = frontend.output_size()
|
||||
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
|
||||
|
||||
# build model
|
||||
model_class = tables.model_classes.get(kwargs["model"])
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
||||
|
||||
@ -22,12 +22,12 @@ class AudioDataset(torch.utils.data.Dataset):
|
||||
self.index_ds = index_ds_class(path, **kwargs)
|
||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||
if preprocessor_speech:
|
||||
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
|
||||
preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
|
||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
||||
self.preprocessor_speech = preprocessor_speech
|
||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||
if preprocessor_text:
|
||||
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
|
||||
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
|
||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||
self.preprocessor_text = preprocessor_text
|
||||
|
||||
@ -57,7 +57,7 @@ class AudioDataset(torch.utils.data.Dataset):
|
||||
source = item["source"]
|
||||
data_src = load_audio_text_image_video(source, fs=self.fs)
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src)
|
||||
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
||||
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
|
||||
|
||||
target = item["target"]
|
||||
|
||||
83
funasr/datasets/audio_datasets/preprocessor.py
Normal file
83
funasr/datasets/audio_datasets/preprocessor.py
Normal file
@ -0,0 +1,83 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
import concurrent.futures
|
||||
import librosa
|
||||
import torch.distributed as dist
|
||||
from typing import Collection
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
import random
|
||||
import re
|
||||
from funasr.tokenizer.cleaner import TextCleaner
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
|
||||
class SpeechPreprocessSpeedPerturb(nn.Module):
|
||||
def __init__(self, speed_perturb: list=None, **kwargs):
|
||||
super().__init__()
|
||||
self.speed_perturb = speed_perturb
|
||||
|
||||
def forward(self, waveform, fs, **kwargs):
|
||||
if self.speed_perturb is None:
|
||||
return waveform
|
||||
speed = random.choice(self.speed_perturb)
|
||||
if speed != 1.0:
|
||||
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
|
||||
waveform = waveform.view(-1)
|
||||
|
||||
return waveform
|
||||
|
||||
|
||||
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
|
||||
class TextPreprocessSegDict(nn.Module):
|
||||
def __init__(self, seg_dict: str = None,
|
||||
text_cleaner: Collection[str] = None,
|
||||
split_with_space: bool = False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.seg_dict = None
|
||||
if seg_dict is not None:
|
||||
self.seg_dict = {}
|
||||
with open(seg_dict, "r", encoding="utf8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
s = line.strip().split()
|
||||
key = s[0]
|
||||
value = s[1:]
|
||||
self.seg_dict[key] = " ".join(value)
|
||||
self.text_cleaner = TextCleaner(text_cleaner)
|
||||
self.split_with_space = split_with_space
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
if self.seg_dict is not None:
|
||||
text = self.text_cleaner(text)
|
||||
if self.split_with_space:
|
||||
tokens = text.strip().split(" ")
|
||||
if self.seg_dict is not None:
|
||||
text = seg_tokenize(tokens, self.seg_dict)
|
||||
|
||||
return text
|
||||
|
||||
def seg_tokenize(txt, seg_dict):
|
||||
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
|
||||
out_txt = ""
|
||||
for word in txt:
|
||||
word = word.lower()
|
||||
if word in seg_dict:
|
||||
out_txt += seg_dict[word] + " "
|
||||
else:
|
||||
if pattern.match(word):
|
||||
for char in word:
|
||||
if char in seg_dict:
|
||||
out_txt += seg_dict[char] + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
return out_txt.strip().split()
|
||||
@ -32,6 +32,7 @@ def load_cmvn(cmvn_file):
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
vars_list = list(rescale_line)
|
||||
continue
|
||||
import pdb;pdb.set_trace()
|
||||
means = np.array(means_list).astype(np.float32)
|
||||
vars = np.array(vars_list).astype(np.float32)
|
||||
cmvn = np.array([means, vars])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user