mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
aishell example
This commit is contained in:
parent
25fbf7110b
commit
ce4235b1c8
@ -99,7 +99,10 @@ dataset_conf:
|
||||
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
|
||||
buffer_size: 1024
|
||||
shuffle: True
|
||||
num_workers: 0
|
||||
num_workers: 4
|
||||
preprocessor_speech: SpeechPreprocessSpeedPerturb
|
||||
preprocessor_speech_conf:
|
||||
speed_perturb: [0.9, 1.0, 1.1]
|
||||
|
||||
tokenizer: CharTokenizer
|
||||
tokenizer_conf:
|
||||
|
||||
@ -1,13 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
workspace=`pwd`
|
||||
|
||||
# machines configuration
|
||||
|
||||
CUDA_VISIBLE_DEVICES="0,1"
|
||||
gpu_num=2
|
||||
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
|
||||
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
|
||||
njob=1
|
||||
|
||||
# general configuration
|
||||
feats_dir="../DATA" #feature output dictionary
|
||||
@ -18,7 +13,11 @@ stage=0
|
||||
stop_stage=5
|
||||
|
||||
# feature configuration
|
||||
nj=64
|
||||
nj=32
|
||||
|
||||
inference_device="cuda" #"cpu"
|
||||
inference_checkpoint="model.pt"
|
||||
inference_scp="wav.scp"
|
||||
|
||||
# data
|
||||
raw_data=../raw_data
|
||||
@ -26,6 +25,7 @@ data_url=www.openslr.org/resources/33
|
||||
|
||||
# exp tag
|
||||
tag="exp1"
|
||||
workspace=`pwd`
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
@ -42,11 +42,6 @@ test_sets="dev test"
|
||||
config=train_asr_paraformer_conformer_12e_6d_2048_256.yaml
|
||||
model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_device="cuda" #"cpu"
|
||||
inference_checkpoint="model.pt"
|
||||
inference_scp="wav.scp"
|
||||
|
||||
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
echo "stage -1: Data Download"
|
||||
@ -112,6 +107,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
mkdir -p ${exp_dir}/exp/${model_dir}
|
||||
log_file="${exp_dir}/exp/${model_dir}/train.log.txt"
|
||||
echo "log_file: ${log_file}"
|
||||
|
||||
gpu_num=$(echo CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
torchrun \
|
||||
--nnodes 1 \
|
||||
--nproc_per_node ${gpu_num} \
|
||||
|
||||
@ -41,43 +41,9 @@ class TextPreprocessSegDict(nn.Module):
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.seg_dict = None
|
||||
if seg_dict is not None:
|
||||
self.seg_dict = {}
|
||||
with open(seg_dict, "r", encoding="utf8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
s = line.strip().split()
|
||||
key = s[0]
|
||||
value = s[1:]
|
||||
self.seg_dict[key] = " ".join(value)
|
||||
self.text_cleaner = TextCleaner(text_cleaner)
|
||||
self.split_with_space = split_with_space
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
if self.seg_dict is not None:
|
||||
text = self.text_cleaner(text)
|
||||
if self.split_with_space:
|
||||
tokens = text.strip().split(" ")
|
||||
if self.seg_dict is not None:
|
||||
text = seg_tokenize(tokens, self.seg_dict)
|
||||
|
||||
text = self.text_cleaner(text)
|
||||
|
||||
return text
|
||||
|
||||
def seg_tokenize(txt, seg_dict):
|
||||
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
|
||||
out_txt = ""
|
||||
for word in txt:
|
||||
word = word.lower()
|
||||
if word in seg_dict:
|
||||
out_txt += seg_dict[word] + " "
|
||||
else:
|
||||
if pattern.match(word):
|
||||
for char in word:
|
||||
if char in seg_dict:
|
||||
out_txt += seg_dict[char] + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
return out_txt.strip().split()
|
||||
@ -32,7 +32,6 @@ def load_cmvn(cmvn_file):
|
||||
rescale_line = line_item[3:(len(line_item) - 1)]
|
||||
vars_list = list(rescale_line)
|
||||
continue
|
||||
import pdb;pdb.set_trace()
|
||||
means = np.array(means_list).astype(np.float32)
|
||||
vars = np.array(vars_list).astype(np.float32)
|
||||
cmvn = np.array([means, vars])
|
||||
|
||||
@ -3,60 +3,105 @@ from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Union
|
||||
import warnings
|
||||
import re
|
||||
|
||||
from funasr.tokenizer.abs_tokenizer import BaseTokenizer
|
||||
from funasr.register import tables
|
||||
|
||||
@tables.register("tokenizer_classes", "CharTokenizer")
|
||||
class CharTokenizer(BaseTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
space_symbol: str = "<space>",
|
||||
remove_non_linguistic_symbols: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.space_symbol = space_symbol
|
||||
if non_linguistic_symbols is None:
|
||||
self.non_linguistic_symbols = set()
|
||||
elif isinstance(non_linguistic_symbols, (Path, str)):
|
||||
non_linguistic_symbols = Path(non_linguistic_symbols)
|
||||
try:
|
||||
with non_linguistic_symbols.open("r", encoding="utf-8") as f:
|
||||
self.non_linguistic_symbols = set(line.rstrip() for line in f)
|
||||
except FileNotFoundError:
|
||||
warnings.warn(f"{non_linguistic_symbols} doesn't exist.")
|
||||
self.non_linguistic_symbols = set()
|
||||
else:
|
||||
self.non_linguistic_symbols = set(non_linguistic_symbols)
|
||||
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
|
||||
def __init__(
|
||||
self,
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
space_symbol: str = "<space>",
|
||||
remove_non_linguistic_symbols: bool = False,
|
||||
split_with_space: bool = False,
|
||||
seg_dict: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.space_symbol = space_symbol
|
||||
if non_linguistic_symbols is None:
|
||||
self.non_linguistic_symbols = set()
|
||||
elif isinstance(non_linguistic_symbols, (Path, str)):
|
||||
non_linguistic_symbols = Path(non_linguistic_symbols)
|
||||
try:
|
||||
with non_linguistic_symbols.open("r", encoding="utf-8") as f:
|
||||
self.non_linguistic_symbols = set(line.rstrip() for line in f)
|
||||
except FileNotFoundError:
|
||||
warnings.warn(f"{non_linguistic_symbols} doesn't exist.")
|
||||
self.non_linguistic_symbols = set()
|
||||
else:
|
||||
self.non_linguistic_symbols = set(non_linguistic_symbols)
|
||||
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
|
||||
self.split_with_space = split_with_space
|
||||
self.seg_dict = None
|
||||
if seg_dict is not None:
|
||||
self.seg_dict = load_seg_dict(seg_dict)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f'space_symbol="{self.space_symbol}"'
|
||||
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
|
||||
f")"
|
||||
)
|
||||
|
||||
def text2tokens(self, line: Union[str, list]) -> List[str]:
|
||||
|
||||
if self.split_with_space:
|
||||
tokens = line.strip().split(" ")
|
||||
if self.seg_dict is not None:
|
||||
tokens = seg_tokenize(tokens, self.seg_dict)
|
||||
else:
|
||||
tokens = []
|
||||
while len(line) != 0:
|
||||
for w in self.non_linguistic_symbols:
|
||||
if line.startswith(w):
|
||||
if not self.remove_non_linguistic_symbols:
|
||||
tokens.append(line[: len(w)])
|
||||
line = line[len(w) :]
|
||||
break
|
||||
else:
|
||||
t = line[0]
|
||||
if t == " ":
|
||||
t = "<space>"
|
||||
tokens.append(t)
|
||||
line = line[1:]
|
||||
return tokens
|
||||
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
tokens = [t if t != self.space_symbol else " " for t in tokens]
|
||||
return "".join(tokens)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f'space_symbol="{self.space_symbol}"'
|
||||
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
|
||||
f")"
|
||||
)
|
||||
|
||||
def text2tokens(self, line: Union[str, list]) -> List[str]:
|
||||
tokens = []
|
||||
while len(line) != 0:
|
||||
for w in self.non_linguistic_symbols:
|
||||
if line.startswith(w):
|
||||
if not self.remove_non_linguistic_symbols:
|
||||
tokens.append(line[: len(w)])
|
||||
line = line[len(w) :]
|
||||
break
|
||||
else:
|
||||
t = line[0]
|
||||
if t == " ":
|
||||
t = "<space>"
|
||||
tokens.append(t)
|
||||
line = line[1:]
|
||||
return tokens
|
||||
def load_seg_dict(seg_dict_file):
|
||||
seg_dict = {}
|
||||
assert isinstance(seg_dict_file, str)
|
||||
with open(seg_dict_file, "r", encoding="utf8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
s = line.strip().split()
|
||||
key = s[0]
|
||||
value = s[1:]
|
||||
seg_dict[key] = " ".join(value)
|
||||
return seg_dict
|
||||
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
tokens = [t if t != self.space_symbol else " " for t in tokens]
|
||||
return "".join(tokens)
|
||||
def seg_tokenize(txt, seg_dict):
|
||||
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
|
||||
out_txt = ""
|
||||
for word in txt:
|
||||
word = word.lower()
|
||||
if word in seg_dict:
|
||||
out_txt += seg_dict[word] + " "
|
||||
else:
|
||||
if pattern.match(word):
|
||||
for char in word:
|
||||
if char in seg_dict:
|
||||
out_txt += seg_dict[char] + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
else:
|
||||
out_txt += "<unk>" + " "
|
||||
return out_txt.strip().split()
|
||||
Loading…
Reference in New Issue
Block a user