Merge pull request #26 from alibaba-damo-academy/dev

Dev
This commit is contained in:
zhifu gao 2022-12-30 18:58:33 +08:00 committed by GitHub
commit ddef3c9d92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
277 changed files with 21 additions and 28771 deletions

BIN
.DS_Store vendored

Binary file not shown.

8
.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
.idea
./__pycache__/
*/__pycache__/
*/*/__pycache__/
*/*/*/__pycache__/
.DS_Store
init_model/
*.tar.gz

3
.idea/.gitignore vendored
View File

@ -1,3 +0,0 @@
# Default ignored files
/shelf/
/workspace.xml

View File

@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@ -1,50 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="37">
<item index="0" class="java.lang.String" itemvalue="sentencepiece" />
<item index="1" class="java.lang.String" itemvalue="pyyaml" />
<item index="2" class="java.lang.String" itemvalue="parameterized" />
<item index="3" class="java.lang.String" itemvalue="rouge" />
<item index="4" class="java.lang.String" itemvalue="nose2" />
<item index="5" class="java.lang.String" itemvalue="timm" />
<item index="6" class="java.lang.String" itemvalue="mypy" />
<item index="7" class="java.lang.String" itemvalue="tensorboardX" />
<item index="8" class="java.lang.String" itemvalue="torch" />
<item index="9" class="java.lang.String" itemvalue="blobfile" />
<item index="10" class="java.lang.String" itemvalue="torchvision" />
<item index="11" class="java.lang.String" itemvalue="einops" />
<item index="12" class="java.lang.String" itemvalue="editdistance" />
<item index="13" class="java.lang.String" itemvalue="soundfile" />
<item index="14" class="java.lang.String" itemvalue="typeguard" />
<item index="15" class="java.lang.String" itemvalue="torchaudio" />
<item index="16" class="java.lang.String" itemvalue="kaldiio" />
<item index="17" class="java.lang.String" itemvalue="webdataset" />
<item index="18" class="java.lang.String" itemvalue="g2p" />
<item index="19" class="java.lang.String" itemvalue="flake8-executable" />
<item index="20" class="java.lang.String" itemvalue="pyflakes" />
<item index="21" class="java.lang.String" itemvalue="textgrid" />
<item index="22" class="java.lang.String" itemvalue="flake8-pyi" />
<item index="23" class="java.lang.String" itemvalue="pycodestyle" />
<item index="24" class="java.lang.String" itemvalue="flake8-bugbear" />
<item index="25" class="java.lang.String" itemvalue="flake8-comprehensions" />
<item index="26" class="java.lang.String" itemvalue="flake8" />
<item index="27" class="java.lang.String" itemvalue="shapely" />
<item index="28" class="java.lang.String" itemvalue="submitit" />
<item index="29" class="java.lang.String" itemvalue="tensorboard" />
<item index="30" class="java.lang.String" itemvalue="matplotlib" />
<item index="31" class="java.lang.String" itemvalue="ftfy" />
<item index="32" class="java.lang.String" itemvalue="sacremoses" />
<item index="33" class="java.lang.String" itemvalue="efficientnet_pytorch" />
<item index="34" class="java.lang.String" itemvalue="pytorch_lightning" />
<item index="35" class="java.lang.String" itemvalue="rouge_score" />
<item index="36" class="java.lang.String" itemvalue="torch_complex" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

View File

@ -1,4 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (funasr_doc)" project-jdk-type="Python SDK" />
</project>

View File

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/FunASR.iml" filepath="$PROJECT_DIR$/.idea/FunASR.iml" />
</modules>
</component>
</project>

View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

1
egs/aishell/conformer/utils Symbolic link
View File

@ -0,0 +1 @@
../tranformer/utils

View File

@ -1,79 +0,0 @@
from kaldiio import ReadHelper
from kaldiio import WriteHelper
import argparse
import json
import math
import numpy as np
def get_parser():
parser = argparse.ArgumentParser(
description="apply cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--ark-file",
"-a",
default=False,
required=True,
type=str,
help="fbank ark file",
)
parser.add_argument(
"--cmvn-file",
"-c",
default=False,
required=True,
type=str,
help="cmvn file",
)
parser.add_argument(
"--ark-index",
"-i",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
with open(args.cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stats']
vars = cmvn_stats['var_stats']
total_frames = cmvn_stats['total_frames']
for i in range(len(means)):
means[i] /= total_frames
vars[i] = vars[i] / total_frames - means[i] * means[i]
if vars[i] < 1.0e-20:
vars[i] = 1.0e-20
vars[i] = 1.0 / math.sqrt(vars[i])
with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
for key, mat in ark_reader:
mat = (mat - means) * vars
ark_writer(key, mat)
if __name__ == '__main__':
main()

View File

@ -1,29 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
nj=32
cmd=./utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
fbankdir=$1
cmvn_file=$2
logdir=$3
output_dir=$4
dump_dir=${output_dir}/ark; mkdir -p ${dump_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/apply_cmvn.JOB.log \
python utils/apply_cmvn.py -a $fbankdir/ark/feats.JOB.ark \
-c $cmvn_file -i JOB -o ${dump_dir} \
|| exit 1;
for n in $(seq $nj); do
cat ${dump_dir}/feats.$n.scp || exit 1
done > ${output_dir}/feats.scp || exit 1
echo "$0: Succeeded apply cmvn"

View File

@ -1,143 +0,0 @@
from kaldiio import ReadHelper, WriteHelper
import argparse
import numpy as np
def build_LFR_features(inputs, m=7, n=6):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / n))
left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (m - 1) // 2
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
else:
num_padding = m - (T - i * n)
frame = np.hstack(inputs[i * n:])
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)
def build_CMVN_features(inputs, mvn_file): # noqa
with open(mvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
add_shift_list = []
rescale_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
add_shift_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
rescale_list = list(rescale_line)
continue
for j in range(inputs.shape[0]):
for k in range(inputs.shape[1]):
add_shift_value = add_shift_list[k]
rescale_value = rescale_list[k]
inputs[j, k] = float(inputs[j, k]) + float(add_shift_value)
inputs[j, k] = float(inputs[j, k]) * float(rescale_value)
return inputs
def get_parser():
parser = argparse.ArgumentParser(
description="apply low_frame_rate and cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--ark-file",
"-a",
default=False,
required=True,
type=str,
help="fbank ark file",
)
parser.add_argument(
"--lfr",
"-f",
default=True,
type=str,
help="low frame rate",
)
parser.add_argument(
"--lfr-m",
"-m",
default=7,
type=int,
help="number of frames to stack",
)
parser.add_argument(
"--lfr-n",
"-n",
default=6,
type=int,
help="number of frames to skip",
)
parser.add_argument(
"--cmvn-file",
"-c",
default=False,
required=True,
type=str,
help="global cmvn file",
)
parser.add_argument(
"--ark-index",
"-i",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
dump_ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
dump_scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
shape_file = args.output_dir + "/len." + str(args.ark_index)
ark_writer = WriteHelper('ark,scp:{},{}'.format(dump_ark_file, dump_scp_file))
shape_writer = open(shape_file, 'w')
with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
for key, mat in ark_reader:
if args.lfr:
lfr = build_LFR_features(mat, args.lfr_m, args.lfr_n)
else:
lfr = mat
cmvn = build_CMVN_features(lfr, args.cmvn_file)
dims = cmvn.shape[1]
lens = cmvn.shape[0]
shape_writer.write(key + " " + str(lens) + "," + str(dims) + '\n')
ark_writer(key, cmvn)
if __name__ == '__main__':
main()

View File

@ -1,38 +0,0 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=utils/run.pl
# feature configuration
lfr=True
lfr_m=7
lfr_n=6
echo "$0 $@"
. utils/parse_options.sh || exit 1;
fbankdir=$1
cmvn_file=$2
logdir=$3
output_dir=$4
dump_dir=${output_dir}/ark; mkdir -p ${dump_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/apply_lfr_and_cmvn.JOB.log \
python utils/apply_lfr_and_cmvn.py -a $fbankdir/ark/feats.JOB.ark \
-f $lfr -m $lfr_m -n $lfr_n -c $cmvn_file -i JOB -o ${dump_dir} \
|| exit 1;
for n in $(seq $nj); do
cat ${dump_dir}/feats.$n.scp || exit 1
done > ${output_dir}/feats.scp || exit 1
for n in $(seq $nj); do
cat ${dump_dir}/len.$n || exit 1
done > ${output_dir}/speech_shape || exit 1
echo "$0: Succeeded apply low frame rate and cmvn"

View File

@ -1,73 +0,0 @@
import argparse
import json
import numpy as np
def get_parser():
parser = argparse.ArgumentParser(
description="combine cmvn file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--cmvn-dir",
"-c",
default=False,
required=True,
type=str,
help="cmvn dir",
)
parser.add_argument(
"--nj",
"-n",
default=1,
required=True,
type=int,
help="num of cmvn file",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
total_means = np.zeros(args.dims)
total_vars = np.zeros(args.dims)
total_frames = 0
cmvn_file = args.output_dir + "/cmvn.json"
for i in range(1, args.nj+1):
with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
cmvn_stats = json.load(fin)
total_means += np.array(cmvn_stats["mean_stats"])
total_vars += np.array(cmvn_stats["var_stats"])
total_frames += cmvn_stats["total_frames"]
cmvn_info = {
'mean_stats': list(total_means.tolist()),
'var_stats': list(total_vars.tolist()),
'total_frames': total_frames
}
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
if __name__ == '__main__':
main()

View File

@ -1,74 +0,0 @@
from kaldiio import ReadHelper
import argparse
import numpy as np
import json
def get_parser():
parser = argparse.ArgumentParser(
description="computer global cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--ark-file",
"-a",
default=False,
required=True,
type=str,
help="fbank ark file",
)
parser.add_argument(
"--ark-index",
"-i",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
mean_stats = np.zeros(args.dims)
var_stats = np.zeros(args.dims)
total_frames = 0
with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
for key, mat in ark_reader:
mean_stats += np.sum(mat, axis=0)
var_stats += np.sum(np.square(mat), axis=0)
total_frames += mat.shape[0]
cmvn_info = {
'mean_stats': list(mean_stats.tolist()),
'var_stats': list(var_stats.tolist()),
'total_frames': total_frames
}
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
if __name__ == '__main__':
main()

View File

@ -1,25 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
nj=32
cmd=./utils/run.pl
feats_dim=80
echo "$0 $@"
. utils/parse_options.sh || exit 1;
fbankdir=$1
logdir=$2
output_dir=${fbankdir}/cmvn; mkdir -p ${output_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
|| exit 1;
python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
echo "$0: Succeeded compute global cmvn"

View File

@ -1,153 +0,0 @@
from kaldiio import WriteHelper
import argparse
import numpy as np
import json
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
def compute_fbank(wav_file,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
resample_rate=16000,
speed=1.0):
waveform, sample_rate = torchaudio.load(wav_file)
if resample_rate != sample_rate:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
new_freq=resample_rate)(waveform)
if speed != 1.0:
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, resample_rate,
[['speed', str(speed)], ['rate', str(resample_rate)]]
)
waveform = waveform * (1 << 15)
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
window_type='hamming',
sample_frequency=resample_rate)
return mat.numpy()
def get_parser():
parser = argparse.ArgumentParser(
description="computer features",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--wav-lists",
"-w",
default=False,
required=True,
type=str,
help="input wav lists",
)
parser.add_argument(
"--text-files",
"-t",
default=False,
required=True,
type=str,
help="input text files",
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--sample-frequency",
"-s",
default=16000,
type=int,
help="sample frequency",
)
parser.add_argument(
"--speed-perturb",
"-p",
default="1.0",
type=str,
help="speed perturb",
)
parser.add_argument(
"--ark-index",
"-a",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".ark"
scp_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".scp"
text_file = args.output_dir + "/txt/text." + str(args.ark_index) + ".txt"
feats_shape_file = args.output_dir + "/ark/len." + str(args.ark_index)
text_shape_file = args.output_dir + "/txt/len." + str(args.ark_index)
ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
text_writer = open(text_file, 'w')
feats_shape_writer = open(feats_shape_file, 'w')
text_shape_writer = open(text_shape_file, 'w')
speed_perturb_list = args.speed_perturb.split(',')
for speed in speed_perturb_list:
with open(args.wav_lists, 'r', encoding='utf-8') as wavfile:
with open(args.text_files, 'r', encoding='utf-8') as textfile:
for wav, text in zip(wavfile, textfile):
s_w = wav.strip().split()
wav_id = s_w[0]
wav_file = s_w[1]
s_t = text.strip().split()
text_id = s_t[0]
txt = s_t[1:]
fbank = compute_fbank(wav_file,
num_mel_bins=args.dims,
resample_rate=args.sample_frequency,
speed=float(speed)
)
feats_dims = fbank.shape[1]
feats_lens = fbank.shape[0]
txt_lens = len(txt)
if speed == "1.0":
wav_id_sp = wav_id
else:
wav_id_sp = wav_id + "_sp" + speed
feats_shape_writer.write(wav_id_sp + " " + str(feats_lens) + "," + str(feats_dims) + '\n')
text_shape_writer.write(wav_id_sp + " " + str(txt_lens) + '\n')
text_writer.write(wav_id_sp + " " + " ".join(txt) + '\n')
ark_writer(wav_id_sp, fbank)
if __name__ == '__main__':
main()

View File

@ -1,51 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
nj=32
cmd=./utils/run.pl
# feature configuration
feats_dim=80
sample_frequency=16000
speed_perturb="1.0"
echo "$0 $@"
. utils/parse_options.sh || exit 1;
data=$1
logdir=$2
fbankdir=$3
[ ! -f $data/wav.scp ] && echo "$0: no such file $data/wav.scp" && exit 1;
[ ! -f $data/text ] && echo "$0: no such file $data/text" && exit 1;
python utils/split_data.py $data $data $nj
ark_dir=${fbankdir}/ark; mkdir -p ${ark_dir}
text_dir=${fbankdir}/txt; mkdir -p ${text_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/make_fbank.JOB.log \
python utils/compute_fbank.py -w $data/split${nj}/JOB/wav.scp -t $data/split${nj}/JOB/text \
-d $feats_dim -s $sample_frequency -p ${speed_perturb} -a JOB -o ${fbankdir} \
|| exit 1;
for n in $(seq $nj); do
cat ${ark_dir}/feats.$n.scp || exit 1
done > $fbankdir/feats.scp || exit 1
for n in $(seq $nj); do
cat ${text_dir}/text.$n.txt || exit 1
done > $fbankdir/text || exit 1
for n in $(seq $nj); do
cat ${ark_dir}/len.$n || exit 1
done > $fbankdir/speech_shape || exit 1
for n in $(seq $nj); do
cat ${text_dir}/len.$n || exit 1
done > $fbankdir/text_shape || exit 1
echo "$0: Succeeded compute FBANK features"

View File

@ -1,157 +0,0 @@
import os
import numpy as np
import sys
def compute_wer(ref_file,
hyp_file,
cer_detail_file):
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
}
hyp_dict = {}
ref_dict = {}
with open(hyp_file, 'r') as hyp_reader:
for line in hyp_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
hyp_dict[key] = value
with open(ref_file, 'r') as ref_reader:
for line in ref_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
ref_dict[key] = value
cer_detail_writer = open(cer_detail_file, 'w')
for hyp_key in hyp_dict:
if hyp_key in ref_dict:
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
cer_detail_writer.write('\n')
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
def compute_wer_by_line(hyp,
ref):
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_ref,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
if i < 0 and j >= 0:
rst['del'] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
return rst
def print_cer_detail(rst):
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
if __name__ == '__main__':
if len(sys.argv) != 4:
print("usage : python compute-wer.py test.ref test.hyp test.wer")
sys.exit(0)
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
cer_detail_file = sys.argv[3]
compute_wer(ref_file, hyp_file, cer_detail_file)

View File

@ -1,370 +0,0 @@
#!/usr/bin/env python3
# coding=utf8
# Copyright 2021 Jiayu DU
import sys
import argparse
import json
import logging
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
DEBUG = None
def GetEditType(ref_token, hyp_token):
if ref_token == None and hyp_token != None:
return 'I'
elif ref_token != None and hyp_token == None:
return 'D'
elif ref_token == hyp_token:
return 'C'
elif ref_token != hyp_token:
return 'S'
else:
raise RuntimeError
class AlignmentArc:
def __init__(self, src, dst, ref, hyp):
self.src = src
self.dst = dst
self.ref = ref
self.hyp = hyp
self.edit_type = GetEditType(ref, hyp)
def similarity_score_function(ref_token, hyp_token):
return 0 if (ref_token == hyp_token) else -1.0
def insertion_score_function(token):
return -1.0
def deletion_score_function(token):
return -1.0
def EditDistance(
ref,
hyp,
similarity_score_function = similarity_score_function,
insertion_score_function = insertion_score_function,
deletion_score_function = deletion_score_function):
assert(len(ref) != 0)
class DPState:
def __init__(self):
self.score = -float('inf')
# backpointer
self.prev_r = None
self.prev_h = None
def print_search_grid(S, R, H, fstream):
print(file=fstream)
for r in range(R):
for h in range(H):
print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
print(file=fstream)
R = len(ref) + 1
H = len(hyp) + 1
# Construct DP search space, a (R x H) grid
S = [ [] for r in range(R) ]
for r in range(R):
S[r] = [ DPState() for x in range(H) ]
# initialize DP search grid origin, S(r = 0, h = 0)
S[0][0].score = 0.0
S[0][0].prev_r = None
S[0][0].prev_h = None
# initialize REF axis
for r in range(1, R):
S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
S[r][0].prev_r = r-1
S[r][0].prev_h = 0
# initialize HYP axis
for h in range(1, H):
S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
S[0][h].prev_r = 0
S[0][h].prev_h = h-1
best_score = S[0][0].score
best_state = (0, 0)
for r in range(1, R):
for h in range(1, H):
sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
new_score = S[r-1][h-1].score + sub_or_cor_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r-1
S[r][h].prev_h = h-1
del_score = deletion_score_function(ref[r-1])
new_score = S[r-1][h].score + del_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r - 1
S[r][h].prev_h = h
ins_score = insertion_score_function(hyp[h-1])
new_score = S[r][h-1].score + ins_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r
S[r][h].prev_h = h-1
best_score = S[R-1][H-1].score
best_state = (R-1, H-1)
if DEBUG:
print_search_grid(S, R, H, sys.stderr)
# Backtracing best alignment path, i.e. a list of arcs
# arc = (src, dst, ref, hyp, edit_type)
# src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
best_path = []
r, h = best_state[0], best_state[1]
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
score = S[r][h].score
# loop invariant:
# 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
# 2. score is the value of point(r, h) on DP search grid
while prev_r != None or prev_h != None:
src = (prev_r, prev_h)
dst = (r, h)
if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
elif (r == prev_r + 1 and h == prev_h): # Deletion
arc = AlignmentArc(src, dst, ref[prev_r], None)
elif (r == prev_r and h == prev_h + 1): # Insertion
arc = AlignmentArc(src, dst, None, hyp[prev_h])
else:
raise RuntimeError
best_path.append(arc)
r, h = prev_r, prev_h
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
score = S[r][h].score
best_path.reverse()
return (best_path, best_score)
def PrettyPrintAlignment(alignment, stream = sys.stderr):
def get_token_str(token):
if token == None:
return "*"
return token
def is_double_width_char(ch):
if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
return True
# TODO: support other double-width-char language such as Japanese, Korean
else:
return False
def display_width(token_str):
m = 0
for c in token_str:
if is_double_width_char(c):
m += 2
else:
m += 1
return m
R = ' REF : '
H = ' HYP : '
E = ' EDIT : '
for arc in alignment:
r = get_token_str(arc.ref)
h = get_token_str(arc.hyp)
e = arc.edit_type if arc.edit_type != 'C' else ''
nr, nh, ne = display_width(r), display_width(h), display_width(e)
n = max(nr, nh, ne) + 1
R += r + ' ' * (n-nr)
H += h + ' ' * (n-nh)
E += e + ' ' * (n-ne)
print(R, file=stream)
print(H, file=stream)
print(E, file=stream)
def CountEdits(alignment):
c, s, i, d = 0, 0, 0, 0
for arc in alignment:
if arc.edit_type == 'C':
c += 1
elif arc.edit_type == 'S':
s += 1
elif arc.edit_type == 'I':
i += 1
elif arc.edit_type == 'D':
d += 1
else:
raise RuntimeError
return (c, s, i, d)
def ComputeTokenErrorRate(c, s, i, d):
return 100.0 * (s + d + i) / (s + d + c)
def ComputeSentenceErrorRate(num_err_utts, num_utts):
assert(num_utts != 0)
return 100.0 * num_err_utts / num_utts
class EvaluationResult:
def __init__(self):
self.num_ref_utts = 0
self.num_hyp_utts = 0
self.num_eval_utts = 0 # seen in both ref & hyp
self.num_hyp_without_ref = 0
self.C = 0
self.S = 0
self.I = 0
self.D = 0
self.token_error_rate = 0.0
self.num_utts_with_error = 0
self.sentence_error_rate = 0.0
def to_json(self):
return json.dumps(self.__dict__)
def to_kaldi(self):
info = (
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
)
return info
def to_sclite(self):
return "TODO"
def to_espnet(self):
return "TODO"
def to_summary(self):
#return json.dumps(self.__dict__, indent=4)
summary = (
'==================== Overall Statistics ====================\n'
F'num_ref_utts: {self.num_ref_utts}\n'
F'num_hyp_utts: {self.num_hyp_utts}\n'
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
F'num_eval_utts: {self.num_eval_utts}\n'
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
F'token_error_rate: {self.token_error_rate:.2f}%\n'
F'token_stats:\n'
F' - tokens:{self.C + self.S + self.D:>7}\n'
F' - edits: {self.S + self.I + self.D:>7}\n'
F' - cor: {self.C:>7}\n'
F' - sub: {self.S:>7}\n'
F' - ins: {self.I:>7}\n'
F' - del: {self.D:>7}\n'
'============================================================\n'
)
return summary
class Utterance:
def __init__(self, uid, text):
self.uid = uid
self.text = text
def LoadUtterances(filepath, format):
utts = {}
if format == 'text': # utt_id word1 word2 ...
with open(filepath, 'r', encoding='utf8') as f:
for line in f:
line = line.strip()
if line:
cols = line.split(maxsplit=1)
assert(len(cols) == 2 or len(cols) == 1)
uid = cols[0]
text = cols[1] if len(cols) == 2 else ''
if utts.get(uid) != None:
raise RuntimeError(F'Found duplicated utterence id {uid}')
utts[uid] = Utterance(uid, text)
else:
raise RuntimeError(F'Unsupported text format {format}')
return utts
def tokenize_text(text, tokenizer):
if tokenizer == 'whitespace':
return text.split()
elif tokenizer == 'char':
return [ ch for ch in ''.join(text.split()) ]
else:
raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# optional
parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
# required
parser.add_argument('--ref', type=str, required=True, help='input reference file')
parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
parser.add_argument('result_file', type=str)
args = parser.parse_args()
logging.info(args)
ref_utts = LoadUtterances(args.ref, args.ref_format)
hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
r = EvaluationResult()
# check valid utterances in hyp that have matched non-empty reference
eval_utts = []
r.num_hyp_without_ref = 0
for uid in sorted(hyp_utts.keys()):
if uid in ref_utts.keys(): # TODO: efficiency
if ref_utts[uid].text.strip(): # non-empty reference
eval_utts.append(uid)
else:
logging.warn(F'Found {uid} with empty reference, skipping...')
else:
logging.warn(F'Found {uid} without reference, skipping...')
r.num_hyp_without_ref += 1
r.num_hyp_utts = len(hyp_utts)
r.num_ref_utts = len(ref_utts)
r.num_eval_utts = len(eval_utts)
with open(args.result_file, 'w+', encoding='utf8') as fo:
for uid in eval_utts:
ref = ref_utts[uid]
hyp = hyp_utts[uid]
alignment, score = EditDistance(
tokenize_text(ref.text, args.tokenizer),
tokenize_text(hyp.text, args.tokenizer)
)
c, s, i, d = CountEdits(alignment)
utt_ter = ComputeTokenErrorRate(c, s, i, d)
# utt-level evaluation result
print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
PrettyPrintAlignment(alignment, fo)
r.C += c
r.S += s
r.I += i
r.D += d
if utt_ter > 0:
r.num_utts_with_error += 1
# corpus level evaluation result
r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
print(r.to_summary(), file=fo)
print(r.to_json())
print(r.to_kaldi())

View File

@ -1,47 +0,0 @@
from transformers import AutoTokenizer, AutoModel, pipeline
import numpy as np
import sys
import os
import torch
from kaldiio import WriteHelper
import re
text_file_json = sys.argv[1]
out_ark = sys.argv[2]
out_scp = sys.argv[3]
out_shape = sys.argv[4]
device = int(sys.argv[5])
model_path = sys.argv[6]
model = AutoModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
with open(text_file_json, 'r') as f:
js = f.readlines()
f_shape = open(out_shape, "w")
with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
with torch.no_grad():
for idx, line in enumerate(js):
id, tokens = line.strip().split(" ", 1)
tokens = re.sub(" ", "", tokens.strip())
tokens = ' '.join([j for j in tokens])
token_num = len(tokens.split(" "))
outputs = extractor(tokens)
outputs = np.array(outputs)
embeds = outputs[0, 1:-1, :]
token_num_embeds, dim = embeds.shape
if token_num == token_num_embeds:
writer(id, embeds)
shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
f_shape.write(shape_line)
else:
print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
f_shape.close()

View File

@ -1,87 +0,0 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
# Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
$shifted=0;
if ($ARGV[0] eq "--exclude") {
$exclude = 1;
shift @ARGV;
$shifted=1;
}
if ($ARGV[0] eq "-f") {
$field = $ARGV[1];
shift @ARGV; shift @ARGV;
$shifted=1
}
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
"only the lines that were *not* in id_list.\n" .
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
"-f option, add 1 to the argument.\n" .
"See also: scripts/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
@A = split;
@A>=1 || die "Invalid id-list file line $_";
$seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
while(<>) {
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
# $1 is what we filter on.
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
print $_;
}
}
} else {
while(<>) {
@A = split;
@A > 0 || die "Invalid scp file line $_";
@A >= $field || die "Invalid scp file line $_";
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
print $_;
}
}
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2)

View File

@ -1,35 +0,0 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/wav.scp ]; then
echo "$0: wav.scp is not found"
exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
echo "$0: text is not found"
exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp
cp ${data_dir}/text ${data_dir}/.backup/text
mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak > ${data_dir}/wav.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
rm ${data_dir}/wav.scp.bak
rm ${data_dir}/text.bak

View File

@ -1,52 +0,0 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/feats.scp ]; then
echo "$0: feats.scp is not found"
exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
echo "$0: text is not found"
exit 1;
fi
if [ ! -f ${data_dir}/speech_shape ]; then
echo "$0: feature lengths is not found"
exit 1;
fi
if [ ! -f ${data_dir}/text_shape ]; then
echo "$0: text lengths is not found"
exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp
cp ${data_dir}/text ${data_dir}/.backup/text
cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape
cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape
mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak > ${data_dir}/feats.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak > ${data_dir}/speech_shape
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak > ${data_dir}/text_shape
rm ${data_dir}/feats.scp.bak
rm ${data_dir}/text.bak
rm ${data_dir}/speech_shape.bak
rm ${data_dir}/text_shape.bak

View File

@ -1,22 +0,0 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=./utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
ark_dir=$1
txt_dir=$2
output_dir=$3
[ ! -d ${ark_dir}/ark ] && echo "$0: ark data is required" && exit 1;
[ ! -d ${txt_dir}/txt ] && echo "$0: txt data is required" && exit 1;
for n in $(seq $nj); do
echo "${ark_dir}/ark/feats.$n.ark ${txt_dir}/txt/text.$n.txt" || exit 1
done > ${output_dir}/ark_txt.scp || exit 1

View File

@ -1,97 +0,0 @@
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.

View File

@ -1,45 +0,0 @@
#!/usr/bin/env python
import sys
def get_commandline_args(no_executable=True):
extra_chars = [
" ",
";",
"&",
"|",
"<",
">",
"?",
"*",
"~",
"`",
'"',
"'",
"\\",
"{",
"}",
"(",
")",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''")
if all(char not in arg for char in extra_chars)
else "'" + arg.replace("'", "'\\''") + "'"
for arg in sys.argv
]
if no_executable:
return " ".join(argv[1:])
else:
return sys.executable + " " + " ".join(argv)
def main():
print(get_commandline_args())
if __name__ == "__main__":
main()

View File

@ -1,35 +0,0 @@
from pathlib import Path
import torch
import yaml
class NoAliasSafeDumper(yaml.SafeDumper):
# Disable anchor/alias in yaml because looks ugly
def ignore_aliases(self, data):
return True
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
"""Safe-dump in yaml with no anchor/alias"""
return yaml.dump(
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
)
def gen_conf(file, out_dir):
conf = torch.load(file)["config"]
conf["oss_bucket"] = "null"
print(conf)
output_dir = Path(out_dir)
output_dir.mkdir(parents=True, exist_ok=True)
with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
yaml_no_alias_safe_dump(conf, f, indent=4, sort_keys=False)
if __name__ == "__main__":
import sys
in_f = sys.argv[1]
out_f = sys.argv[2]
gen_conf(in_f, out_f)

View File

@ -1,31 +0,0 @@
import sys
import re
in_f = sys.argv[1]
out_f = sys.argv[2]
with open(in_f, "r", encoding="utf-8") as f:
lines = f.readlines()
with open(out_f, "w", encoding="utf-8") as f:
for line in lines:
outs = line.strip().split(" ", 1)
if len(outs) == 2:
idx, text = outs
text = re.sub("</s>", "", text)
text = re.sub("<s>", "", text)
text = re.sub("@@", "", text)
text = re.sub("@", "", text)
text = re.sub("<unk>", "", text)
text = re.sub(" ", "", text)
text = text.lower()
else:
idx = outs[0]
text = " "
text = [x for x in text]
text = " ".join(text)
out = "{} {}\n".format(idx, text)
f.write(out)

View File

@ -1,356 +0,0 @@
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# In general, doing
# run.pl some.log a b c is like running the command a b c in
# the bash shell, and putting the standard error and output into some.log.
# To run parallel jobs (backgrounded on the host machine), you can do (e.g.)
# run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB
# and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier].
# If any of the jobs fails, this script will fail.
# A typical example is:
# run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz
# and run.pl will run something like:
# ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log
#
# Basically it takes the command-line arguments, quotes them
# as necessary to preserve spaces, and evaluates them with bash.
# In addition it puts the command line at the top of the log, and
# the start and end times of the command at the beginning and end.
# The reason why this is useful is so that we can create a different
# version of this program that uses a queueing system instead.
#use Data::Dumper;
@ARGV < 2 && die "usage: run.pl log-file command-line arguments...";
#print STDERR "COMMAND-LINE: " . Dumper(\@ARGV) . "\n";
$job_pick = 'all';
$max_jobs_run = -1;
$jobstart = 1;
$jobend = 1;
$ignored_opts = ""; # These will be ignored.
# First parse an option like JOB=1:4, and any
# options that would normally be given to
# queue.pl, which we will just discard.
for (my $x = 1; $x <= 2; $x++) { # This for-loop is to
# allow the JOB=1:n option to be interleaved with the
# options to qsub.
while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) {
# parse any options that would normally go to qsub, but which will be ignored here.
my $switch = shift @ARGV;
if ($switch eq "-V") {
$ignored_opts .= "-V ";
} elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") {
# we do support the option --max-jobs-run n, and its GridEngine form -tc n.
# if the command appears multiple times uses the smallest option.
if ( $max_jobs_run <= 0 ) {
$max_jobs_run = shift @ARGV;
} else {
my $new_constraint = shift @ARGV;
if ( ($new_constraint < $max_jobs_run) ) {
$max_jobs_run = $new_constraint;
}
}
if (! ($max_jobs_run > 0)) {
die "run.pl: invalid option --max-jobs-run $max_jobs_run";
}
} else {
my $argument = shift @ARGV;
if ($argument =~ m/^--/) {
print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n";
}
if ($switch eq "-sync" && $argument =~ m/^[yY]/) {
$ignored_opts .= "-sync "; # Note: in the
# corresponding code in queue.pl it says instead, just "$sync = 1;".
} elsif ($switch eq "-pe") { # e.g. -pe smp 5
my $argument2 = shift @ARGV;
$ignored_opts .= "$switch $argument $argument2 ";
} elsif ($switch eq "--gpu") {
$using_gpu = $argument;
} elsif ($switch eq "--pick") {
if($argument =~ m/^(all|failed|incomplete)$/) {
$job_pick = $argument;
} else {
print STDERR "run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'"
}
} else {
# Ignore option.
$ignored_opts .= "$switch $argument ";
}
}
}
if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20
$jobname = $1;
$jobstart = $2;
$jobend = $3;
if ($jobstart > $jobend) {
die "run.pl: invalid job range $ARGV[0]";
}
if ($jobstart <= 0) {
die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility).";
}
shift;
} elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1.
$jobname = $1;
$jobstart = $2;
$jobend = $2;
shift;
} elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) {
print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n";
}
}
# Users found this message confusing so we are removing it.
# if ($ignored_opts ne "") {
# print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n";
# }
if ($max_jobs_run == -1) { # If --max-jobs-run option not set,
# then work out the number of processors if possible,
# and set it based on that.
$max_jobs_run = 0;
if ($using_gpu) {
if (open(P, "nvidia-smi -L |")) {
$max_jobs_run++ while (<P>);
close(P);
}
if ($max_jobs_run == 0) {
$max_jobs_run = 1;
print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n";
}
} elsif (open(P, "</proc/cpuinfo")) { # Linux
while (<P>) { if (m/^processor/) { $max_jobs_run++; } }
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n";
$max_jobs_run = 10; # reasonable default.
}
close(P);
} elsif (open(P, "sysctl -a |")) { # BSD/Darwin
while (<P>) {
if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4
$max_jobs_run = $1;
last;
}
}
close(P);
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n";
$max_jobs_run = 10; # reasonable default.
}
} else {
# allow at most 32 jobs at once, on non-UNIX systems; change this code
# if you need to change this default.
$max_jobs_run = 32;
}
# The just-computed value of $max_jobs_run is just the number of processors
# (or our best guess); and if it happens that the number of jobs we need to
# run is just slightly above $max_jobs_run, it will make sense to increase
# $max_jobs_run to equal the number of jobs, so we don't have a small number
# of leftover jobs.
$num_jobs = $jobend - $jobstart + 1;
if (!$using_gpu &&
$num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) {
$max_jobs_run = $num_jobs;
}
}
sub pick_or_exit {
# pick_or_exit ( $logfile )
# Invoked before each job is started helps to run jobs selectively.
#
# Given the name of the output logfile decides whether the job must be
# executed (by returning from the subroutine) or not (by terminating the
# process calling exit)
#
# PRE: $job_pick is a global variable set by command line switch --pick
# and indicates which class of jobs must be executed.
#
# 1) If a failed job is not executed the process exit code will indicate
# failure, just as if the task was just executed and failed.
#
# 2) If a task is incomplete it will be executed. Incomplete may be either
# a job whose log file does not contain the accounting notes in the end,
# or a job whose log file does not exist.
#
# 3) If the $job_pick is set to 'all' (default behavior) a task will be
# executed regardless of the result of previous attempts.
#
# This logic could have been implemented in the main execution loop
# but a subroutine to preserve the current level of readability of
# that part of the code.
#
# Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020
#
if($job_pick eq 'all'){
return; # no need to bother with the previous log
}
open my $fh, "<", $_[0] or return; # job not executed yet
my $log_line;
my $cur_line;
while ($cur_line = <$fh>) {
if( $cur_line =~ m/# Ended \(code .*/ ) {
$log_line = $cur_line;
}
}
close $fh;
if (! defined($log_line)){
return; # incomplete
}
if ( $log_line =~ m/# Ended \(code 0\).*/ ) {
exit(0); # complete
} elsif ( $log_line =~ m/# Ended \(code \d+(; signal \d+)?\).*/ ){
if ($job_pick !~ m/^(failed|all)$/) {
exit(1); # failed but not going to run
} else {
return; # failed
}
} elsif ( $log_line =~ m/.*\S.*/ ) {
return; # incomplete jobs are always run
}
}
$logfile = shift @ARGV;
if (defined $jobname && $logfile !~ m/$jobname/ &&
$jobend > $jobstart) {
print STDERR "run.pl: you are trying to run a parallel job but "
. "you are putting the output into just one log file ($logfile)\n";
exit(1);
}
$cmd = "";
foreach $x (@ARGV) {
if ($x =~ m/^\S+$/) { $cmd .= $x . " "; }
elsif ($x =~ m:\":) { $cmd .= "'$x' "; }
else { $cmd .= "\"$x\" "; }
}
#$Data::Dumper::Indent=0;
$ret = 0;
$numfail = 0;
%active_pids=();
use POSIX ":sys_wait_h";
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
if (scalar(keys %active_pids) >= $max_jobs_run) {
# Lets wait for a change in any child's status
# Then we have to work out which child finished
$r = waitpid(-1, 0);
$code = $?;
if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen.
if ( defined $active_pids{$r} ) {
$jid=$active_pids{$r};
$fail[$jid]=$code;
if ($code !=0) { $numfail++;}
delete $active_pids{$r};
# print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n";
} else {
die "run.pl: Cannot find the PID of the child process that just finished.";
}
# In theory we could do a non-blocking waitpid over all jobs running just
# to find out if only one or more jobs finished during the previous waitpid()
# However, we just omit this and will reap the next one in the next pass
# through the for(;;) cycle
}
$childpid = fork();
if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; }
if ($childpid == 0) { # We're in the child... this branch
# executes the job and returns (possibly with an error status).
if (defined $jobname) {
$cmd =~ s/$jobname/$jobid/g;
$logfile =~ s/$jobname/$jobid/g;
}
# exit if the job does not need to be executed
pick_or_exit( $logfile );
system("mkdir -p `dirname $logfile` 2>/dev/null");
open(F, ">$logfile") || die "run.pl: Error opening log file $logfile";
print F "# " . $cmd . "\n";
print F "# Started at " . `date`;
$starttime = `date +'%s'`;
print F "#\n";
close(F);
# Pipe into bash.. make sure we're not using any other shell.
open(B, "|bash") || die "run.pl: Error opening shell command";
print B "( " . $cmd . ") 2>>$logfile >> $logfile";
close(B); # If there was an error, exit status is in $?
$ret = $?;
$lowbits = $ret & 127;
$highbits = $ret >> 8;
if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" }
else { $return_str = "code $highbits"; }
$endtime = `date +'%s'`;
open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)";
$enddate = `date`;
chop $enddate;
print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n";
print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n";
close(F);
exit($ret == 0 ? 0 : 1);
} else {
$pid[$jobid] = $childpid;
$active_pids{$childpid} = $jobid;
# print STDERR "Queued: " . Dumper(\%active_pids) . "\n";
}
}
# Now we have submitted all the jobs, lets wait until all the jobs finish
foreach $child (keys %active_pids) {
$jobid=$active_pids{$child};
$r = waitpid($pid[$jobid], 0);
$code = $?;
if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen.
if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully
}
# Some sanity checks:
# The $fail array should not contain undefined codes
# The number of non-zeros in that array should be equal to $numfail
# We cannot do foreach() here, as the JOB ids do not start at zero
$failed_jids=0;
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
$job_return = $fail[$jobid];
if (not defined $job_return ) {
# print Dumper(\@fail);
die "run.pl: Sanity check failed: we have indication that some jobs are running " .
"even after we waited for all jobs to finish" ;
}
if ($job_return != 0 ){ $failed_jids++;}
}
if ($failed_jids != $numfail) {
die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)."
}
if ($numfail > 0) { $ret = 1; }
if ($ret != 0) {
$njobs = $jobend - $jobstart + 1;
if ($njobs == 1) {
if (defined $jobname) {
$logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with
# that job.
}
print STDERR "run.pl: job failed, log is in $logfile\n";
if ($logfile =~ m/JOB/) {
print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script.";
}
}
else {
$logfile =~ s/$jobname/*/g;
print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n";
}
}
exit ($ret);

View File

@ -1,44 +0,0 @@
#!/usr/bin/env perl
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
if ($ARGV[0] eq "--srand") {
$n = $ARGV[1];
$n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\"";
srand($ARGV[1]);
shift;
shift;
} else {
srand(0); # Gives inconsistent behavior if we don't seed.
}
if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we
# don't understand.
print "Usage: shuffle_list.pl [--srand N] [input file] > output\n";
print "randomizes the order of lines of input.\n";
exit(1);
}
@lines;
while (<>) {
push @lines, [ (rand(), $_)] ;
}
@lines = sort { $a->[0] cmp $b->[0] } @lines;
foreach $l (@lines) {
print $l->[1];
}

View File

@ -1,60 +0,0 @@
import os
import sys
import random
in_dir = sys.argv[1]
out_dir = sys.argv[2]
num_split = sys.argv[3]
def split_scp(scp, num):
assert len(scp) >= num
avg = len(scp) // num
out = []
begin = 0
for i in range(num):
if i == num - 1:
out.append(scp[begin:])
else:
out.append(scp[begin:begin+avg])
begin += avg
return out
os.path.exists("{}/wav.scp".format(in_dir))
os.path.exists("{}/text".format(in_dir))
with open("{}/wav.scp".format(in_dir), 'r') as infile:
wav_list = infile.readlines()
with open("{}/text".format(in_dir), 'r') as infile:
text_list = infile.readlines()
assert len(wav_list) == len(text_list)
x = list(zip(wav_list, text_list))
random.shuffle(x)
wav_shuffle_list, text_shuffle_list = zip(*x)
num_split = int(num_split)
wav_split_list = split_scp(wav_shuffle_list, num_split)
text_split_list = split_scp(text_shuffle_list, num_split)
for idx, wav_list in enumerate(wav_split_list, 1):
path = out_dir + "/split" + str(num_split) + "/" + str(idx)
if not os.path.exists(path):
os.makedirs(path)
with open("{}/wav.scp".format(path), 'w') as wav_writer:
for line in wav_list:
wav_writer.write(line)
for idx, text_list in enumerate(text_split_list, 1):
path = out_dir + "/split" + str(num_split) + "/" + str(idx)
if not os.path.exists(path):
os.makedirs(path)
with open("{}/text".format(path), 'w') as text_writer:
for line in text_list:
text_writer.write(line)

View File

@ -1,246 +0,0 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# See ../../COPYING for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can. If you use the utt2spk
# option it will make sure these chunks coincide with speaker boundaries. In
# this case, if there are more chunks than speakers (and in some other
# circumstances), some of the resulting chunks will be empty and it will print
# an error message and exit with nonzero status.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
# You can also call the scripts like:
# split_scp.pl -j 3 0 scp scp.0
# [note: with this option, it assumes zero-based indexing of the split parts,
# i.e. the second number must be 0 <= n < num-jobs.]
use warnings;
$num_jobs = 0;
$job_id = 0;
$utt2spk_file = "";
$one_based = 0;
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
if ($ARGV[0] eq "-j") {
shift @ARGV;
$num_jobs = shift @ARGV;
$job_id = shift @ARGV;
}
if ($ARGV[0] =~ /--utt2spk=(.+)/) {
$utt2spk_file=$1;
shift;
}
if ($ARGV[0] eq '--one-based') {
$one_based = 1;
shift @ARGV;
}
}
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
$job_id - $one_based >= $num_jobs)) {
die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
($one_based ? " --one-based" : "") . "'\n"
}
$one_based
and $job_id--;
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
die
"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
}
$error = 0;
$inscp = shift @ARGV;
if ($num_jobs == 0) { # without -j option
@OUTPUTS = @ARGV;
} else {
for ($j = 0; $j < $num_jobs; $j++) {
if ($j == $job_id) {
if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
else { push @OUTPUTS, "-"; }
} else {
push @OUTPUTS, "/dev/null";
}
}
}
if ($utt2spk_file ne "") { # We have the --utt2spk option...
open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
while(<$u_fh>) {
@A = split;
@A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
($u,$s) = @A;
$utt2spk{$u} = $s;
}
close $u_fh;
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
@spkrs = ();
while(<$i_fh>) {
@A = split;
if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
$u = $A[0];
$s = $utt2spk{$u};
defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
if(!defined $spk_count{$s}) {
push @spkrs, $s;
$spk_count{$s} = 0;
$spk_data{$s} = []; # ref to new empty array.
}
$spk_count{$s}++;
push @{$spk_data{$s}}, $_;
}
# Now split as equally as possible ..
# First allocate spks to files by allocating an approximately
# equal number of speakers.
$numspks = @spkrs; # number of speakers.
$numscps = @OUTPUTS; # number of output files.
if ($numspks < $numscps) {
die "$0: Refusing to split data because number of speakers $numspks " .
"is less than the number of output .scp files $numscps\n";
}
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scparray[$scpidx] = []; # [] is array reference.
}
for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
$scpidx = int(($spkidx*$numscps) / $numspks);
$spk = $spkrs[$spkidx];
push @{$scparray[$scpidx]}, $spk;
$scpcount[$scpidx] += $spk_count{$spk};
}
# Now will try to reassign beginning + ending speakers
# to different scp's and see if it gets more balanced.
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
# We can show that if considering changing just 2 scp's, we minimize
# this by minimizing the squared difference in sizes. This is
# equivalent to minimizing the absolute difference in sizes. This
# shows this method is bound to converge.
$changed = 1;
while($changed) {
$changed = 0;
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
# First try to reassign ending spk of this scp.
if($scpidx < $numscps-1) {
$sz = @{$scparray[$scpidx]};
if($sz > 0) {
$spk = $scparray[$scpidx]->[$sz-1];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx];
$nutt2 = $scpcount[$scpidx+1];
if( abs( ($nutt2+$count) - ($nutt1-$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx+1] += $count;
$scpcount[$scpidx] -= $count;
pop @{$scparray[$scpidx]};
unshift @{$scparray[$scpidx+1]}, $spk;
$changed = 1;
}
}
}
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
$spk = $scparray[$scpidx]->[0];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx-1];
$nutt2 = $scpcount[$scpidx];
if( abs( ($nutt2-$count) - ($nutt1+$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx-1] += $count;
$scpcount[$scpidx] -= $count;
shift @{$scparray[$scpidx]};
push @{$scparray[$scpidx-1]}, $spk;
$changed = 1;
}
}
}
}
# Now print out the files...
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($f_fh, '>', $scpfile)
: open($f_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
$count = 0;
if(@{$scparray[$scpidx]} == 0) {
print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
"$scpfile (too many splits and too few speakers?)\n";
$error = 1;
} else {
foreach $spk ( @{$scparray[$scpidx]} ) {
print $f_fh @{$spk_data{$spk}};
$count += $spk_count{$spk};
}
$count == $scpcount[$scpidx] || die "Count mismatch [code error]";
}
close($f_fh);
}
} else {
# This block is the "normal" case where there is no --utt2spk
# option and we just break into equal size chunks.
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
$numscps = @OUTPUTS; # size of array.
@F = ();
while(<$i_fh>) {
push @F, $_;
}
$numlines = @F;
if($numlines == 0) {
print STDERR "$0: error: empty input scp file $inscp\n";
$error = 1;
}
$linesperscp = int( $numlines / $numscps); # the "whole part"..
$linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
$remainder = $numlines - ($linesperscp * $numscps);
($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
# [just doing int() rounds down].
$n = 0;
for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($o_fh, '>', $scpfile)
: open($o_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
print $o_fh $F[$n++];
}
close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
}
$n == $numlines || die "$n != $numlines [code error]";
}
exit ($error);

View File

@ -1,30 +0,0 @@
#!/usr/bin/env bash
dev_num_utt=1000
echo "$0 $@"
. utils/parse_options.sh || exit 1;
train_data=$1
out_dir=$2
[ ! -f ${train_data}/wav.scp ] && echo "$0: no such file ${train_data}/wav.scp" && exit 1;
[ ! -f ${train_data}/text ] && echo "$0: no such file ${train_data}/text" && exit 1;
mkdir -p ${out_dir}/train && mkdir -p ${out_dir}/dev
cp ${train_data}/wav.scp ${out_dir}/train/wav.scp.bak
cp ${train_data}/text ${out_dir}/train/text.bak
num_utt=$(wc -l <${out_dir}/train/wav.scp.bak)
utils/shuffle_list.pl --srand 1 ${out_dir}/train/wav.scp.bak > ${out_dir}/train/wav.scp.shuf
head -n ${dev_num_utt} ${out_dir}/train/wav.scp.shuf > ${out_dir}/dev/wav.scp
tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/wav.scp.shuf > ${out_dir}/train/wav.scp
utils/shuffle_list.pl --srand 1 ${out_dir}/train/text.bak > ${out_dir}/train/text.shuf
head -n ${dev_num_utt} ${out_dir}/train/text.shuf > ${out_dir}/dev/text
tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/text.shuf > ${out_dir}/train/text
rm ${out_dir}/train/wav.scp.bak ${out_dir}/train/text.bak
rm ${out_dir}/train/wav.scp.shuf ${out_dir}/train/text.shuf

View File

@ -1,135 +0,0 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import re
import sys
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
start_pos = None
end_pos = None
for pos in match_pos:
if pos[0] <= i < pos[1]:
start_pos = pos[0]
end_pos = pos[1]
break
return start_pos, end_pos
def get_parser():
parser = argparse.ArgumentParser(
description="convert raw text to tokenized text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--nchar",
"-n",
default=1,
type=int,
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
default=None,
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
type=str,
default="char",
choices=["char", "phn"],
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
If trans_type is char,
read from SI1279.WRD file -> "bricks are an alternative"
Else if trans_type is phn,
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
sil t er n ih sil t ih v sil" """,
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
rs = []
if args.non_lang_syms is not None:
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
nls = [x.rstrip() for x in f.readlines()]
rs = [re.compile(re.escape(x)) for x in nls]
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
line = f.readline()
n = args.nchar
while line:
x = line.split()
print(" ".join(x[: args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols :])
# get all matched positions
match_pos = []
for r in rs:
i = 0
while i >= 0:
m = r.search(a, i)
if m:
match_pos.append([m.start(), m.end()])
i = m.end()
else:
break
if args.trans_type == "phn":
a = a.split(" ")
else:
if len(match_pos) > 0:
chars = []
i = 0
while i < len(a):
start_pos, end_pos = exist_or_not(i, match_pos)
if start_pos is not None:
chars.append(a[start_pos:end_pos])
i = end_pos
else:
chars.append(a[i])
i += 1
a = chars
a = [a[j : j + n] for j in range(0, len(a), n)]
a_flat = []
for z in a:
a_flat.append("".join(z))
a_chars = [z.replace(" ", args.space) for z in a_flat]
if args.trans_type == "phn":
a_chars = [z.replace("sil", args.space) for z in a_chars]
print(" ".join(a_chars))
line = f.readline()
if __name__ == "__main__":
main()

View File

@ -1,106 +0,0 @@
import re
import argparse
def load_dict(seg_file):
seg_dict = {}
with open(seg_file, 'r') as infile:
for line in infile:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dict
def forward_segment(text, dic):
word_list = []
i = 0
while i < len(text):
longest_word = text[i]
for j in range(i + 1, len(text) + 1):
word = text[i:j]
if word in dic:
if len(word) > len(longest_word):
longest_word = word
word_list.append(longest_word)
i += len(longest_word)
return word_list
def tokenize(txt,
seg_dict):
out_txt = ""
pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
for word in txt:
if pattern.match(word):
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
out_txt += "<unk>" + " "
else:
continue
return out_txt.strip()
def get_parser():
parser = argparse.ArgumentParser(
description="text tokenize",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--text-file",
"-t",
default=False,
required=True,
type=str,
help="input text",
)
parser.add_argument(
"--seg-file",
"-s",
default=False,
required=True,
type=str,
help="seg file",
)
parser.add_argument(
"--txt-index",
"-i",
default=1,
required=True,
type=int,
help="txt index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w')
shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w')
seg_dict = load_dict(args.seg_file)
with open(args.text_file, 'r') as infile:
for line in infile:
s = line.strip().split()
text_id = s[0]
text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
text = tokenize(text_list, seg_dict)
lens = len(text.strip().split())
txt_writer.write(text_id + " " + text + '\n')
shape_writer.write(text_id + " " + str(lens) + '\n')
if __name__ == '__main__':
main()

View File

@ -1,35 +0,0 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
# tokenize configuration
text_dir=$1
seg_file=$2
logdir=$3
output_dir=$4
txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \
python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \
-s ${seg_file} -i JOB -o ${txt_dir} \
|| exit 1;
# concatenate the text files together.
for n in $(seq $nj); do
cat ${txt_dir}/text.$n.txt || exit 1
done > ${output_dir}/text || exit 1
for n in $(seq $nj); do
cat ${txt_dir}/len.$n || exit 1
done > ${output_dir}/text_shape || exit 1
echo "$0: Succeeded text tokenize"

View File

@ -1,834 +0,0 @@
#!/usr/bin/env python3
# coding=utf-8
# Authors:
# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
# 2019.9 Jiayu DU
#
# requirements:
# - python 3.X
# notes: python 2.X WILL fail or produce misleading results
import sys, os, argparse, codecs, string, re
# ================================================================================ #
# basic constant
# ================================================================================ #
CHINESE_DIGIS = u'零一二三四五六七八九'
BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
ZERO_ALT = u''
ONE_ALT = u''
TWO_ALTS = [u'', u'']
POSITIVE = [u'', u'']
NEGATIVE = [u'', u'']
POINT = [u'', u'']
# PLUS = [u'加', u'加']
# SIL = [u'杠', u'槓']
FILLER_CHARS = ['', '']
ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \
'胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \
'儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \
'佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)'
# 中文数字系统类型
NUMBERING_TYPES = ['low', 'mid', 'high']
CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
'里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
'砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
'针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
'毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
'盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
'纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
CHINESE_PUNC_STOP = '!?。。'
CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏'
CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
# ================================================================================ #
# basic class
# ================================================================================ #
class ChineseChar(object):
"""
中文字符
每个字符对应简体和繁体,
e.g. 简体 = '', 繁体 = ''
转换时可转换为简体或繁体
"""
def __init__(self, simplified, traditional):
self.simplified = simplified
self.traditional = traditional
#self.__repr__ = self.__str__
def __str__(self):
return self.simplified or self.traditional or None
def __repr__(self):
return self.__str__()
class ChineseNumberUnit(ChineseChar):
"""
中文数字/数位字符
每个字符除繁简体外还有一个额外的大写字符
e.g. '' ''
"""
def __init__(self, power, simplified, traditional, big_s, big_t):
super(ChineseNumberUnit, self).__init__(simplified, traditional)
self.power = power
self.big_s = big_s
self.big_t = big_t
def __str__(self):
return '10^{}'.format(self.power)
@classmethod
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
if small_unit:
return ChineseNumberUnit(power=index + 1,
simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
elif numbering_type == NUMBERING_TYPES[0]:
return ChineseNumberUnit(power=index + 8,
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
elif numbering_type == NUMBERING_TYPES[1]:
return ChineseNumberUnit(power=(index + 2) * 4,
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
elif numbering_type == NUMBERING_TYPES[2]:
return ChineseNumberUnit(power=pow(2, index + 3),
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
else:
raise ValueError(
'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
class ChineseNumberDigit(ChineseChar):
"""
中文数字字符
"""
def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
super(ChineseNumberDigit, self).__init__(simplified, traditional)
self.value = value
self.big_s = big_s
self.big_t = big_t
self.alt_s = alt_s
self.alt_t = alt_t
def __str__(self):
return str(self.value)
@classmethod
def create(cls, i, v):
return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
class ChineseMath(ChineseChar):
"""
中文数位字符
"""
def __init__(self, simplified, traditional, symbol, expression=None):
super(ChineseMath, self).__init__(simplified, traditional)
self.symbol = symbol
self.expression = expression
self.big_s = simplified
self.big_t = traditional
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
class NumberSystem(object):
"""
中文数字系统
"""
pass
class MathSymbol(object):
"""
用于中文数字系统的数学符号 (/简体), e.g.
positive = ['', '']
negative = ['', '']
point = ['', '']
"""
def __init__(self, positive, negative, point):
self.positive = positive
self.negative = negative
self.point = point
def __iter__(self):
for v in self.__dict__.values():
yield v
# class OtherSymbol(object):
# """
# 其他符号
# """
#
# def __init__(self, sil):
# self.sil = sil
#
# def __iter__(self):
# for v in self.__dict__.values():
# yield v
# ================================================================================ #
# basic utils
# ================================================================================ #
def create_system(numbering_type=NUMBERING_TYPES[1]):
"""
根据数字系统类型返回创建相应的数字系统默认为 mid
NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
low: '' = '亿' * '' = $10^{9}$, '' = '' * '', etc.
mid: '' = '亿' * '' = $10^{12}$, '' = '' * '', etc.
high: '' = '亿' * '亿' = $10^{16}$, '' = '' * '', etc.
返回对应的数字系统
"""
# chinese number units of '亿' and larger
all_larger_units = zip(
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
larger_units = [CNU.create(i, v, numbering_type, False)
for i, v in enumerate(all_larger_units)]
# chinese number units of '十, 百, 千, 万'
all_smaller_units = zip(
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
smaller_units = [CNU.create(i, v, small_unit=True)
for i, v in enumerate(all_smaller_units)]
# digis
chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
# symbols
positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
point_cn = CM(POINT[0], POINT[1], '.', lambda x,
y: float(str(x) + '.' + str(y)))
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
system = NumberSystem()
system.units = smaller_units + larger_units
system.digits = digits
system.math = MathSymbol(positive_cn, negative_cn, point_cn)
# system.symbols = OtherSymbol(sil_cn)
return system
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
def get_symbol(char, system):
for u in system.units:
if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
return u
for d in system.digits:
if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
return d
for m in system.math:
if char in [m.traditional, m.simplified]:
return m
def string2symbols(chinese_string, system):
int_string, dec_string = chinese_string, ''
for p in [system.math.point.simplified, system.math.point.traditional]:
if p in chinese_string:
int_string, dec_string = chinese_string.split(p)
break
return [get_symbol(c, system) for c in int_string], \
[get_symbol(c, system) for c in dec_string]
def correct_symbols(integer_symbols, system):
"""
一百八 to 一百八十
一亿一千三百万 to 一亿 一千万 三百万
"""
if integer_symbols and isinstance(integer_symbols[0], CNU):
if integer_symbols[0].power == 1:
integer_symbols = [system.digits[1]] + integer_symbols
if len(integer_symbols) > 1:
if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
integer_symbols.append(
CNU(integer_symbols[-2].power - 1, None, None, None, None))
result = []
unit_count = 0
for s in integer_symbols:
if isinstance(s, CND):
result.append(s)
unit_count = 0
elif isinstance(s, CNU):
current_unit = CNU(s.power, None, None, None, None)
unit_count += 1
if unit_count == 1:
result.append(current_unit)
elif unit_count > 1:
for i in range(len(result)):
if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
result[-i - 1] = CNU(result[-i - 1].power +
current_unit.power, None, None, None, None)
return result
def compute_value(integer_symbols):
"""
Compute the value.
When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
e.g. '两千万' = 2000 * 10000 not 2000 + 10000
"""
value = [0]
last_power = 0
for s in integer_symbols:
if isinstance(s, CND):
value[-1] = s.value
elif isinstance(s, CNU):
value[-1] *= pow(10, s.power)
if s.power > last_power:
value[:-1] = list(map(lambda v: v *
pow(10, s.power), value[:-1]))
last_power = s.power
value.append(0)
return sum(value)
system = create_system(numbering_type)
int_part, dec_part = string2symbols(chinese_string, system)
int_part = correct_symbols(int_part, system)
int_str = str(compute_value(int_part))
dec_str = ''.join([str(d.value) for d in dec_part])
if dec_part:
return '{0}.{1}'.format(int_str, dec_str)
else:
return int_str
def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
traditional=False, alt_zero=False, alt_one=False, alt_two=True,
use_zeros=True, use_units=True):
def get_value(value_string, use_zeros=True):
striped_string = value_string.lstrip('0')
# record nothing if all zeros
if not striped_string:
return []
# record one digits
elif len(striped_string) == 1:
if use_zeros and len(value_string) != len(striped_string):
return [system.digits[0], system.digits[int(striped_string)]]
else:
return [system.digits[int(striped_string)]]
# recursively record multiple digits
else:
result_unit = next(u for u in reversed(
system.units) if u.power < len(striped_string))
result_string = value_string[:-result_unit.power]
return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
system = create_system(numbering_type)
int_dec = number_string.split('.')
if len(int_dec) == 1:
int_string = int_dec[0]
dec_string = ""
elif len(int_dec) == 2:
int_string = int_dec[0]
dec_string = int_dec[1]
else:
raise ValueError(
"invalid input num string with more than one dot: {}".format(number_string))
if use_units and len(int_string) > 1:
result_symbols = get_value(int_string)
else:
result_symbols = [system.digits[int(c)] for c in int_string]
dec_symbols = [system.digits[int(c)] for c in dec_string]
if dec_string:
result_symbols += [system.math.point] + dec_symbols
if alt_two:
liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
system.digits[2].big_s, system.digits[2].big_t)
for i, v in enumerate(result_symbols):
if isinstance(v, CND) and v.value == 2:
next_symbol = result_symbols[i +
1] if i < len(result_symbols) - 1 else None
previous_symbol = result_symbols[i - 1] if i > 0 else None
if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output
if big:
attr_name = 'big_'
if traditional:
attr_name += 't'
else:
attr_name += 's'
else:
if traditional:
attr_name = 'traditional'
else:
attr_name = 'simplified'
result = ''.join([getattr(s, attr_name) for s in result_symbols])
# if not use_zeros:
# result = result.strip(getattr(system.digits[0], attr_name))
if alt_zero:
result = result.replace(
getattr(system.digits[0], attr_name), system.digits[0].alt_s)
if alt_one:
result = result.replace(
getattr(system.digits[1], attr_name), system.digits[1].alt_s)
for i, p in enumerate(POINT):
if result.startswith(p):
return CHINESE_DIGIS[0] + result
# ^10, 11, .., 19
if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
result = result[1:]
return result
# ================================================================================ #
# different types of rewriters
# ================================================================================ #
class Cardinal:
"""
CARDINAL类
"""
def __init__(self, cardinal=None, chntext=None):
self.cardinal = cardinal
self.chntext = chntext
def chntext2cardinal(self):
return chn2num(self.chntext)
def cardinal2chntext(self):
return num2chn(self.cardinal)
class Digit:
"""
DIGIT类
"""
def __init__(self, digit=None, chntext=None):
self.digit = digit
self.chntext = chntext
# def chntext2digit(self):
# return chn2num(self.chntext)
def digit2chntext(self):
return num2chn(self.digit, alt_two=False, use_units=False)
class TelePhone:
"""
TELEPHONE类
"""
def __init__(self, telephone=None, raw_chntext=None, chntext=None):
self.telephone = telephone
self.raw_chntext = raw_chntext
self.chntext = chntext
# def chntext2telephone(self):
# sil_parts = self.raw_chntext.split('<SIL>')
# self.telephone = '-'.join([
# str(chn2num(p)) for p in sil_parts
# ])
# return self.telephone
def telephone2chntext(self, fixed=False):
if fixed:
sil_parts = self.telephone.split('-')
self.raw_chntext = '<SIL>'.join([
num2chn(part, alt_two=False, use_units=False) for part in sil_parts
])
self.chntext = self.raw_chntext.replace('<SIL>', '')
else:
sp_parts = self.telephone.strip('+').split()
self.raw_chntext = '<SP>'.join([
num2chn(part, alt_two=False, use_units=False) for part in sp_parts
])
self.chntext = self.raw_chntext.replace('<SP>', '')
return self.chntext
class Fraction:
"""
FRACTION类
"""
def __init__(self, fraction=None, chntext=None):
self.fraction = fraction
self.chntext = chntext
def chntext2fraction(self):
denominator, numerator = self.chntext.split('分之')
return chn2num(numerator) + '/' + chn2num(denominator)
def fraction2chntext(self):
numerator, denominator = self.fraction.split('/')
return num2chn(denominator) + '分之' + num2chn(numerator)
class Date:
"""
DATE类
"""
def __init__(self, date=None, chntext=None):
self.date = date
self.chntext = chntext
# def chntext2date(self):
# chntext = self.chntext
# try:
# year, other = chntext.strip().split('年', maxsplit=1)
# year = Digit(chntext=year).digit2chntext() + '年'
# except ValueError:
# other = chntext
# year = ''
# if other:
# try:
# month, day = other.strip().split('月', maxsplit=1)
# month = Cardinal(chntext=month).chntext2cardinal() + '月'
# except ValueError:
# day = chntext
# month = ''
# if day:
# day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
# else:
# month = ''
# day = ''
# date = year + month + day
# self.date = date
# return self.date
def date2chntext(self):
date = self.date
try:
year, other = date.strip().split('', 1)
year = Digit(digit=year).digit2chntext() + ''
except ValueError:
other = date
year = ''
if other:
try:
month, day = other.strip().split('', 1)
month = Cardinal(cardinal=month).cardinal2chntext() + ''
except ValueError:
day = date
month = ''
if day:
day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
else:
month = ''
day = ''
chntext = year + month + day
self.chntext = chntext
return self.chntext
class Money:
"""
MONEY类
"""
def __init__(self, money=None, chntext=None):
self.money = money
self.chntext = chntext
# def chntext2money(self):
# return self.money
def money2chntext(self):
money = self.money
pattern = re.compile(r'(\d+(\.\d+)?)')
matchers = pattern.findall(money)
if matchers:
for matcher in matchers:
money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
self.chntext = money
return self.chntext
class Percentage:
"""
PERCENTAGE类
"""
def __init__(self, percentage=None, chntext=None):
self.percentage = percentage
self.chntext = chntext
def chntext2percentage(self):
return chn2num(self.chntext.strip().strip('百分之')) + '%'
def percentage2chntext(self):
return '百分之' + num2chn(self.percentage.strip().strip('%'))
def remove_erhua(text, er_whitelist):
"""
去除儿化音词中的儿:
他女儿在那边儿 -> 他女儿在那边
"""
er_pattern = re.compile(er_whitelist)
new_str=''
while re.search('',text):
a = re.search('',text).span()
remove_er_flag = 0
if er_pattern.search(text):
b = er_pattern.search(text).span()
if b[0] <= a[0]:
remove_er_flag = 1
if remove_er_flag == 0 :
new_str = new_str + text[0:a[0]]
text = text[a[1]:]
else:
new_str = new_str + text[0:b[1]]
text = text[b[1]:]
text = new_str + text
return text
# ================================================================================ #
# NSW Normalizer
# ================================================================================ #
class NSWNormalizer:
def __init__(self, raw_text):
self.raw_text = '^' + raw_text + '$'
self.norm_text = ''
def _particular(self):
text = self.norm_text
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
matchers = pattern.findall(text)
if matchers:
# print('particular')
for matcher in matchers:
text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
self.norm_text = text
return self.norm_text
def normalize(self):
text = self.raw_text
# 规范化日期
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
matchers = pattern.findall(text)
if matchers:
#print('date')
for matcher in matchers:
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
# 规范化金钱
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
matchers = pattern.findall(text)
if matchers:
#print('money')
for matcher in matchers:
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
# 规范化固话/手机号码
# 手机
# http://www.jihaoba.com/news/show/13680
# 移动139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
# 联通130、131、132、156、155、186、185、176
# 电信133、153、189、180、181、177
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
matchers = pattern.findall(text)
if matchers:
#print('telephone')
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
# 固话
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
matchers = pattern.findall(text)
if matchers:
# print('fixed telephone')
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
# 规范化分数
pattern = re.compile(r"(\d+/\d+)")
matchers = pattern.findall(text)
if matchers:
#print('fraction')
for matcher in matchers:
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
# 规范化百分数
text = text.replace('', '%')
pattern = re.compile(r"(\d+(\.\d+)?%)")
matchers = pattern.findall(text)
if matchers:
#print('percentage')
for matcher in matchers:
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
# 规范化纯数+量词
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
matchers = pattern.findall(text)
if matchers:
#print('cardinal+quantifier')
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
# 规范化数字编号
pattern = re.compile(r"(\d{4,32})")
matchers = pattern.findall(text)
if matchers:
#print('digit')
for matcher in matchers:
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
# 规范化纯数
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(text)
if matchers:
#print('cardinal')
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
self.norm_text = text
self._particular()
return self.norm_text.lstrip('^').rstrip('$')
def nsw_test_case(raw_text):
print('I:' + raw_text)
print('O:' + NSWNormalizer(raw_text).normalize())
print('')
def nsw_test():
nsw_test_case('固话0595-23865596或23880880。')
nsw_test_case('固话0595-23865596或23880880。')
nsw_test_case('手机:+86 19859213959或15659451527。')
nsw_test_case('分数32477/76391。')
nsw_test_case('百分数80.03%')
nsw_test_case('编号31520181154418。')
nsw_test_case('纯数2983.07克或12345.60米。')
nsw_test_case('日期1999年2月20日或09年3月15号。')
nsw_test_case('金钱12块534.5元20.1万')
nsw_test_case('特殊O2O或B2C。')
nsw_test_case('3456万吨')
nsw_test_case('2938个')
nsw_test_case('938')
nsw_test_case('今天吃了115个小笼包231个馒头')
nsw_test_case('有62的概率')
if __name__ == '__main__':
#nsw_test()
p = argparse.ArgumentParser()
p.add_argument('ifile', help='input filename, assume utf-8 encoding')
p.add_argument('ofile', help='output filename')
p.add_argument('--to_upper', action='store_true', help='convert to upper case')
p.add_argument('--to_lower', action='store_true', help='convert to lower case')
p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "呃, 啊"')
p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "这儿"')
p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
args = p.parse_args()
ifile = codecs.open(args.ifile, 'r', 'utf8')
ofile = codecs.open(args.ofile, 'w+', 'utf8')
n = 0
for l in ifile:
key = ''
text = ''
if args.has_key:
cols = l.split(maxsplit=1)
key = cols[0]
if len(cols) == 2:
text = cols[1].strip()
else:
text = ''
else:
text = l.strip()
# cases
if args.to_upper and args.to_lower:
sys.stderr.write('text norm: to_upper OR to_lower?')
exit(1)
if args.to_upper:
text = text.upper()
if args.to_lower:
text = text.lower()
# Filler chars removal
if args.remove_fillers:
for ch in FILLER_CHARS:
text = text.replace(ch, '')
if args.remove_erhua:
text = remove_erhua(text, ER_WHITELIST)
# NSW(Non-Standard-Word) normalization
text = NSWNormalizer(text).normalize()
# Punctuations removal
old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
new_chars = ' ' * len(old_chars)
del_chars = ''
text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
#
if args.has_key:
ofile.write(key + '\t' + text + '\n')
else:
ofile.write(text + '\n')
n += 1
if n % args.log_interval == 0:
sys.stderr.write("text norm: {} lines done.\n".format(n))
sys.stderr.write("text norm: {} lines done in total.\n".format(n))
ifile.close()
ofile.close()

View File

@ -0,0 +1 @@
../tranformer/utils

View File

@ -1,79 +0,0 @@
from kaldiio import ReadHelper
from kaldiio import WriteHelper
import argparse
import json
import math
import numpy as np
def get_parser():
parser = argparse.ArgumentParser(
description="apply cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--ark-file",
"-a",
default=False,
required=True,
type=str,
help="fbank ark file",
)
parser.add_argument(
"--cmvn-file",
"-c",
default=False,
required=True,
type=str,
help="cmvn file",
)
parser.add_argument(
"--ark-index",
"-i",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
with open(args.cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stats']
vars = cmvn_stats['var_stats']
total_frames = cmvn_stats['total_frames']
for i in range(len(means)):
means[i] /= total_frames
vars[i] = vars[i] / total_frames - means[i] * means[i]
if vars[i] < 1.0e-20:
vars[i] = 1.0e-20
vars[i] = 1.0 / math.sqrt(vars[i])
with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
for key, mat in ark_reader:
mat = (mat - means) * vars
ark_writer(key, mat)
if __name__ == '__main__':
main()

View File

@ -1,29 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
nj=32
cmd=./utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
fbankdir=$1
cmvn_file=$2
logdir=$3
output_dir=$4
dump_dir=${output_dir}/ark; mkdir -p ${dump_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/apply_cmvn.JOB.log \
python utils/apply_cmvn.py -a $fbankdir/ark/feats.JOB.ark \
-c $cmvn_file -i JOB -o ${dump_dir} \
|| exit 1;
for n in $(seq $nj); do
cat ${dump_dir}/feats.$n.scp || exit 1
done > ${output_dir}/feats.scp || exit 1
echo "$0: Succeeded apply cmvn"

View File

@ -1,143 +0,0 @@
from kaldiio import ReadHelper, WriteHelper
import argparse
import numpy as np
def build_LFR_features(inputs, m=7, n=6):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / n))
left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (m - 1) // 2
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
else:
num_padding = m - (T - i * n)
frame = np.hstack(inputs[i * n:])
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)
def build_CMVN_features(inputs, mvn_file): # noqa
with open(mvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
add_shift_list = []
rescale_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
add_shift_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
rescale_list = list(rescale_line)
continue
for j in range(inputs.shape[0]):
for k in range(inputs.shape[1]):
add_shift_value = add_shift_list[k]
rescale_value = rescale_list[k]
inputs[j, k] = float(inputs[j, k]) + float(add_shift_value)
inputs[j, k] = float(inputs[j, k]) * float(rescale_value)
return inputs
def get_parser():
parser = argparse.ArgumentParser(
description="apply low_frame_rate and cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--ark-file",
"-a",
default=False,
required=True,
type=str,
help="fbank ark file",
)
parser.add_argument(
"--lfr",
"-f",
default=True,
type=str,
help="low frame rate",
)
parser.add_argument(
"--lfr-m",
"-m",
default=7,
type=int,
help="number of frames to stack",
)
parser.add_argument(
"--lfr-n",
"-n",
default=6,
type=int,
help="number of frames to skip",
)
parser.add_argument(
"--cmvn-file",
"-c",
default=False,
required=True,
type=str,
help="global cmvn file",
)
parser.add_argument(
"--ark-index",
"-i",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
dump_ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
dump_scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
shape_file = args.output_dir + "/len." + str(args.ark_index)
ark_writer = WriteHelper('ark,scp:{},{}'.format(dump_ark_file, dump_scp_file))
shape_writer = open(shape_file, 'w')
with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
for key, mat in ark_reader:
if args.lfr:
lfr = build_LFR_features(mat, args.lfr_m, args.lfr_n)
else:
lfr = mat
cmvn = build_CMVN_features(lfr, args.cmvn_file)
dims = cmvn.shape[1]
lens = cmvn.shape[0]
shape_writer.write(key + " " + str(lens) + "," + str(dims) + '\n')
ark_writer(key, cmvn)
if __name__ == '__main__':
main()

View File

@ -1,38 +0,0 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=utils/run.pl
# feature configuration
lfr=True
lfr_m=7
lfr_n=6
echo "$0 $@"
. utils/parse_options.sh || exit 1;
fbankdir=$1
cmvn_file=$2
logdir=$3
output_dir=$4
dump_dir=${output_dir}/ark; mkdir -p ${dump_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/apply_lfr_and_cmvn.JOB.log \
python utils/apply_lfr_and_cmvn.py -a $fbankdir/ark/feats.JOB.ark \
-f $lfr -m $lfr_m -n $lfr_n -c $cmvn_file -i JOB -o ${dump_dir} \
|| exit 1;
for n in $(seq $nj); do
cat ${dump_dir}/feats.$n.scp || exit 1
done > ${output_dir}/feats.scp || exit 1
for n in $(seq $nj); do
cat ${dump_dir}/len.$n || exit 1
done > ${output_dir}/speech_shape || exit 1
echo "$0: Succeeded apply low frame rate and cmvn"

View File

@ -1,73 +0,0 @@
import argparse
import json
import numpy as np
def get_parser():
parser = argparse.ArgumentParser(
description="combine cmvn file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--cmvn-dir",
"-c",
default=False,
required=True,
type=str,
help="cmvn dir",
)
parser.add_argument(
"--nj",
"-n",
default=1,
required=True,
type=int,
help="num of cmvn file",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
total_means = np.zeros(args.dims)
total_vars = np.zeros(args.dims)
total_frames = 0
cmvn_file = args.output_dir + "/cmvn.json"
for i in range(1, args.nj+1):
with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
cmvn_stats = json.load(fin)
total_means += np.array(cmvn_stats["mean_stats"])
total_vars += np.array(cmvn_stats["var_stats"])
total_frames += cmvn_stats["total_frames"]
cmvn_info = {
'mean_stats': list(total_means.tolist()),
'var_stats': list(total_vars.tolist()),
'total_frames': total_frames
}
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
if __name__ == '__main__':
main()

View File

@ -1,74 +0,0 @@
from kaldiio import ReadHelper
import argparse
import numpy as np
import json
def get_parser():
parser = argparse.ArgumentParser(
description="computer global cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--ark-file",
"-a",
default=False,
required=True,
type=str,
help="fbank ark file",
)
parser.add_argument(
"--ark-index",
"-i",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
mean_stats = np.zeros(args.dims)
var_stats = np.zeros(args.dims)
total_frames = 0
with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
for key, mat in ark_reader:
mean_stats += np.sum(mat, axis=0)
var_stats += np.sum(np.square(mat), axis=0)
total_frames += mat.shape[0]
cmvn_info = {
'mean_stats': list(mean_stats.tolist()),
'var_stats': list(var_stats.tolist()),
'total_frames': total_frames
}
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
if __name__ == '__main__':
main()

View File

@ -1,25 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
nj=32
cmd=./utils/run.pl
feats_dim=80
echo "$0 $@"
. utils/parse_options.sh || exit 1;
fbankdir=$1
logdir=$2
output_dir=${fbankdir}/cmvn; mkdir -p ${output_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
|| exit 1;
python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
echo "$0: Succeeded compute global cmvn"

View File

@ -1,153 +0,0 @@
from kaldiio import WriteHelper
import argparse
import numpy as np
import json
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
def compute_fbank(wav_file,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
resample_rate=16000,
speed=1.0):
waveform, sample_rate = torchaudio.load(wav_file)
if resample_rate != sample_rate:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
new_freq=resample_rate)(waveform)
if speed != 1.0:
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, resample_rate,
[['speed', str(speed)], ['rate', str(resample_rate)]]
)
waveform = waveform * (1 << 15)
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
window_type='hamming',
sample_frequency=resample_rate)
return mat.numpy()
def get_parser():
parser = argparse.ArgumentParser(
description="computer features",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--wav-lists",
"-w",
default=False,
required=True,
type=str,
help="input wav lists",
)
parser.add_argument(
"--text-files",
"-t",
default=False,
required=True,
type=str,
help="input text files",
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--sample-frequency",
"-s",
default=16000,
type=int,
help="sample frequency",
)
parser.add_argument(
"--speed-perturb",
"-p",
default="1.0",
type=str,
help="speed perturb",
)
parser.add_argument(
"--ark-index",
"-a",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".ark"
scp_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".scp"
text_file = args.output_dir + "/txt/text." + str(args.ark_index) + ".txt"
feats_shape_file = args.output_dir + "/ark/len." + str(args.ark_index)
text_shape_file = args.output_dir + "/txt/len." + str(args.ark_index)
ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
text_writer = open(text_file, 'w')
feats_shape_writer = open(feats_shape_file, 'w')
text_shape_writer = open(text_shape_file, 'w')
speed_perturb_list = args.speed_perturb.split(',')
for speed in speed_perturb_list:
with open(args.wav_lists, 'r', encoding='utf-8') as wavfile:
with open(args.text_files, 'r', encoding='utf-8') as textfile:
for wav, text in zip(wavfile, textfile):
s_w = wav.strip().split()
wav_id = s_w[0]
wav_file = s_w[1]
s_t = text.strip().split()
text_id = s_t[0]
txt = s_t[1:]
fbank = compute_fbank(wav_file,
num_mel_bins=args.dims,
resample_rate=args.sample_frequency,
speed=float(speed)
)
feats_dims = fbank.shape[1]
feats_lens = fbank.shape[0]
txt_lens = len(txt)
if speed == "1.0":
wav_id_sp = wav_id
else:
wav_id_sp = wav_id + "_sp" + speed
feats_shape_writer.write(wav_id_sp + " " + str(feats_lens) + "," + str(feats_dims) + '\n')
text_shape_writer.write(wav_id_sp + " " + str(txt_lens) + '\n')
text_writer.write(wav_id_sp + " " + " ".join(txt) + '\n')
ark_writer(wav_id_sp, fbank)
if __name__ == '__main__':
main()

View File

@ -1,51 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
nj=32
cmd=./utils/run.pl
# feature configuration
feats_dim=80
sample_frequency=16000
speed_perturb="1.0"
echo "$0 $@"
. utils/parse_options.sh || exit 1;
data=$1
logdir=$2
fbankdir=$3
[ ! -f $data/wav.scp ] && echo "$0: no such file $data/wav.scp" && exit 1;
[ ! -f $data/text ] && echo "$0: no such file $data/text" && exit 1;
python utils/split_data.py $data $data $nj
ark_dir=${fbankdir}/ark; mkdir -p ${ark_dir}
text_dir=${fbankdir}/txt; mkdir -p ${text_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/make_fbank.JOB.log \
python utils/compute_fbank.py -w $data/split${nj}/JOB/wav.scp -t $data/split${nj}/JOB/text \
-d $feats_dim -s $sample_frequency -p ${speed_perturb} -a JOB -o ${fbankdir} \
|| exit 1;
for n in $(seq $nj); do
cat ${ark_dir}/feats.$n.scp || exit 1
done > $fbankdir/feats.scp || exit 1
for n in $(seq $nj); do
cat ${text_dir}/text.$n.txt || exit 1
done > $fbankdir/text || exit 1
for n in $(seq $nj); do
cat ${ark_dir}/len.$n || exit 1
done > $fbankdir/speech_shape || exit 1
for n in $(seq $nj); do
cat ${text_dir}/len.$n || exit 1
done > $fbankdir/text_shape || exit 1
echo "$0: Succeeded compute FBANK features"

View File

@ -1,157 +0,0 @@
import os
import numpy as np
import sys
def compute_wer(ref_file,
hyp_file,
cer_detail_file):
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
}
hyp_dict = {}
ref_dict = {}
with open(hyp_file, 'r') as hyp_reader:
for line in hyp_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
hyp_dict[key] = value
with open(ref_file, 'r') as ref_reader:
for line in ref_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
ref_dict[key] = value
cer_detail_writer = open(cer_detail_file, 'w')
for hyp_key in hyp_dict:
if hyp_key in ref_dict:
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
cer_detail_writer.write('\n')
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
def compute_wer_by_line(hyp,
ref):
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_ref,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
if i < 0 and j >= 0:
rst['del'] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
return rst
def print_cer_detail(rst):
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
if __name__ == '__main__':
if len(sys.argv) != 4:
print("usage : python compute-wer.py test.ref test.hyp test.wer")
sys.exit(0)
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
cer_detail_file = sys.argv[3]
compute_wer(ref_file, hyp_file, cer_detail_file)

View File

@ -1,370 +0,0 @@
#!/usr/bin/env python3
# coding=utf8
# Copyright 2021 Jiayu DU
import sys
import argparse
import json
import logging
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
DEBUG = None
def GetEditType(ref_token, hyp_token):
if ref_token == None and hyp_token != None:
return 'I'
elif ref_token != None and hyp_token == None:
return 'D'
elif ref_token == hyp_token:
return 'C'
elif ref_token != hyp_token:
return 'S'
else:
raise RuntimeError
class AlignmentArc:
def __init__(self, src, dst, ref, hyp):
self.src = src
self.dst = dst
self.ref = ref
self.hyp = hyp
self.edit_type = GetEditType(ref, hyp)
def similarity_score_function(ref_token, hyp_token):
return 0 if (ref_token == hyp_token) else -1.0
def insertion_score_function(token):
return -1.0
def deletion_score_function(token):
return -1.0
def EditDistance(
ref,
hyp,
similarity_score_function = similarity_score_function,
insertion_score_function = insertion_score_function,
deletion_score_function = deletion_score_function):
assert(len(ref) != 0)
class DPState:
def __init__(self):
self.score = -float('inf')
# backpointer
self.prev_r = None
self.prev_h = None
def print_search_grid(S, R, H, fstream):
print(file=fstream)
for r in range(R):
for h in range(H):
print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
print(file=fstream)
R = len(ref) + 1
H = len(hyp) + 1
# Construct DP search space, a (R x H) grid
S = [ [] for r in range(R) ]
for r in range(R):
S[r] = [ DPState() for x in range(H) ]
# initialize DP search grid origin, S(r = 0, h = 0)
S[0][0].score = 0.0
S[0][0].prev_r = None
S[0][0].prev_h = None
# initialize REF axis
for r in range(1, R):
S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
S[r][0].prev_r = r-1
S[r][0].prev_h = 0
# initialize HYP axis
for h in range(1, H):
S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
S[0][h].prev_r = 0
S[0][h].prev_h = h-1
best_score = S[0][0].score
best_state = (0, 0)
for r in range(1, R):
for h in range(1, H):
sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
new_score = S[r-1][h-1].score + sub_or_cor_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r-1
S[r][h].prev_h = h-1
del_score = deletion_score_function(ref[r-1])
new_score = S[r-1][h].score + del_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r - 1
S[r][h].prev_h = h
ins_score = insertion_score_function(hyp[h-1])
new_score = S[r][h-1].score + ins_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r
S[r][h].prev_h = h-1
best_score = S[R-1][H-1].score
best_state = (R-1, H-1)
if DEBUG:
print_search_grid(S, R, H, sys.stderr)
# Backtracing best alignment path, i.e. a list of arcs
# arc = (src, dst, ref, hyp, edit_type)
# src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
best_path = []
r, h = best_state[0], best_state[1]
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
score = S[r][h].score
# loop invariant:
# 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
# 2. score is the value of point(r, h) on DP search grid
while prev_r != None or prev_h != None:
src = (prev_r, prev_h)
dst = (r, h)
if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
elif (r == prev_r + 1 and h == prev_h): # Deletion
arc = AlignmentArc(src, dst, ref[prev_r], None)
elif (r == prev_r and h == prev_h + 1): # Insertion
arc = AlignmentArc(src, dst, None, hyp[prev_h])
else:
raise RuntimeError
best_path.append(arc)
r, h = prev_r, prev_h
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
score = S[r][h].score
best_path.reverse()
return (best_path, best_score)
def PrettyPrintAlignment(alignment, stream = sys.stderr):
def get_token_str(token):
if token == None:
return "*"
return token
def is_double_width_char(ch):
if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
return True
# TODO: support other double-width-char language such as Japanese, Korean
else:
return False
def display_width(token_str):
m = 0
for c in token_str:
if is_double_width_char(c):
m += 2
else:
m += 1
return m
R = ' REF : '
H = ' HYP : '
E = ' EDIT : '
for arc in alignment:
r = get_token_str(arc.ref)
h = get_token_str(arc.hyp)
e = arc.edit_type if arc.edit_type != 'C' else ''
nr, nh, ne = display_width(r), display_width(h), display_width(e)
n = max(nr, nh, ne) + 1
R += r + ' ' * (n-nr)
H += h + ' ' * (n-nh)
E += e + ' ' * (n-ne)
print(R, file=stream)
print(H, file=stream)
print(E, file=stream)
def CountEdits(alignment):
c, s, i, d = 0, 0, 0, 0
for arc in alignment:
if arc.edit_type == 'C':
c += 1
elif arc.edit_type == 'S':
s += 1
elif arc.edit_type == 'I':
i += 1
elif arc.edit_type == 'D':
d += 1
else:
raise RuntimeError
return (c, s, i, d)
def ComputeTokenErrorRate(c, s, i, d):
return 100.0 * (s + d + i) / (s + d + c)
def ComputeSentenceErrorRate(num_err_utts, num_utts):
assert(num_utts != 0)
return 100.0 * num_err_utts / num_utts
class EvaluationResult:
def __init__(self):
self.num_ref_utts = 0
self.num_hyp_utts = 0
self.num_eval_utts = 0 # seen in both ref & hyp
self.num_hyp_without_ref = 0
self.C = 0
self.S = 0
self.I = 0
self.D = 0
self.token_error_rate = 0.0
self.num_utts_with_error = 0
self.sentence_error_rate = 0.0
def to_json(self):
return json.dumps(self.__dict__)
def to_kaldi(self):
info = (
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
)
return info
def to_sclite(self):
return "TODO"
def to_espnet(self):
return "TODO"
def to_summary(self):
#return json.dumps(self.__dict__, indent=4)
summary = (
'==================== Overall Statistics ====================\n'
F'num_ref_utts: {self.num_ref_utts}\n'
F'num_hyp_utts: {self.num_hyp_utts}\n'
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
F'num_eval_utts: {self.num_eval_utts}\n'
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
F'token_error_rate: {self.token_error_rate:.2f}%\n'
F'token_stats:\n'
F' - tokens:{self.C + self.S + self.D:>7}\n'
F' - edits: {self.S + self.I + self.D:>7}\n'
F' - cor: {self.C:>7}\n'
F' - sub: {self.S:>7}\n'
F' - ins: {self.I:>7}\n'
F' - del: {self.D:>7}\n'
'============================================================\n'
)
return summary
class Utterance:
def __init__(self, uid, text):
self.uid = uid
self.text = text
def LoadUtterances(filepath, format):
utts = {}
if format == 'text': # utt_id word1 word2 ...
with open(filepath, 'r', encoding='utf8') as f:
for line in f:
line = line.strip()
if line:
cols = line.split(maxsplit=1)
assert(len(cols) == 2 or len(cols) == 1)
uid = cols[0]
text = cols[1] if len(cols) == 2 else ''
if utts.get(uid) != None:
raise RuntimeError(F'Found duplicated utterence id {uid}')
utts[uid] = Utterance(uid, text)
else:
raise RuntimeError(F'Unsupported text format {format}')
return utts
def tokenize_text(text, tokenizer):
if tokenizer == 'whitespace':
return text.split()
elif tokenizer == 'char':
return [ ch for ch in ''.join(text.split()) ]
else:
raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# optional
parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
# required
parser.add_argument('--ref', type=str, required=True, help='input reference file')
parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
parser.add_argument('result_file', type=str)
args = parser.parse_args()
logging.info(args)
ref_utts = LoadUtterances(args.ref, args.ref_format)
hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
r = EvaluationResult()
# check valid utterances in hyp that have matched non-empty reference
eval_utts = []
r.num_hyp_without_ref = 0
for uid in sorted(hyp_utts.keys()):
if uid in ref_utts.keys(): # TODO: efficiency
if ref_utts[uid].text.strip(): # non-empty reference
eval_utts.append(uid)
else:
logging.warn(F'Found {uid} with empty reference, skipping...')
else:
logging.warn(F'Found {uid} without reference, skipping...')
r.num_hyp_without_ref += 1
r.num_hyp_utts = len(hyp_utts)
r.num_ref_utts = len(ref_utts)
r.num_eval_utts = len(eval_utts)
with open(args.result_file, 'w+', encoding='utf8') as fo:
for uid in eval_utts:
ref = ref_utts[uid]
hyp = hyp_utts[uid]
alignment, score = EditDistance(
tokenize_text(ref.text, args.tokenizer),
tokenize_text(hyp.text, args.tokenizer)
)
c, s, i, d = CountEdits(alignment)
utt_ter = ComputeTokenErrorRate(c, s, i, d)
# utt-level evaluation result
print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
PrettyPrintAlignment(alignment, fo)
r.C += c
r.S += s
r.I += i
r.D += d
if utt_ter > 0:
r.num_utts_with_error += 1
# corpus level evaluation result
r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
print(r.to_summary(), file=fo)
print(r.to_json())
print(r.to_kaldi())

View File

@ -1,47 +0,0 @@
from transformers import AutoTokenizer, AutoModel, pipeline
import numpy as np
import sys
import os
import torch
from kaldiio import WriteHelper
import re
text_file_json = sys.argv[1]
out_ark = sys.argv[2]
out_scp = sys.argv[3]
out_shape = sys.argv[4]
device = int(sys.argv[5])
model_path = sys.argv[6]
model = AutoModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
with open(text_file_json, 'r') as f:
js = f.readlines()
f_shape = open(out_shape, "w")
with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
with torch.no_grad():
for idx, line in enumerate(js):
id, tokens = line.strip().split(" ", 1)
tokens = re.sub(" ", "", tokens.strip())
tokens = ' '.join([j for j in tokens])
token_num = len(tokens.split(" "))
outputs = extractor(tokens)
outputs = np.array(outputs)
embeds = outputs[0, 1:-1, :]
token_num_embeds, dim = embeds.shape
if token_num == token_num_embeds:
writer(id, embeds)
shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
f_shape.write(shape_line)
else:
print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
f_shape.close()

View File

@ -1,87 +0,0 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
# Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
$shifted=0;
if ($ARGV[0] eq "--exclude") {
$exclude = 1;
shift @ARGV;
$shifted=1;
}
if ($ARGV[0] eq "-f") {
$field = $ARGV[1];
shift @ARGV; shift @ARGV;
$shifted=1
}
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
"only the lines that were *not* in id_list.\n" .
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
"-f option, add 1 to the argument.\n" .
"See also: scripts/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
@A = split;
@A>=1 || die "Invalid id-list file line $_";
$seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
while(<>) {
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
# $1 is what we filter on.
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
print $_;
}
}
} else {
while(<>) {
@A = split;
@A > 0 || die "Invalid scp file line $_";
@A >= $field || die "Invalid scp file line $_";
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
print $_;
}
}
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2)

View File

@ -1,35 +0,0 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/wav.scp ]; then
echo "$0: wav.scp is not found"
exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
echo "$0: text is not found"
exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp
cp ${data_dir}/text ${data_dir}/.backup/text
mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak > ${data_dir}/wav.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
rm ${data_dir}/wav.scp.bak
rm ${data_dir}/text.bak

View File

@ -1,52 +0,0 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/feats.scp ]; then
echo "$0: feats.scp is not found"
exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
echo "$0: text is not found"
exit 1;
fi
if [ ! -f ${data_dir}/speech_shape ]; then
echo "$0: feature lengths is not found"
exit 1;
fi
if [ ! -f ${data_dir}/text_shape ]; then
echo "$0: text lengths is not found"
exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp
cp ${data_dir}/text ${data_dir}/.backup/text
cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape
cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape
mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak > ${data_dir}/feats.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak > ${data_dir}/speech_shape
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak > ${data_dir}/text_shape
rm ${data_dir}/feats.scp.bak
rm ${data_dir}/text.bak
rm ${data_dir}/speech_shape.bak
rm ${data_dir}/text_shape.bak

View File

@ -1,22 +0,0 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=./utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
ark_dir=$1
txt_dir=$2
output_dir=$3
[ ! -d ${ark_dir}/ark ] && echo "$0: ark data is required" && exit 1;
[ ! -d ${txt_dir}/txt ] && echo "$0: txt data is required" && exit 1;
for n in $(seq $nj); do
echo "${ark_dir}/ark/feats.$n.ark ${txt_dir}/txt/text.$n.txt" || exit 1
done > ${output_dir}/ark_txt.scp || exit 1

View File

@ -1,97 +0,0 @@
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.

View File

@ -1,45 +0,0 @@
#!/usr/bin/env python
import sys
def get_commandline_args(no_executable=True):
extra_chars = [
" ",
";",
"&",
"|",
"<",
">",
"?",
"*",
"~",
"`",
'"',
"'",
"\\",
"{",
"}",
"(",
")",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''")
if all(char not in arg for char in extra_chars)
else "'" + arg.replace("'", "'\\''") + "'"
for arg in sys.argv
]
if no_executable:
return " ".join(argv[1:])
else:
return sys.executable + " " + " ".join(argv)
def main():
print(get_commandline_args())
if __name__ == "__main__":
main()

View File

@ -1,35 +0,0 @@
from pathlib import Path
import torch
import yaml
class NoAliasSafeDumper(yaml.SafeDumper):
# Disable anchor/alias in yaml because looks ugly
def ignore_aliases(self, data):
return True
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
"""Safe-dump in yaml with no anchor/alias"""
return yaml.dump(
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
)
def gen_conf(file, out_dir):
conf = torch.load(file)["config"]
conf["oss_bucket"] = "null"
print(conf)
output_dir = Path(out_dir)
output_dir.mkdir(parents=True, exist_ok=True)
with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
yaml_no_alias_safe_dump(conf, f, indent=4, sort_keys=False)
if __name__ == "__main__":
import sys
in_f = sys.argv[1]
out_f = sys.argv[2]
gen_conf(in_f, out_f)

View File

@ -1,31 +0,0 @@
import sys
import re
in_f = sys.argv[1]
out_f = sys.argv[2]
with open(in_f, "r", encoding="utf-8") as f:
lines = f.readlines()
with open(out_f, "w", encoding="utf-8") as f:
for line in lines:
outs = line.strip().split(" ", 1)
if len(outs) == 2:
idx, text = outs
text = re.sub("</s>", "", text)
text = re.sub("<s>", "", text)
text = re.sub("@@", "", text)
text = re.sub("@", "", text)
text = re.sub("<unk>", "", text)
text = re.sub(" ", "", text)
text = text.lower()
else:
idx = outs[0]
text = " "
text = [x for x in text]
text = " ".join(text)
out = "{} {}\n".format(idx, text)
f.write(out)

View File

@ -1,356 +0,0 @@
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# In general, doing
# run.pl some.log a b c is like running the command a b c in
# the bash shell, and putting the standard error and output into some.log.
# To run parallel jobs (backgrounded on the host machine), you can do (e.g.)
# run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB
# and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier].
# If any of the jobs fails, this script will fail.
# A typical example is:
# run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz
# and run.pl will run something like:
# ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log
#
# Basically it takes the command-line arguments, quotes them
# as necessary to preserve spaces, and evaluates them with bash.
# In addition it puts the command line at the top of the log, and
# the start and end times of the command at the beginning and end.
# The reason why this is useful is so that we can create a different
# version of this program that uses a queueing system instead.
#use Data::Dumper;
@ARGV < 2 && die "usage: run.pl log-file command-line arguments...";
#print STDERR "COMMAND-LINE: " . Dumper(\@ARGV) . "\n";
$job_pick = 'all';
$max_jobs_run = -1;
$jobstart = 1;
$jobend = 1;
$ignored_opts = ""; # These will be ignored.
# First parse an option like JOB=1:4, and any
# options that would normally be given to
# queue.pl, which we will just discard.
for (my $x = 1; $x <= 2; $x++) { # This for-loop is to
# allow the JOB=1:n option to be interleaved with the
# options to qsub.
while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) {
# parse any options that would normally go to qsub, but which will be ignored here.
my $switch = shift @ARGV;
if ($switch eq "-V") {
$ignored_opts .= "-V ";
} elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") {
# we do support the option --max-jobs-run n, and its GridEngine form -tc n.
# if the command appears multiple times uses the smallest option.
if ( $max_jobs_run <= 0 ) {
$max_jobs_run = shift @ARGV;
} else {
my $new_constraint = shift @ARGV;
if ( ($new_constraint < $max_jobs_run) ) {
$max_jobs_run = $new_constraint;
}
}
if (! ($max_jobs_run > 0)) {
die "run.pl: invalid option --max-jobs-run $max_jobs_run";
}
} else {
my $argument = shift @ARGV;
if ($argument =~ m/^--/) {
print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n";
}
if ($switch eq "-sync" && $argument =~ m/^[yY]/) {
$ignored_opts .= "-sync "; # Note: in the
# corresponding code in queue.pl it says instead, just "$sync = 1;".
} elsif ($switch eq "-pe") { # e.g. -pe smp 5
my $argument2 = shift @ARGV;
$ignored_opts .= "$switch $argument $argument2 ";
} elsif ($switch eq "--gpu") {
$using_gpu = $argument;
} elsif ($switch eq "--pick") {
if($argument =~ m/^(all|failed|incomplete)$/) {
$job_pick = $argument;
} else {
print STDERR "run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'"
}
} else {
# Ignore option.
$ignored_opts .= "$switch $argument ";
}
}
}
if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20
$jobname = $1;
$jobstart = $2;
$jobend = $3;
if ($jobstart > $jobend) {
die "run.pl: invalid job range $ARGV[0]";
}
if ($jobstart <= 0) {
die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility).";
}
shift;
} elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1.
$jobname = $1;
$jobstart = $2;
$jobend = $2;
shift;
} elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) {
print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n";
}
}
# Users found this message confusing so we are removing it.
# if ($ignored_opts ne "") {
# print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n";
# }
if ($max_jobs_run == -1) { # If --max-jobs-run option not set,
# then work out the number of processors if possible,
# and set it based on that.
$max_jobs_run = 0;
if ($using_gpu) {
if (open(P, "nvidia-smi -L |")) {
$max_jobs_run++ while (<P>);
close(P);
}
if ($max_jobs_run == 0) {
$max_jobs_run = 1;
print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n";
}
} elsif (open(P, "</proc/cpuinfo")) { # Linux
while (<P>) { if (m/^processor/) { $max_jobs_run++; } }
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n";
$max_jobs_run = 10; # reasonable default.
}
close(P);
} elsif (open(P, "sysctl -a |")) { # BSD/Darwin
while (<P>) {
if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4
$max_jobs_run = $1;
last;
}
}
close(P);
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n";
$max_jobs_run = 10; # reasonable default.
}
} else {
# allow at most 32 jobs at once, on non-UNIX systems; change this code
# if you need to change this default.
$max_jobs_run = 32;
}
# The just-computed value of $max_jobs_run is just the number of processors
# (or our best guess); and if it happens that the number of jobs we need to
# run is just slightly above $max_jobs_run, it will make sense to increase
# $max_jobs_run to equal the number of jobs, so we don't have a small number
# of leftover jobs.
$num_jobs = $jobend - $jobstart + 1;
if (!$using_gpu &&
$num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) {
$max_jobs_run = $num_jobs;
}
}
sub pick_or_exit {
# pick_or_exit ( $logfile )
# Invoked before each job is started helps to run jobs selectively.
#
# Given the name of the output logfile decides whether the job must be
# executed (by returning from the subroutine) or not (by terminating the
# process calling exit)
#
# PRE: $job_pick is a global variable set by command line switch --pick
# and indicates which class of jobs must be executed.
#
# 1) If a failed job is not executed the process exit code will indicate
# failure, just as if the task was just executed and failed.
#
# 2) If a task is incomplete it will be executed. Incomplete may be either
# a job whose log file does not contain the accounting notes in the end,
# or a job whose log file does not exist.
#
# 3) If the $job_pick is set to 'all' (default behavior) a task will be
# executed regardless of the result of previous attempts.
#
# This logic could have been implemented in the main execution loop
# but a subroutine to preserve the current level of readability of
# that part of the code.
#
# Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020
#
if($job_pick eq 'all'){
return; # no need to bother with the previous log
}
open my $fh, "<", $_[0] or return; # job not executed yet
my $log_line;
my $cur_line;
while ($cur_line = <$fh>) {
if( $cur_line =~ m/# Ended \(code .*/ ) {
$log_line = $cur_line;
}
}
close $fh;
if (! defined($log_line)){
return; # incomplete
}
if ( $log_line =~ m/# Ended \(code 0\).*/ ) {
exit(0); # complete
} elsif ( $log_line =~ m/# Ended \(code \d+(; signal \d+)?\).*/ ){
if ($job_pick !~ m/^(failed|all)$/) {
exit(1); # failed but not going to run
} else {
return; # failed
}
} elsif ( $log_line =~ m/.*\S.*/ ) {
return; # incomplete jobs are always run
}
}
$logfile = shift @ARGV;
if (defined $jobname && $logfile !~ m/$jobname/ &&
$jobend > $jobstart) {
print STDERR "run.pl: you are trying to run a parallel job but "
. "you are putting the output into just one log file ($logfile)\n";
exit(1);
}
$cmd = "";
foreach $x (@ARGV) {
if ($x =~ m/^\S+$/) { $cmd .= $x . " "; }
elsif ($x =~ m:\":) { $cmd .= "'$x' "; }
else { $cmd .= "\"$x\" "; }
}
#$Data::Dumper::Indent=0;
$ret = 0;
$numfail = 0;
%active_pids=();
use POSIX ":sys_wait_h";
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
if (scalar(keys %active_pids) >= $max_jobs_run) {
# Lets wait for a change in any child's status
# Then we have to work out which child finished
$r = waitpid(-1, 0);
$code = $?;
if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen.
if ( defined $active_pids{$r} ) {
$jid=$active_pids{$r};
$fail[$jid]=$code;
if ($code !=0) { $numfail++;}
delete $active_pids{$r};
# print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n";
} else {
die "run.pl: Cannot find the PID of the child process that just finished.";
}
# In theory we could do a non-blocking waitpid over all jobs running just
# to find out if only one or more jobs finished during the previous waitpid()
# However, we just omit this and will reap the next one in the next pass
# through the for(;;) cycle
}
$childpid = fork();
if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; }
if ($childpid == 0) { # We're in the child... this branch
# executes the job and returns (possibly with an error status).
if (defined $jobname) {
$cmd =~ s/$jobname/$jobid/g;
$logfile =~ s/$jobname/$jobid/g;
}
# exit if the job does not need to be executed
pick_or_exit( $logfile );
system("mkdir -p `dirname $logfile` 2>/dev/null");
open(F, ">$logfile") || die "run.pl: Error opening log file $logfile";
print F "# " . $cmd . "\n";
print F "# Started at " . `date`;
$starttime = `date +'%s'`;
print F "#\n";
close(F);
# Pipe into bash.. make sure we're not using any other shell.
open(B, "|bash") || die "run.pl: Error opening shell command";
print B "( " . $cmd . ") 2>>$logfile >> $logfile";
close(B); # If there was an error, exit status is in $?
$ret = $?;
$lowbits = $ret & 127;
$highbits = $ret >> 8;
if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" }
else { $return_str = "code $highbits"; }
$endtime = `date +'%s'`;
open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)";
$enddate = `date`;
chop $enddate;
print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n";
print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n";
close(F);
exit($ret == 0 ? 0 : 1);
} else {
$pid[$jobid] = $childpid;
$active_pids{$childpid} = $jobid;
# print STDERR "Queued: " . Dumper(\%active_pids) . "\n";
}
}
# Now we have submitted all the jobs, lets wait until all the jobs finish
foreach $child (keys %active_pids) {
$jobid=$active_pids{$child};
$r = waitpid($pid[$jobid], 0);
$code = $?;
if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen.
if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully
}
# Some sanity checks:
# The $fail array should not contain undefined codes
# The number of non-zeros in that array should be equal to $numfail
# We cannot do foreach() here, as the JOB ids do not start at zero
$failed_jids=0;
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
$job_return = $fail[$jobid];
if (not defined $job_return ) {
# print Dumper(\@fail);
die "run.pl: Sanity check failed: we have indication that some jobs are running " .
"even after we waited for all jobs to finish" ;
}
if ($job_return != 0 ){ $failed_jids++;}
}
if ($failed_jids != $numfail) {
die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)."
}
if ($numfail > 0) { $ret = 1; }
if ($ret != 0) {
$njobs = $jobend - $jobstart + 1;
if ($njobs == 1) {
if (defined $jobname) {
$logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with
# that job.
}
print STDERR "run.pl: job failed, log is in $logfile\n";
if ($logfile =~ m/JOB/) {
print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script.";
}
}
else {
$logfile =~ s/$jobname/*/g;
print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n";
}
}
exit ($ret);

View File

@ -1,44 +0,0 @@
#!/usr/bin/env perl
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
if ($ARGV[0] eq "--srand") {
$n = $ARGV[1];
$n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\"";
srand($ARGV[1]);
shift;
shift;
} else {
srand(0); # Gives inconsistent behavior if we don't seed.
}
if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we
# don't understand.
print "Usage: shuffle_list.pl [--srand N] [input file] > output\n";
print "randomizes the order of lines of input.\n";
exit(1);
}
@lines;
while (<>) {
push @lines, [ (rand(), $_)] ;
}
@lines = sort { $a->[0] cmp $b->[0] } @lines;
foreach $l (@lines) {
print $l->[1];
}

View File

@ -1,60 +0,0 @@
import os
import sys
import random
in_dir = sys.argv[1]
out_dir = sys.argv[2]
num_split = sys.argv[3]
def split_scp(scp, num):
assert len(scp) >= num
avg = len(scp) // num
out = []
begin = 0
for i in range(num):
if i == num - 1:
out.append(scp[begin:])
else:
out.append(scp[begin:begin+avg])
begin += avg
return out
os.path.exists("{}/wav.scp".format(in_dir))
os.path.exists("{}/text".format(in_dir))
with open("{}/wav.scp".format(in_dir), 'r') as infile:
wav_list = infile.readlines()
with open("{}/text".format(in_dir), 'r') as infile:
text_list = infile.readlines()
assert len(wav_list) == len(text_list)
x = list(zip(wav_list, text_list))
random.shuffle(x)
wav_shuffle_list, text_shuffle_list = zip(*x)
num_split = int(num_split)
wav_split_list = split_scp(wav_shuffle_list, num_split)
text_split_list = split_scp(text_shuffle_list, num_split)
for idx, wav_list in enumerate(wav_split_list, 1):
path = out_dir + "/split" + str(num_split) + "/" + str(idx)
if not os.path.exists(path):
os.makedirs(path)
with open("{}/wav.scp".format(path), 'w') as wav_writer:
for line in wav_list:
wav_writer.write(line)
for idx, text_list in enumerate(text_split_list, 1):
path = out_dir + "/split" + str(num_split) + "/" + str(idx)
if not os.path.exists(path):
os.makedirs(path)
with open("{}/text".format(path), 'w') as text_writer:
for line in text_list:
text_writer.write(line)

View File

@ -1,246 +0,0 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# See ../../COPYING for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can. If you use the utt2spk
# option it will make sure these chunks coincide with speaker boundaries. In
# this case, if there are more chunks than speakers (and in some other
# circumstances), some of the resulting chunks will be empty and it will print
# an error message and exit with nonzero status.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
# You can also call the scripts like:
# split_scp.pl -j 3 0 scp scp.0
# [note: with this option, it assumes zero-based indexing of the split parts,
# i.e. the second number must be 0 <= n < num-jobs.]
use warnings;
$num_jobs = 0;
$job_id = 0;
$utt2spk_file = "";
$one_based = 0;
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
if ($ARGV[0] eq "-j") {
shift @ARGV;
$num_jobs = shift @ARGV;
$job_id = shift @ARGV;
}
if ($ARGV[0] =~ /--utt2spk=(.+)/) {
$utt2spk_file=$1;
shift;
}
if ($ARGV[0] eq '--one-based') {
$one_based = 1;
shift @ARGV;
}
}
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
$job_id - $one_based >= $num_jobs)) {
die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
($one_based ? " --one-based" : "") . "'\n"
}
$one_based
and $job_id--;
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
die
"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
}
$error = 0;
$inscp = shift @ARGV;
if ($num_jobs == 0) { # without -j option
@OUTPUTS = @ARGV;
} else {
for ($j = 0; $j < $num_jobs; $j++) {
if ($j == $job_id) {
if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
else { push @OUTPUTS, "-"; }
} else {
push @OUTPUTS, "/dev/null";
}
}
}
if ($utt2spk_file ne "") { # We have the --utt2spk option...
open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
while(<$u_fh>) {
@A = split;
@A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
($u,$s) = @A;
$utt2spk{$u} = $s;
}
close $u_fh;
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
@spkrs = ();
while(<$i_fh>) {
@A = split;
if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
$u = $A[0];
$s = $utt2spk{$u};
defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
if(!defined $spk_count{$s}) {
push @spkrs, $s;
$spk_count{$s} = 0;
$spk_data{$s} = []; # ref to new empty array.
}
$spk_count{$s}++;
push @{$spk_data{$s}}, $_;
}
# Now split as equally as possible ..
# First allocate spks to files by allocating an approximately
# equal number of speakers.
$numspks = @spkrs; # number of speakers.
$numscps = @OUTPUTS; # number of output files.
if ($numspks < $numscps) {
die "$0: Refusing to split data because number of speakers $numspks " .
"is less than the number of output .scp files $numscps\n";
}
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scparray[$scpidx] = []; # [] is array reference.
}
for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
$scpidx = int(($spkidx*$numscps) / $numspks);
$spk = $spkrs[$spkidx];
push @{$scparray[$scpidx]}, $spk;
$scpcount[$scpidx] += $spk_count{$spk};
}
# Now will try to reassign beginning + ending speakers
# to different scp's and see if it gets more balanced.
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
# We can show that if considering changing just 2 scp's, we minimize
# this by minimizing the squared difference in sizes. This is
# equivalent to minimizing the absolute difference in sizes. This
# shows this method is bound to converge.
$changed = 1;
while($changed) {
$changed = 0;
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
# First try to reassign ending spk of this scp.
if($scpidx < $numscps-1) {
$sz = @{$scparray[$scpidx]};
if($sz > 0) {
$spk = $scparray[$scpidx]->[$sz-1];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx];
$nutt2 = $scpcount[$scpidx+1];
if( abs( ($nutt2+$count) - ($nutt1-$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx+1] += $count;
$scpcount[$scpidx] -= $count;
pop @{$scparray[$scpidx]};
unshift @{$scparray[$scpidx+1]}, $spk;
$changed = 1;
}
}
}
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
$spk = $scparray[$scpidx]->[0];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx-1];
$nutt2 = $scpcount[$scpidx];
if( abs( ($nutt2-$count) - ($nutt1+$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx-1] += $count;
$scpcount[$scpidx] -= $count;
shift @{$scparray[$scpidx]};
push @{$scparray[$scpidx-1]}, $spk;
$changed = 1;
}
}
}
}
# Now print out the files...
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($f_fh, '>', $scpfile)
: open($f_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
$count = 0;
if(@{$scparray[$scpidx]} == 0) {
print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
"$scpfile (too many splits and too few speakers?)\n";
$error = 1;
} else {
foreach $spk ( @{$scparray[$scpidx]} ) {
print $f_fh @{$spk_data{$spk}};
$count += $spk_count{$spk};
}
$count == $scpcount[$scpidx] || die "Count mismatch [code error]";
}
close($f_fh);
}
} else {
# This block is the "normal" case where there is no --utt2spk
# option and we just break into equal size chunks.
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
$numscps = @OUTPUTS; # size of array.
@F = ();
while(<$i_fh>) {
push @F, $_;
}
$numlines = @F;
if($numlines == 0) {
print STDERR "$0: error: empty input scp file $inscp\n";
$error = 1;
}
$linesperscp = int( $numlines / $numscps); # the "whole part"..
$linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
$remainder = $numlines - ($linesperscp * $numscps);
($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
# [just doing int() rounds down].
$n = 0;
for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($o_fh, '>', $scpfile)
: open($o_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
print $o_fh $F[$n++];
}
close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
}
$n == $numlines || die "$n != $numlines [code error]";
}
exit ($error);

View File

@ -1,30 +0,0 @@
#!/usr/bin/env bash
dev_num_utt=1000
echo "$0 $@"
. utils/parse_options.sh || exit 1;
train_data=$1
out_dir=$2
[ ! -f ${train_data}/wav.scp ] && echo "$0: no such file ${train_data}/wav.scp" && exit 1;
[ ! -f ${train_data}/text ] && echo "$0: no such file ${train_data}/text" && exit 1;
mkdir -p ${out_dir}/train && mkdir -p ${out_dir}/dev
cp ${train_data}/wav.scp ${out_dir}/train/wav.scp.bak
cp ${train_data}/text ${out_dir}/train/text.bak
num_utt=$(wc -l <${out_dir}/train/wav.scp.bak)
utils/shuffle_list.pl --srand 1 ${out_dir}/train/wav.scp.bak > ${out_dir}/train/wav.scp.shuf
head -n ${dev_num_utt} ${out_dir}/train/wav.scp.shuf > ${out_dir}/dev/wav.scp
tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/wav.scp.shuf > ${out_dir}/train/wav.scp
utils/shuffle_list.pl --srand 1 ${out_dir}/train/text.bak > ${out_dir}/train/text.shuf
head -n ${dev_num_utt} ${out_dir}/train/text.shuf > ${out_dir}/dev/text
tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/text.shuf > ${out_dir}/train/text
rm ${out_dir}/train/wav.scp.bak ${out_dir}/train/text.bak
rm ${out_dir}/train/wav.scp.shuf ${out_dir}/train/text.shuf

View File

@ -1,135 +0,0 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import re
import sys
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
start_pos = None
end_pos = None
for pos in match_pos:
if pos[0] <= i < pos[1]:
start_pos = pos[0]
end_pos = pos[1]
break
return start_pos, end_pos
def get_parser():
parser = argparse.ArgumentParser(
description="convert raw text to tokenized text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--nchar",
"-n",
default=1,
type=int,
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
default=None,
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
type=str,
default="char",
choices=["char", "phn"],
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
If trans_type is char,
read from SI1279.WRD file -> "bricks are an alternative"
Else if trans_type is phn,
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
sil t er n ih sil t ih v sil" """,
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
rs = []
if args.non_lang_syms is not None:
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
nls = [x.rstrip() for x in f.readlines()]
rs = [re.compile(re.escape(x)) for x in nls]
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
line = f.readline()
n = args.nchar
while line:
x = line.split()
print(" ".join(x[: args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols :])
# get all matched positions
match_pos = []
for r in rs:
i = 0
while i >= 0:
m = r.search(a, i)
if m:
match_pos.append([m.start(), m.end()])
i = m.end()
else:
break
if args.trans_type == "phn":
a = a.split(" ")
else:
if len(match_pos) > 0:
chars = []
i = 0
while i < len(a):
start_pos, end_pos = exist_or_not(i, match_pos)
if start_pos is not None:
chars.append(a[start_pos:end_pos])
i = end_pos
else:
chars.append(a[i])
i += 1
a = chars
a = [a[j : j + n] for j in range(0, len(a), n)]
a_flat = []
for z in a:
a_flat.append("".join(z))
a_chars = [z.replace(" ", args.space) for z in a_flat]
if args.trans_type == "phn":
a_chars = [z.replace("sil", args.space) for z in a_chars]
print(" ".join(a_chars))
line = f.readline()
if __name__ == "__main__":
main()

View File

@ -1,106 +0,0 @@
import re
import argparse
def load_dict(seg_file):
seg_dict = {}
with open(seg_file, 'r') as infile:
for line in infile:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dict
def forward_segment(text, dic):
word_list = []
i = 0
while i < len(text):
longest_word = text[i]
for j in range(i + 1, len(text) + 1):
word = text[i:j]
if word in dic:
if len(word) > len(longest_word):
longest_word = word
word_list.append(longest_word)
i += len(longest_word)
return word_list
def tokenize(txt,
seg_dict):
out_txt = ""
pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
for word in txt:
if pattern.match(word):
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
out_txt += "<unk>" + " "
else:
continue
return out_txt.strip()
def get_parser():
parser = argparse.ArgumentParser(
description="text tokenize",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--text-file",
"-t",
default=False,
required=True,
type=str,
help="input text",
)
parser.add_argument(
"--seg-file",
"-s",
default=False,
required=True,
type=str,
help="seg file",
)
parser.add_argument(
"--txt-index",
"-i",
default=1,
required=True,
type=int,
help="txt index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w')
shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w')
seg_dict = load_dict(args.seg_file)
with open(args.text_file, 'r') as infile:
for line in infile:
s = line.strip().split()
text_id = s[0]
text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
text = tokenize(text_list, seg_dict)
lens = len(text.strip().split())
txt_writer.write(text_id + " " + text + '\n')
shape_writer.write(text_id + " " + str(lens) + '\n')
if __name__ == '__main__':
main()

View File

@ -1,35 +0,0 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
# tokenize configuration
text_dir=$1
seg_file=$2
logdir=$3
output_dir=$4
txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \
python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \
-s ${seg_file} -i JOB -o ${txt_dir} \
|| exit 1;
# concatenate the text files together.
for n in $(seq $nj); do
cat ${txt_dir}/text.$n.txt || exit 1
done > ${output_dir}/text || exit 1
for n in $(seq $nj); do
cat ${txt_dir}/len.$n || exit 1
done > ${output_dir}/text_shape || exit 1
echo "$0: Succeeded text tokenize"

View File

@ -1,834 +0,0 @@
#!/usr/bin/env python3
# coding=utf-8
# Authors:
# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
# 2019.9 Jiayu DU
#
# requirements:
# - python 3.X
# notes: python 2.X WILL fail or produce misleading results
import sys, os, argparse, codecs, string, re
# ================================================================================ #
# basic constant
# ================================================================================ #
CHINESE_DIGIS = u'零一二三四五六七八九'
BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
ZERO_ALT = u''
ONE_ALT = u''
TWO_ALTS = [u'', u'']
POSITIVE = [u'', u'']
NEGATIVE = [u'', u'']
POINT = [u'', u'']
# PLUS = [u'加', u'加']
# SIL = [u'杠', u'槓']
FILLER_CHARS = ['', '']
ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \
'胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \
'儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \
'佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)'
# 中文数字系统类型
NUMBERING_TYPES = ['low', 'mid', 'high']
CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
'里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
'砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
'针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
'毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
'盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
'纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
CHINESE_PUNC_STOP = '!?。。'
CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏'
CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
# ================================================================================ #
# basic class
# ================================================================================ #
class ChineseChar(object):
"""
中文字符
每个字符对应简体和繁体,
e.g. 简体 = '', 繁体 = ''
转换时可转换为简体或繁体
"""
def __init__(self, simplified, traditional):
self.simplified = simplified
self.traditional = traditional
#self.__repr__ = self.__str__
def __str__(self):
return self.simplified or self.traditional or None
def __repr__(self):
return self.__str__()
class ChineseNumberUnit(ChineseChar):
"""
中文数字/数位字符
每个字符除繁简体外还有一个额外的大写字符
e.g. '' ''
"""
def __init__(self, power, simplified, traditional, big_s, big_t):
super(ChineseNumberUnit, self).__init__(simplified, traditional)
self.power = power
self.big_s = big_s
self.big_t = big_t
def __str__(self):
return '10^{}'.format(self.power)
@classmethod
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
if small_unit:
return ChineseNumberUnit(power=index + 1,
simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
elif numbering_type == NUMBERING_TYPES[0]:
return ChineseNumberUnit(power=index + 8,
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
elif numbering_type == NUMBERING_TYPES[1]:
return ChineseNumberUnit(power=(index + 2) * 4,
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
elif numbering_type == NUMBERING_TYPES[2]:
return ChineseNumberUnit(power=pow(2, index + 3),
simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
else:
raise ValueError(
'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
class ChineseNumberDigit(ChineseChar):
"""
中文数字字符
"""
def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
super(ChineseNumberDigit, self).__init__(simplified, traditional)
self.value = value
self.big_s = big_s
self.big_t = big_t
self.alt_s = alt_s
self.alt_t = alt_t
def __str__(self):
return str(self.value)
@classmethod
def create(cls, i, v):
return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
class ChineseMath(ChineseChar):
"""
中文数位字符
"""
def __init__(self, simplified, traditional, symbol, expression=None):
super(ChineseMath, self).__init__(simplified, traditional)
self.symbol = symbol
self.expression = expression
self.big_s = simplified
self.big_t = traditional
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
class NumberSystem(object):
"""
中文数字系统
"""
pass
class MathSymbol(object):
"""
用于中文数字系统的数学符号 (/简体), e.g.
positive = ['', '']
negative = ['', '']
point = ['', '']
"""
def __init__(self, positive, negative, point):
self.positive = positive
self.negative = negative
self.point = point
def __iter__(self):
for v in self.__dict__.values():
yield v
# class OtherSymbol(object):
# """
# 其他符号
# """
#
# def __init__(self, sil):
# self.sil = sil
#
# def __iter__(self):
# for v in self.__dict__.values():
# yield v
# ================================================================================ #
# basic utils
# ================================================================================ #
def create_system(numbering_type=NUMBERING_TYPES[1]):
"""
根据数字系统类型返回创建相应的数字系统默认为 mid
NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
low: '' = '亿' * '' = $10^{9}$, '' = '' * '', etc.
mid: '' = '亿' * '' = $10^{12}$, '' = '' * '', etc.
high: '' = '亿' * '亿' = $10^{16}$, '' = '' * '', etc.
返回对应的数字系统
"""
# chinese number units of '亿' and larger
all_larger_units = zip(
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
larger_units = [CNU.create(i, v, numbering_type, False)
for i, v in enumerate(all_larger_units)]
# chinese number units of '十, 百, 千, 万'
all_smaller_units = zip(
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
smaller_units = [CNU.create(i, v, small_unit=True)
for i, v in enumerate(all_smaller_units)]
# digis
chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
# symbols
positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
point_cn = CM(POINT[0], POINT[1], '.', lambda x,
y: float(str(x) + '.' + str(y)))
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
system = NumberSystem()
system.units = smaller_units + larger_units
system.digits = digits
system.math = MathSymbol(positive_cn, negative_cn, point_cn)
# system.symbols = OtherSymbol(sil_cn)
return system
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
def get_symbol(char, system):
for u in system.units:
if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
return u
for d in system.digits:
if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
return d
for m in system.math:
if char in [m.traditional, m.simplified]:
return m
def string2symbols(chinese_string, system):
int_string, dec_string = chinese_string, ''
for p in [system.math.point.simplified, system.math.point.traditional]:
if p in chinese_string:
int_string, dec_string = chinese_string.split(p)
break
return [get_symbol(c, system) for c in int_string], \
[get_symbol(c, system) for c in dec_string]
def correct_symbols(integer_symbols, system):
"""
一百八 to 一百八十
一亿一千三百万 to 一亿 一千万 三百万
"""
if integer_symbols and isinstance(integer_symbols[0], CNU):
if integer_symbols[0].power == 1:
integer_symbols = [system.digits[1]] + integer_symbols
if len(integer_symbols) > 1:
if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
integer_symbols.append(
CNU(integer_symbols[-2].power - 1, None, None, None, None))
result = []
unit_count = 0
for s in integer_symbols:
if isinstance(s, CND):
result.append(s)
unit_count = 0
elif isinstance(s, CNU):
current_unit = CNU(s.power, None, None, None, None)
unit_count += 1
if unit_count == 1:
result.append(current_unit)
elif unit_count > 1:
for i in range(len(result)):
if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
result[-i - 1] = CNU(result[-i - 1].power +
current_unit.power, None, None, None, None)
return result
def compute_value(integer_symbols):
"""
Compute the value.
When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
e.g. '两千万' = 2000 * 10000 not 2000 + 10000
"""
value = [0]
last_power = 0
for s in integer_symbols:
if isinstance(s, CND):
value[-1] = s.value
elif isinstance(s, CNU):
value[-1] *= pow(10, s.power)
if s.power > last_power:
value[:-1] = list(map(lambda v: v *
pow(10, s.power), value[:-1]))
last_power = s.power
value.append(0)
return sum(value)
system = create_system(numbering_type)
int_part, dec_part = string2symbols(chinese_string, system)
int_part = correct_symbols(int_part, system)
int_str = str(compute_value(int_part))
dec_str = ''.join([str(d.value) for d in dec_part])
if dec_part:
return '{0}.{1}'.format(int_str, dec_str)
else:
return int_str
def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
traditional=False, alt_zero=False, alt_one=False, alt_two=True,
use_zeros=True, use_units=True):
def get_value(value_string, use_zeros=True):
striped_string = value_string.lstrip('0')
# record nothing if all zeros
if not striped_string:
return []
# record one digits
elif len(striped_string) == 1:
if use_zeros and len(value_string) != len(striped_string):
return [system.digits[0], system.digits[int(striped_string)]]
else:
return [system.digits[int(striped_string)]]
# recursively record multiple digits
else:
result_unit = next(u for u in reversed(
system.units) if u.power < len(striped_string))
result_string = value_string[:-result_unit.power]
return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
system = create_system(numbering_type)
int_dec = number_string.split('.')
if len(int_dec) == 1:
int_string = int_dec[0]
dec_string = ""
elif len(int_dec) == 2:
int_string = int_dec[0]
dec_string = int_dec[1]
else:
raise ValueError(
"invalid input num string with more than one dot: {}".format(number_string))
if use_units and len(int_string) > 1:
result_symbols = get_value(int_string)
else:
result_symbols = [system.digits[int(c)] for c in int_string]
dec_symbols = [system.digits[int(c)] for c in dec_string]
if dec_string:
result_symbols += [system.math.point] + dec_symbols
if alt_two:
liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
system.digits[2].big_s, system.digits[2].big_t)
for i, v in enumerate(result_symbols):
if isinstance(v, CND) and v.value == 2:
next_symbol = result_symbols[i +
1] if i < len(result_symbols) - 1 else None
previous_symbol = result_symbols[i - 1] if i > 0 else None
if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output
if big:
attr_name = 'big_'
if traditional:
attr_name += 't'
else:
attr_name += 's'
else:
if traditional:
attr_name = 'traditional'
else:
attr_name = 'simplified'
result = ''.join([getattr(s, attr_name) for s in result_symbols])
# if not use_zeros:
# result = result.strip(getattr(system.digits[0], attr_name))
if alt_zero:
result = result.replace(
getattr(system.digits[0], attr_name), system.digits[0].alt_s)
if alt_one:
result = result.replace(
getattr(system.digits[1], attr_name), system.digits[1].alt_s)
for i, p in enumerate(POINT):
if result.startswith(p):
return CHINESE_DIGIS[0] + result
# ^10, 11, .., 19
if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
result = result[1:]
return result
# ================================================================================ #
# different types of rewriters
# ================================================================================ #
class Cardinal:
"""
CARDINAL类
"""
def __init__(self, cardinal=None, chntext=None):
self.cardinal = cardinal
self.chntext = chntext
def chntext2cardinal(self):
return chn2num(self.chntext)
def cardinal2chntext(self):
return num2chn(self.cardinal)
class Digit:
"""
DIGIT类
"""
def __init__(self, digit=None, chntext=None):
self.digit = digit
self.chntext = chntext
# def chntext2digit(self):
# return chn2num(self.chntext)
def digit2chntext(self):
return num2chn(self.digit, alt_two=False, use_units=False)
class TelePhone:
"""
TELEPHONE类
"""
def __init__(self, telephone=None, raw_chntext=None, chntext=None):
self.telephone = telephone
self.raw_chntext = raw_chntext
self.chntext = chntext
# def chntext2telephone(self):
# sil_parts = self.raw_chntext.split('<SIL>')
# self.telephone = '-'.join([
# str(chn2num(p)) for p in sil_parts
# ])
# return self.telephone
def telephone2chntext(self, fixed=False):
if fixed:
sil_parts = self.telephone.split('-')
self.raw_chntext = '<SIL>'.join([
num2chn(part, alt_two=False, use_units=False) for part in sil_parts
])
self.chntext = self.raw_chntext.replace('<SIL>', '')
else:
sp_parts = self.telephone.strip('+').split()
self.raw_chntext = '<SP>'.join([
num2chn(part, alt_two=False, use_units=False) for part in sp_parts
])
self.chntext = self.raw_chntext.replace('<SP>', '')
return self.chntext
class Fraction:
"""
FRACTION类
"""
def __init__(self, fraction=None, chntext=None):
self.fraction = fraction
self.chntext = chntext
def chntext2fraction(self):
denominator, numerator = self.chntext.split('分之')
return chn2num(numerator) + '/' + chn2num(denominator)
def fraction2chntext(self):
numerator, denominator = self.fraction.split('/')
return num2chn(denominator) + '分之' + num2chn(numerator)
class Date:
"""
DATE类
"""
def __init__(self, date=None, chntext=None):
self.date = date
self.chntext = chntext
# def chntext2date(self):
# chntext = self.chntext
# try:
# year, other = chntext.strip().split('年', maxsplit=1)
# year = Digit(chntext=year).digit2chntext() + '年'
# except ValueError:
# other = chntext
# year = ''
# if other:
# try:
# month, day = other.strip().split('月', maxsplit=1)
# month = Cardinal(chntext=month).chntext2cardinal() + '月'
# except ValueError:
# day = chntext
# month = ''
# if day:
# day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
# else:
# month = ''
# day = ''
# date = year + month + day
# self.date = date
# return self.date
def date2chntext(self):
date = self.date
try:
year, other = date.strip().split('', 1)
year = Digit(digit=year).digit2chntext() + ''
except ValueError:
other = date
year = ''
if other:
try:
month, day = other.strip().split('', 1)
month = Cardinal(cardinal=month).cardinal2chntext() + ''
except ValueError:
day = date
month = ''
if day:
day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
else:
month = ''
day = ''
chntext = year + month + day
self.chntext = chntext
return self.chntext
class Money:
"""
MONEY类
"""
def __init__(self, money=None, chntext=None):
self.money = money
self.chntext = chntext
# def chntext2money(self):
# return self.money
def money2chntext(self):
money = self.money
pattern = re.compile(r'(\d+(\.\d+)?)')
matchers = pattern.findall(money)
if matchers:
for matcher in matchers:
money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
self.chntext = money
return self.chntext
class Percentage:
"""
PERCENTAGE类
"""
def __init__(self, percentage=None, chntext=None):
self.percentage = percentage
self.chntext = chntext
def chntext2percentage(self):
return chn2num(self.chntext.strip().strip('百分之')) + '%'
def percentage2chntext(self):
return '百分之' + num2chn(self.percentage.strip().strip('%'))
def remove_erhua(text, er_whitelist):
"""
去除儿化音词中的儿:
他女儿在那边儿 -> 他女儿在那边
"""
er_pattern = re.compile(er_whitelist)
new_str=''
while re.search('',text):
a = re.search('',text).span()
remove_er_flag = 0
if er_pattern.search(text):
b = er_pattern.search(text).span()
if b[0] <= a[0]:
remove_er_flag = 1
if remove_er_flag == 0 :
new_str = new_str + text[0:a[0]]
text = text[a[1]:]
else:
new_str = new_str + text[0:b[1]]
text = text[b[1]:]
text = new_str + text
return text
# ================================================================================ #
# NSW Normalizer
# ================================================================================ #
class NSWNormalizer:
def __init__(self, raw_text):
self.raw_text = '^' + raw_text + '$'
self.norm_text = ''
def _particular(self):
text = self.norm_text
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
matchers = pattern.findall(text)
if matchers:
# print('particular')
for matcher in matchers:
text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
self.norm_text = text
return self.norm_text
def normalize(self):
text = self.raw_text
# 规范化日期
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
matchers = pattern.findall(text)
if matchers:
#print('date')
for matcher in matchers:
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
# 规范化金钱
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
matchers = pattern.findall(text)
if matchers:
#print('money')
for matcher in matchers:
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
# 规范化固话/手机号码
# 手机
# http://www.jihaoba.com/news/show/13680
# 移动139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
# 联通130、131、132、156、155、186、185、176
# 电信133、153、189、180、181、177
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
matchers = pattern.findall(text)
if matchers:
#print('telephone')
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
# 固话
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
matchers = pattern.findall(text)
if matchers:
# print('fixed telephone')
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
# 规范化分数
pattern = re.compile(r"(\d+/\d+)")
matchers = pattern.findall(text)
if matchers:
#print('fraction')
for matcher in matchers:
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
# 规范化百分数
text = text.replace('', '%')
pattern = re.compile(r"(\d+(\.\d+)?%)")
matchers = pattern.findall(text)
if matchers:
#print('percentage')
for matcher in matchers:
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
# 规范化纯数+量词
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
matchers = pattern.findall(text)
if matchers:
#print('cardinal+quantifier')
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
# 规范化数字编号
pattern = re.compile(r"(\d{4,32})")
matchers = pattern.findall(text)
if matchers:
#print('digit')
for matcher in matchers:
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
# 规范化纯数
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(text)
if matchers:
#print('cardinal')
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
self.norm_text = text
self._particular()
return self.norm_text.lstrip('^').rstrip('$')
def nsw_test_case(raw_text):
print('I:' + raw_text)
print('O:' + NSWNormalizer(raw_text).normalize())
print('')
def nsw_test():
nsw_test_case('固话0595-23865596或23880880。')
nsw_test_case('固话0595-23865596或23880880。')
nsw_test_case('手机:+86 19859213959或15659451527。')
nsw_test_case('分数32477/76391。')
nsw_test_case('百分数80.03%')
nsw_test_case('编号31520181154418。')
nsw_test_case('纯数2983.07克或12345.60米。')
nsw_test_case('日期1999年2月20日或09年3月15号。')
nsw_test_case('金钱12块534.5元20.1万')
nsw_test_case('特殊O2O或B2C。')
nsw_test_case('3456万吨')
nsw_test_case('2938个')
nsw_test_case('938')
nsw_test_case('今天吃了115个小笼包231个馒头')
nsw_test_case('有62的概率')
if __name__ == '__main__':
#nsw_test()
p = argparse.ArgumentParser()
p.add_argument('ifile', help='input filename, assume utf-8 encoding')
p.add_argument('ofile', help='output filename')
p.add_argument('--to_upper', action='store_true', help='convert to upper case')
p.add_argument('--to_lower', action='store_true', help='convert to lower case')
p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "呃, 啊"')
p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "这儿"')
p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
args = p.parse_args()
ifile = codecs.open(args.ifile, 'r', 'utf8')
ofile = codecs.open(args.ofile, 'w+', 'utf8')
n = 0
for l in ifile:
key = ''
text = ''
if args.has_key:
cols = l.split(maxsplit=1)
key = cols[0]
if len(cols) == 2:
text = cols[1].strip()
else:
text = ''
else:
text = l.strip()
# cases
if args.to_upper and args.to_lower:
sys.stderr.write('text norm: to_upper OR to_lower?')
exit(1)
if args.to_upper:
text = text.upper()
if args.to_lower:
text = text.lower()
# Filler chars removal
if args.remove_fillers:
for ch in FILLER_CHARS:
text = text.replace(ch, '')
if args.remove_erhua:
text = remove_erhua(text, ER_WHITELIST)
# NSW(Non-Standard-Word) normalization
text = NSWNormalizer(text).normalize()
# Punctuations removal
old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
new_chars = ' ' * len(old_chars)
del_chars = ''
text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
#
if args.has_key:
ofile.write(key + '\t' + text + '\n')
else:
ofile.write(text + '\n')
n += 1
if n % args.log_interval == 0:
sys.stderr.write("text norm: {} lines done.\n".format(n))
sys.stderr.write("text norm: {} lines done in total.\n".format(n))
ifile.close()
ofile.close()

View File

@ -0,0 +1 @@
../tranformer/utils

View File

@ -1,79 +0,0 @@
from kaldiio import ReadHelper
from kaldiio import WriteHelper
import argparse
import json
import math
import numpy as np
def get_parser():
parser = argparse.ArgumentParser(
description="apply cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--ark-file",
"-a",
default=False,
required=True,
type=str,
help="fbank ark file",
)
parser.add_argument(
"--cmvn-file",
"-c",
default=False,
required=True,
type=str,
help="cmvn file",
)
parser.add_argument(
"--ark-index",
"-i",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
with open(args.cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stats']
vars = cmvn_stats['var_stats']
total_frames = cmvn_stats['total_frames']
for i in range(len(means)):
means[i] /= total_frames
vars[i] = vars[i] / total_frames - means[i] * means[i]
if vars[i] < 1.0e-20:
vars[i] = 1.0e-20
vars[i] = 1.0 / math.sqrt(vars[i])
with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
for key, mat in ark_reader:
mat = (mat - means) * vars
ark_writer(key, mat)
if __name__ == '__main__':
main()

View File

@ -1,29 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
nj=32
cmd=./utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
fbankdir=$1
cmvn_file=$2
logdir=$3
output_dir=$4
dump_dir=${output_dir}/ark; mkdir -p ${dump_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/apply_cmvn.JOB.log \
python utils/apply_cmvn.py -a $fbankdir/ark/feats.JOB.ark \
-c $cmvn_file -i JOB -o ${dump_dir} \
|| exit 1;
for n in $(seq $nj); do
cat ${dump_dir}/feats.$n.scp || exit 1
done > ${output_dir}/feats.scp || exit 1
echo "$0: Succeeded apply cmvn"

View File

@ -1,143 +0,0 @@
from kaldiio import ReadHelper, WriteHelper
import argparse
import numpy as np
def build_LFR_features(inputs, m=7, n=6):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / n))
left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (m - 1) // 2
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
else:
num_padding = m - (T - i * n)
frame = np.hstack(inputs[i * n:])
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)
def build_CMVN_features(inputs, mvn_file): # noqa
with open(mvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
add_shift_list = []
rescale_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
add_shift_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
rescale_list = list(rescale_line)
continue
for j in range(inputs.shape[0]):
for k in range(inputs.shape[1]):
add_shift_value = add_shift_list[k]
rescale_value = rescale_list[k]
inputs[j, k] = float(inputs[j, k]) + float(add_shift_value)
inputs[j, k] = float(inputs[j, k]) * float(rescale_value)
return inputs
def get_parser():
parser = argparse.ArgumentParser(
description="apply low_frame_rate and cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--ark-file",
"-a",
default=False,
required=True,
type=str,
help="fbank ark file",
)
parser.add_argument(
"--lfr",
"-f",
default=True,
type=str,
help="low frame rate",
)
parser.add_argument(
"--lfr-m",
"-m",
default=7,
type=int,
help="number of frames to stack",
)
parser.add_argument(
"--lfr-n",
"-n",
default=6,
type=int,
help="number of frames to skip",
)
parser.add_argument(
"--cmvn-file",
"-c",
default=False,
required=True,
type=str,
help="global cmvn file",
)
parser.add_argument(
"--ark-index",
"-i",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
dump_ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
dump_scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
shape_file = args.output_dir + "/len." + str(args.ark_index)
ark_writer = WriteHelper('ark,scp:{},{}'.format(dump_ark_file, dump_scp_file))
shape_writer = open(shape_file, 'w')
with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
for key, mat in ark_reader:
if args.lfr:
lfr = build_LFR_features(mat, args.lfr_m, args.lfr_n)
else:
lfr = mat
cmvn = build_CMVN_features(lfr, args.cmvn_file)
dims = cmvn.shape[1]
lens = cmvn.shape[0]
shape_writer.write(key + " " + str(lens) + "," + str(dims) + '\n')
ark_writer(key, cmvn)
if __name__ == '__main__':
main()

View File

@ -1,38 +0,0 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=utils/run.pl
# feature configuration
lfr=True
lfr_m=7
lfr_n=6
echo "$0 $@"
. utils/parse_options.sh || exit 1;
fbankdir=$1
cmvn_file=$2
logdir=$3
output_dir=$4
dump_dir=${output_dir}/ark; mkdir -p ${dump_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/apply_lfr_and_cmvn.JOB.log \
python utils/apply_lfr_and_cmvn.py -a $fbankdir/ark/feats.JOB.ark \
-f $lfr -m $lfr_m -n $lfr_n -c $cmvn_file -i JOB -o ${dump_dir} \
|| exit 1;
for n in $(seq $nj); do
cat ${dump_dir}/feats.$n.scp || exit 1
done > ${output_dir}/feats.scp || exit 1
for n in $(seq $nj); do
cat ${dump_dir}/len.$n || exit 1
done > ${output_dir}/speech_shape || exit 1
echo "$0: Succeeded apply low frame rate and cmvn"

View File

@ -1,73 +0,0 @@
import argparse
import json
import numpy as np
def get_parser():
parser = argparse.ArgumentParser(
description="combine cmvn file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--cmvn-dir",
"-c",
default=False,
required=True,
type=str,
help="cmvn dir",
)
parser.add_argument(
"--nj",
"-n",
default=1,
required=True,
type=int,
help="num of cmvn file",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
total_means = np.zeros(args.dims)
total_vars = np.zeros(args.dims)
total_frames = 0
cmvn_file = args.output_dir + "/cmvn.json"
for i in range(1, args.nj+1):
with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
cmvn_stats = json.load(fin)
total_means += np.array(cmvn_stats["mean_stats"])
total_vars += np.array(cmvn_stats["var_stats"])
total_frames += cmvn_stats["total_frames"]
cmvn_info = {
'mean_stats': list(total_means.tolist()),
'var_stats': list(total_vars.tolist()),
'total_frames': total_frames
}
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
if __name__ == '__main__':
main()

View File

@ -1,74 +0,0 @@
from kaldiio import ReadHelper
import argparse
import numpy as np
import json
def get_parser():
parser = argparse.ArgumentParser(
description="computer global cmvn",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--ark-file",
"-a",
default=False,
required=True,
type=str,
help="fbank ark file",
)
parser.add_argument(
"--ark-index",
"-i",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
mean_stats = np.zeros(args.dims)
var_stats = np.zeros(args.dims)
total_frames = 0
with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
for key, mat in ark_reader:
mean_stats += np.sum(mat, axis=0)
var_stats += np.sum(np.square(mat), axis=0)
total_frames += mat.shape[0]
cmvn_info = {
'mean_stats': list(mean_stats.tolist()),
'var_stats': list(var_stats.tolist()),
'total_frames': total_frames
}
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
if __name__ == '__main__':
main()

View File

@ -1,25 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
nj=32
cmd=./utils/run.pl
feats_dim=80
echo "$0 $@"
. utils/parse_options.sh || exit 1;
fbankdir=$1
logdir=$2
output_dir=${fbankdir}/cmvn; mkdir -p ${output_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
|| exit 1;
python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
echo "$0: Succeeded compute global cmvn"

View File

@ -1,153 +0,0 @@
from kaldiio import WriteHelper
import argparse
import numpy as np
import json
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
def compute_fbank(wav_file,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
resample_rate=16000,
speed=1.0):
waveform, sample_rate = torchaudio.load(wav_file)
if resample_rate != sample_rate:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
new_freq=resample_rate)(waveform)
if speed != 1.0:
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, resample_rate,
[['speed', str(speed)], ['rate', str(resample_rate)]]
)
waveform = waveform * (1 << 15)
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
window_type='hamming',
sample_frequency=resample_rate)
return mat.numpy()
def get_parser():
parser = argparse.ArgumentParser(
description="computer features",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--wav-lists",
"-w",
default=False,
required=True,
type=str,
help="input wav lists",
)
parser.add_argument(
"--text-files",
"-t",
default=False,
required=True,
type=str,
help="input text files",
)
parser.add_argument(
"--dims",
"-d",
default=80,
type=int,
help="feature dims",
)
parser.add_argument(
"--sample-frequency",
"-s",
default=16000,
type=int,
help="sample frequency",
)
parser.add_argument(
"--speed-perturb",
"-p",
default="1.0",
type=str,
help="speed perturb",
)
parser.add_argument(
"--ark-index",
"-a",
default=1,
required=True,
type=int,
help="ark index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
ark_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".ark"
scp_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".scp"
text_file = args.output_dir + "/txt/text." + str(args.ark_index) + ".txt"
feats_shape_file = args.output_dir + "/ark/len." + str(args.ark_index)
text_shape_file = args.output_dir + "/txt/len." + str(args.ark_index)
ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
text_writer = open(text_file, 'w')
feats_shape_writer = open(feats_shape_file, 'w')
text_shape_writer = open(text_shape_file, 'w')
speed_perturb_list = args.speed_perturb.split(',')
for speed in speed_perturb_list:
with open(args.wav_lists, 'r', encoding='utf-8') as wavfile:
with open(args.text_files, 'r', encoding='utf-8') as textfile:
for wav, text in zip(wavfile, textfile):
s_w = wav.strip().split()
wav_id = s_w[0]
wav_file = s_w[1]
s_t = text.strip().split()
text_id = s_t[0]
txt = s_t[1:]
fbank = compute_fbank(wav_file,
num_mel_bins=args.dims,
resample_rate=args.sample_frequency,
speed=float(speed)
)
feats_dims = fbank.shape[1]
feats_lens = fbank.shape[0]
txt_lens = len(txt)
if speed == "1.0":
wav_id_sp = wav_id
else:
wav_id_sp = wav_id + "_sp" + speed
feats_shape_writer.write(wav_id_sp + " " + str(feats_lens) + "," + str(feats_dims) + '\n')
text_shape_writer.write(wav_id_sp + " " + str(txt_lens) + '\n')
text_writer.write(wav_id_sp + " " + " ".join(txt) + '\n')
ark_writer(wav_id_sp, fbank)
if __name__ == '__main__':
main()

View File

@ -1,51 +0,0 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
nj=32
cmd=./utils/run.pl
# feature configuration
feats_dim=80
sample_frequency=16000
speed_perturb="1.0"
echo "$0 $@"
. utils/parse_options.sh || exit 1;
data=$1
logdir=$2
fbankdir=$3
[ ! -f $data/wav.scp ] && echo "$0: no such file $data/wav.scp" && exit 1;
[ ! -f $data/text ] && echo "$0: no such file $data/text" && exit 1;
python utils/split_data.py $data $data $nj
ark_dir=${fbankdir}/ark; mkdir -p ${ark_dir}
text_dir=${fbankdir}/txt; mkdir -p ${text_dir}
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/make_fbank.JOB.log \
python utils/compute_fbank.py -w $data/split${nj}/JOB/wav.scp -t $data/split${nj}/JOB/text \
-d $feats_dim -s $sample_frequency -p ${speed_perturb} -a JOB -o ${fbankdir} \
|| exit 1;
for n in $(seq $nj); do
cat ${ark_dir}/feats.$n.scp || exit 1
done > $fbankdir/feats.scp || exit 1
for n in $(seq $nj); do
cat ${text_dir}/text.$n.txt || exit 1
done > $fbankdir/text || exit 1
for n in $(seq $nj); do
cat ${ark_dir}/len.$n || exit 1
done > $fbankdir/speech_shape || exit 1
for n in $(seq $nj); do
cat ${text_dir}/len.$n || exit 1
done > $fbankdir/text_shape || exit 1
echo "$0: Succeeded compute FBANK features"

View File

@ -1,157 +0,0 @@
import os
import numpy as np
import sys
def compute_wer(ref_file,
hyp_file,
cer_detail_file):
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
}
hyp_dict = {}
ref_dict = {}
with open(hyp_file, 'r') as hyp_reader:
for line in hyp_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
hyp_dict[key] = value
with open(ref_file, 'r') as ref_reader:
for line in ref_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
ref_dict[key] = value
cer_detail_writer = open(cer_detail_file, 'w')
for hyp_key in hyp_dict:
if hyp_key in ref_dict:
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
cer_detail_writer.write('\n')
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
def compute_wer_by_line(hyp,
ref):
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_ref,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
if i < 0 and j >= 0:
rst['del'] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
return rst
def print_cer_detail(rst):
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
if __name__ == '__main__':
if len(sys.argv) != 4:
print("usage : python compute-wer.py test.ref test.hyp test.wer")
sys.exit(0)
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
cer_detail_file = sys.argv[3]
compute_wer(ref_file, hyp_file, cer_detail_file)

View File

@ -1,370 +0,0 @@
#!/usr/bin/env python3
# coding=utf8
# Copyright 2021 Jiayu DU
import sys
import argparse
import json
import logging
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
DEBUG = None
def GetEditType(ref_token, hyp_token):
if ref_token == None and hyp_token != None:
return 'I'
elif ref_token != None and hyp_token == None:
return 'D'
elif ref_token == hyp_token:
return 'C'
elif ref_token != hyp_token:
return 'S'
else:
raise RuntimeError
class AlignmentArc:
def __init__(self, src, dst, ref, hyp):
self.src = src
self.dst = dst
self.ref = ref
self.hyp = hyp
self.edit_type = GetEditType(ref, hyp)
def similarity_score_function(ref_token, hyp_token):
return 0 if (ref_token == hyp_token) else -1.0
def insertion_score_function(token):
return -1.0
def deletion_score_function(token):
return -1.0
def EditDistance(
ref,
hyp,
similarity_score_function = similarity_score_function,
insertion_score_function = insertion_score_function,
deletion_score_function = deletion_score_function):
assert(len(ref) != 0)
class DPState:
def __init__(self):
self.score = -float('inf')
# backpointer
self.prev_r = None
self.prev_h = None
def print_search_grid(S, R, H, fstream):
print(file=fstream)
for r in range(R):
for h in range(H):
print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
print(file=fstream)
R = len(ref) + 1
H = len(hyp) + 1
# Construct DP search space, a (R x H) grid
S = [ [] for r in range(R) ]
for r in range(R):
S[r] = [ DPState() for x in range(H) ]
# initialize DP search grid origin, S(r = 0, h = 0)
S[0][0].score = 0.0
S[0][0].prev_r = None
S[0][0].prev_h = None
# initialize REF axis
for r in range(1, R):
S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
S[r][0].prev_r = r-1
S[r][0].prev_h = 0
# initialize HYP axis
for h in range(1, H):
S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
S[0][h].prev_r = 0
S[0][h].prev_h = h-1
best_score = S[0][0].score
best_state = (0, 0)
for r in range(1, R):
for h in range(1, H):
sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
new_score = S[r-1][h-1].score + sub_or_cor_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r-1
S[r][h].prev_h = h-1
del_score = deletion_score_function(ref[r-1])
new_score = S[r-1][h].score + del_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r - 1
S[r][h].prev_h = h
ins_score = insertion_score_function(hyp[h-1])
new_score = S[r][h-1].score + ins_score
if new_score >= S[r][h].score:
S[r][h].score = new_score
S[r][h].prev_r = r
S[r][h].prev_h = h-1
best_score = S[R-1][H-1].score
best_state = (R-1, H-1)
if DEBUG:
print_search_grid(S, R, H, sys.stderr)
# Backtracing best alignment path, i.e. a list of arcs
# arc = (src, dst, ref, hyp, edit_type)
# src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
best_path = []
r, h = best_state[0], best_state[1]
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
score = S[r][h].score
# loop invariant:
# 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
# 2. score is the value of point(r, h) on DP search grid
while prev_r != None or prev_h != None:
src = (prev_r, prev_h)
dst = (r, h)
if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
elif (r == prev_r + 1 and h == prev_h): # Deletion
arc = AlignmentArc(src, dst, ref[prev_r], None)
elif (r == prev_r and h == prev_h + 1): # Insertion
arc = AlignmentArc(src, dst, None, hyp[prev_h])
else:
raise RuntimeError
best_path.append(arc)
r, h = prev_r, prev_h
prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
score = S[r][h].score
best_path.reverse()
return (best_path, best_score)
def PrettyPrintAlignment(alignment, stream = sys.stderr):
def get_token_str(token):
if token == None:
return "*"
return token
def is_double_width_char(ch):
if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
return True
# TODO: support other double-width-char language such as Japanese, Korean
else:
return False
def display_width(token_str):
m = 0
for c in token_str:
if is_double_width_char(c):
m += 2
else:
m += 1
return m
R = ' REF : '
H = ' HYP : '
E = ' EDIT : '
for arc in alignment:
r = get_token_str(arc.ref)
h = get_token_str(arc.hyp)
e = arc.edit_type if arc.edit_type != 'C' else ''
nr, nh, ne = display_width(r), display_width(h), display_width(e)
n = max(nr, nh, ne) + 1
R += r + ' ' * (n-nr)
H += h + ' ' * (n-nh)
E += e + ' ' * (n-ne)
print(R, file=stream)
print(H, file=stream)
print(E, file=stream)
def CountEdits(alignment):
c, s, i, d = 0, 0, 0, 0
for arc in alignment:
if arc.edit_type == 'C':
c += 1
elif arc.edit_type == 'S':
s += 1
elif arc.edit_type == 'I':
i += 1
elif arc.edit_type == 'D':
d += 1
else:
raise RuntimeError
return (c, s, i, d)
def ComputeTokenErrorRate(c, s, i, d):
return 100.0 * (s + d + i) / (s + d + c)
def ComputeSentenceErrorRate(num_err_utts, num_utts):
assert(num_utts != 0)
return 100.0 * num_err_utts / num_utts
class EvaluationResult:
def __init__(self):
self.num_ref_utts = 0
self.num_hyp_utts = 0
self.num_eval_utts = 0 # seen in both ref & hyp
self.num_hyp_without_ref = 0
self.C = 0
self.S = 0
self.I = 0
self.D = 0
self.token_error_rate = 0.0
self.num_utts_with_error = 0
self.sentence_error_rate = 0.0
def to_json(self):
return json.dumps(self.__dict__)
def to_kaldi(self):
info = (
F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
)
return info
def to_sclite(self):
return "TODO"
def to_espnet(self):
return "TODO"
def to_summary(self):
#return json.dumps(self.__dict__, indent=4)
summary = (
'==================== Overall Statistics ====================\n'
F'num_ref_utts: {self.num_ref_utts}\n'
F'num_hyp_utts: {self.num_hyp_utts}\n'
F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
F'num_eval_utts: {self.num_eval_utts}\n'
F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
F'token_error_rate: {self.token_error_rate:.2f}%\n'
F'token_stats:\n'
F' - tokens:{self.C + self.S + self.D:>7}\n'
F' - edits: {self.S + self.I + self.D:>7}\n'
F' - cor: {self.C:>7}\n'
F' - sub: {self.S:>7}\n'
F' - ins: {self.I:>7}\n'
F' - del: {self.D:>7}\n'
'============================================================\n'
)
return summary
class Utterance:
def __init__(self, uid, text):
self.uid = uid
self.text = text
def LoadUtterances(filepath, format):
utts = {}
if format == 'text': # utt_id word1 word2 ...
with open(filepath, 'r', encoding='utf8') as f:
for line in f:
line = line.strip()
if line:
cols = line.split(maxsplit=1)
assert(len(cols) == 2 or len(cols) == 1)
uid = cols[0]
text = cols[1] if len(cols) == 2 else ''
if utts.get(uid) != None:
raise RuntimeError(F'Found duplicated utterence id {uid}')
utts[uid] = Utterance(uid, text)
else:
raise RuntimeError(F'Unsupported text format {format}')
return utts
def tokenize_text(text, tokenizer):
if tokenizer == 'whitespace':
return text.split()
elif tokenizer == 'char':
return [ ch for ch in ''.join(text.split()) ]
else:
raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# optional
parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
# required
parser.add_argument('--ref', type=str, required=True, help='input reference file')
parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
parser.add_argument('result_file', type=str)
args = parser.parse_args()
logging.info(args)
ref_utts = LoadUtterances(args.ref, args.ref_format)
hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
r = EvaluationResult()
# check valid utterances in hyp that have matched non-empty reference
eval_utts = []
r.num_hyp_without_ref = 0
for uid in sorted(hyp_utts.keys()):
if uid in ref_utts.keys(): # TODO: efficiency
if ref_utts[uid].text.strip(): # non-empty reference
eval_utts.append(uid)
else:
logging.warn(F'Found {uid} with empty reference, skipping...')
else:
logging.warn(F'Found {uid} without reference, skipping...')
r.num_hyp_without_ref += 1
r.num_hyp_utts = len(hyp_utts)
r.num_ref_utts = len(ref_utts)
r.num_eval_utts = len(eval_utts)
with open(args.result_file, 'w+', encoding='utf8') as fo:
for uid in eval_utts:
ref = ref_utts[uid]
hyp = hyp_utts[uid]
alignment, score = EditDistance(
tokenize_text(ref.text, args.tokenizer),
tokenize_text(hyp.text, args.tokenizer)
)
c, s, i, d = CountEdits(alignment)
utt_ter = ComputeTokenErrorRate(c, s, i, d)
# utt-level evaluation result
print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
PrettyPrintAlignment(alignment, fo)
r.C += c
r.S += s
r.I += i
r.D += d
if utt_ter > 0:
r.num_utts_with_error += 1
# corpus level evaluation result
r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
print(r.to_summary(), file=fo)
print(r.to_json())
print(r.to_kaldi())

View File

@ -1,47 +0,0 @@
from transformers import AutoTokenizer, AutoModel, pipeline
import numpy as np
import sys
import os
import torch
from kaldiio import WriteHelper
import re
text_file_json = sys.argv[1]
out_ark = sys.argv[2]
out_scp = sys.argv[3]
out_shape = sys.argv[4]
device = int(sys.argv[5])
model_path = sys.argv[6]
model = AutoModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
with open(text_file_json, 'r') as f:
js = f.readlines()
f_shape = open(out_shape, "w")
with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
with torch.no_grad():
for idx, line in enumerate(js):
id, tokens = line.strip().split(" ", 1)
tokens = re.sub(" ", "", tokens.strip())
tokens = ' '.join([j for j in tokens])
token_num = len(tokens.split(" "))
outputs = extractor(tokens)
outputs = np.array(outputs)
embeds = outputs[0, 1:-1, :]
token_num_embeds, dim = embeds.shape
if token_num == token_num_embeds:
writer(id, embeds)
shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
f_shape.write(shape_line)
else:
print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
f_shape.close()

View File

@ -1,87 +0,0 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
# Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
$shifted=0;
if ($ARGV[0] eq "--exclude") {
$exclude = 1;
shift @ARGV;
$shifted=1;
}
if ($ARGV[0] eq "-f") {
$field = $ARGV[1];
shift @ARGV; shift @ARGV;
$shifted=1
}
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
"only the lines that were *not* in id_list.\n" .
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
"-f option, add 1 to the argument.\n" .
"See also: scripts/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
@A = split;
@A>=1 || die "Invalid id-list file line $_";
$seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
while(<>) {
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
# $1 is what we filter on.
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
print $_;
}
}
} else {
while(<>) {
@A = split;
@A > 0 || die "Invalid scp file line $_";
@A >= $field || die "Invalid scp file line $_";
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
print $_;
}
}
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2)

View File

@ -1,35 +0,0 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/wav.scp ]; then
echo "$0: wav.scp is not found"
exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
echo "$0: text is not found"
exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp
cp ${data_dir}/text ${data_dir}/.backup/text
mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak > ${data_dir}/wav.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
rm ${data_dir}/wav.scp.bak
rm ${data_dir}/text.bak

View File

@ -1,52 +0,0 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/feats.scp ]; then
echo "$0: feats.scp is not found"
exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
echo "$0: text is not found"
exit 1;
fi
if [ ! -f ${data_dir}/speech_shape ]; then
echo "$0: feature lengths is not found"
exit 1;
fi
if [ ! -f ${data_dir}/text_shape ]; then
echo "$0: text lengths is not found"
exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp
cp ${data_dir}/text ${data_dir}/.backup/text
cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape
cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape
mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak > ${data_dir}/feats.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak > ${data_dir}/speech_shape
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak > ${data_dir}/text_shape
rm ${data_dir}/feats.scp.bak
rm ${data_dir}/text.bak
rm ${data_dir}/speech_shape.bak
rm ${data_dir}/text_shape.bak

View File

@ -1,22 +0,0 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=./utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
ark_dir=$1
txt_dir=$2
output_dir=$3
[ ! -d ${ark_dir}/ark ] && echo "$0: ark data is required" && exit 1;
[ ! -d ${txt_dir}/txt ] && echo "$0: txt data is required" && exit 1;
for n in $(seq $nj); do
echo "${ark_dir}/ark/feats.$n.ark ${txt_dir}/txt/text.$n.txt" || exit 1
done > ${output_dir}/ark_txt.scp || exit 1

View File

@ -1,97 +0,0 @@
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.

View File

@ -1,45 +0,0 @@
#!/usr/bin/env python
import sys
def get_commandline_args(no_executable=True):
extra_chars = [
" ",
";",
"&",
"|",
"<",
">",
"?",
"*",
"~",
"`",
'"',
"'",
"\\",
"{",
"}",
"(",
")",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''")
if all(char not in arg for char in extra_chars)
else "'" + arg.replace("'", "'\\''") + "'"
for arg in sys.argv
]
if no_executable:
return " ".join(argv[1:])
else:
return sys.executable + " " + " ".join(argv)
def main():
print(get_commandline_args())
if __name__ == "__main__":
main()

View File

@ -1,35 +0,0 @@
from pathlib import Path
import torch
import yaml
class NoAliasSafeDumper(yaml.SafeDumper):
# Disable anchor/alias in yaml because looks ugly
def ignore_aliases(self, data):
return True
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
"""Safe-dump in yaml with no anchor/alias"""
return yaml.dump(
data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
)
def gen_conf(file, out_dir):
conf = torch.load(file)["config"]
conf["oss_bucket"] = "null"
print(conf)
output_dir = Path(out_dir)
output_dir.mkdir(parents=True, exist_ok=True)
with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
yaml_no_alias_safe_dump(conf, f, indent=4, sort_keys=False)
if __name__ == "__main__":
import sys
in_f = sys.argv[1]
out_f = sys.argv[2]
gen_conf(in_f, out_f)

View File

@ -1,31 +0,0 @@
import sys
import re
in_f = sys.argv[1]
out_f = sys.argv[2]
with open(in_f, "r", encoding="utf-8") as f:
lines = f.readlines()
with open(out_f, "w", encoding="utf-8") as f:
for line in lines:
outs = line.strip().split(" ", 1)
if len(outs) == 2:
idx, text = outs
text = re.sub("</s>", "", text)
text = re.sub("<s>", "", text)
text = re.sub("@@", "", text)
text = re.sub("@", "", text)
text = re.sub("<unk>", "", text)
text = re.sub(" ", "", text)
text = text.lower()
else:
idx = outs[0]
text = " "
text = [x for x in text]
text = " ".join(text)
out = "{} {}\n".format(idx, text)
f.write(out)

View File

@ -1,356 +0,0 @@
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# In general, doing
# run.pl some.log a b c is like running the command a b c in
# the bash shell, and putting the standard error and output into some.log.
# To run parallel jobs (backgrounded on the host machine), you can do (e.g.)
# run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB
# and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier].
# If any of the jobs fails, this script will fail.
# A typical example is:
# run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz
# and run.pl will run something like:
# ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log
#
# Basically it takes the command-line arguments, quotes them
# as necessary to preserve spaces, and evaluates them with bash.
# In addition it puts the command line at the top of the log, and
# the start and end times of the command at the beginning and end.
# The reason why this is useful is so that we can create a different
# version of this program that uses a queueing system instead.
#use Data::Dumper;
@ARGV < 2 && die "usage: run.pl log-file command-line arguments...";
#print STDERR "COMMAND-LINE: " . Dumper(\@ARGV) . "\n";
$job_pick = 'all';
$max_jobs_run = -1;
$jobstart = 1;
$jobend = 1;
$ignored_opts = ""; # These will be ignored.
# First parse an option like JOB=1:4, and any
# options that would normally be given to
# queue.pl, which we will just discard.
for (my $x = 1; $x <= 2; $x++) { # This for-loop is to
# allow the JOB=1:n option to be interleaved with the
# options to qsub.
while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) {
# parse any options that would normally go to qsub, but which will be ignored here.
my $switch = shift @ARGV;
if ($switch eq "-V") {
$ignored_opts .= "-V ";
} elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") {
# we do support the option --max-jobs-run n, and its GridEngine form -tc n.
# if the command appears multiple times uses the smallest option.
if ( $max_jobs_run <= 0 ) {
$max_jobs_run = shift @ARGV;
} else {
my $new_constraint = shift @ARGV;
if ( ($new_constraint < $max_jobs_run) ) {
$max_jobs_run = $new_constraint;
}
}
if (! ($max_jobs_run > 0)) {
die "run.pl: invalid option --max-jobs-run $max_jobs_run";
}
} else {
my $argument = shift @ARGV;
if ($argument =~ m/^--/) {
print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n";
}
if ($switch eq "-sync" && $argument =~ m/^[yY]/) {
$ignored_opts .= "-sync "; # Note: in the
# corresponding code in queue.pl it says instead, just "$sync = 1;".
} elsif ($switch eq "-pe") { # e.g. -pe smp 5
my $argument2 = shift @ARGV;
$ignored_opts .= "$switch $argument $argument2 ";
} elsif ($switch eq "--gpu") {
$using_gpu = $argument;
} elsif ($switch eq "--pick") {
if($argument =~ m/^(all|failed|incomplete)$/) {
$job_pick = $argument;
} else {
print STDERR "run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'"
}
} else {
# Ignore option.
$ignored_opts .= "$switch $argument ";
}
}
}
if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20
$jobname = $1;
$jobstart = $2;
$jobend = $3;
if ($jobstart > $jobend) {
die "run.pl: invalid job range $ARGV[0]";
}
if ($jobstart <= 0) {
die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility).";
}
shift;
} elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1.
$jobname = $1;
$jobstart = $2;
$jobend = $2;
shift;
} elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) {
print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n";
}
}
# Users found this message confusing so we are removing it.
# if ($ignored_opts ne "") {
# print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n";
# }
if ($max_jobs_run == -1) { # If --max-jobs-run option not set,
# then work out the number of processors if possible,
# and set it based on that.
$max_jobs_run = 0;
if ($using_gpu) {
if (open(P, "nvidia-smi -L |")) {
$max_jobs_run++ while (<P>);
close(P);
}
if ($max_jobs_run == 0) {
$max_jobs_run = 1;
print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n";
}
} elsif (open(P, "</proc/cpuinfo")) { # Linux
while (<P>) { if (m/^processor/) { $max_jobs_run++; } }
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n";
$max_jobs_run = 10; # reasonable default.
}
close(P);
} elsif (open(P, "sysctl -a |")) { # BSD/Darwin
while (<P>) {
if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4
$max_jobs_run = $1;
last;
}
}
close(P);
if ($max_jobs_run == 0) {
print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n";
$max_jobs_run = 10; # reasonable default.
}
} else {
# allow at most 32 jobs at once, on non-UNIX systems; change this code
# if you need to change this default.
$max_jobs_run = 32;
}
# The just-computed value of $max_jobs_run is just the number of processors
# (or our best guess); and if it happens that the number of jobs we need to
# run is just slightly above $max_jobs_run, it will make sense to increase
# $max_jobs_run to equal the number of jobs, so we don't have a small number
# of leftover jobs.
$num_jobs = $jobend - $jobstart + 1;
if (!$using_gpu &&
$num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) {
$max_jobs_run = $num_jobs;
}
}
sub pick_or_exit {
# pick_or_exit ( $logfile )
# Invoked before each job is started helps to run jobs selectively.
#
# Given the name of the output logfile decides whether the job must be
# executed (by returning from the subroutine) or not (by terminating the
# process calling exit)
#
# PRE: $job_pick is a global variable set by command line switch --pick
# and indicates which class of jobs must be executed.
#
# 1) If a failed job is not executed the process exit code will indicate
# failure, just as if the task was just executed and failed.
#
# 2) If a task is incomplete it will be executed. Incomplete may be either
# a job whose log file does not contain the accounting notes in the end,
# or a job whose log file does not exist.
#
# 3) If the $job_pick is set to 'all' (default behavior) a task will be
# executed regardless of the result of previous attempts.
#
# This logic could have been implemented in the main execution loop
# but a subroutine to preserve the current level of readability of
# that part of the code.
#
# Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020
#
if($job_pick eq 'all'){
return; # no need to bother with the previous log
}
open my $fh, "<", $_[0] or return; # job not executed yet
my $log_line;
my $cur_line;
while ($cur_line = <$fh>) {
if( $cur_line =~ m/# Ended \(code .*/ ) {
$log_line = $cur_line;
}
}
close $fh;
if (! defined($log_line)){
return; # incomplete
}
if ( $log_line =~ m/# Ended \(code 0\).*/ ) {
exit(0); # complete
} elsif ( $log_line =~ m/# Ended \(code \d+(; signal \d+)?\).*/ ){
if ($job_pick !~ m/^(failed|all)$/) {
exit(1); # failed but not going to run
} else {
return; # failed
}
} elsif ( $log_line =~ m/.*\S.*/ ) {
return; # incomplete jobs are always run
}
}
$logfile = shift @ARGV;
if (defined $jobname && $logfile !~ m/$jobname/ &&
$jobend > $jobstart) {
print STDERR "run.pl: you are trying to run a parallel job but "
. "you are putting the output into just one log file ($logfile)\n";
exit(1);
}
$cmd = "";
foreach $x (@ARGV) {
if ($x =~ m/^\S+$/) { $cmd .= $x . " "; }
elsif ($x =~ m:\":) { $cmd .= "'$x' "; }
else { $cmd .= "\"$x\" "; }
}
#$Data::Dumper::Indent=0;
$ret = 0;
$numfail = 0;
%active_pids=();
use POSIX ":sys_wait_h";
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
if (scalar(keys %active_pids) >= $max_jobs_run) {
# Lets wait for a change in any child's status
# Then we have to work out which child finished
$r = waitpid(-1, 0);
$code = $?;
if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen.
if ( defined $active_pids{$r} ) {
$jid=$active_pids{$r};
$fail[$jid]=$code;
if ($code !=0) { $numfail++;}
delete $active_pids{$r};
# print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n";
} else {
die "run.pl: Cannot find the PID of the child process that just finished.";
}
# In theory we could do a non-blocking waitpid over all jobs running just
# to find out if only one or more jobs finished during the previous waitpid()
# However, we just omit this and will reap the next one in the next pass
# through the for(;;) cycle
}
$childpid = fork();
if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; }
if ($childpid == 0) { # We're in the child... this branch
# executes the job and returns (possibly with an error status).
if (defined $jobname) {
$cmd =~ s/$jobname/$jobid/g;
$logfile =~ s/$jobname/$jobid/g;
}
# exit if the job does not need to be executed
pick_or_exit( $logfile );
system("mkdir -p `dirname $logfile` 2>/dev/null");
open(F, ">$logfile") || die "run.pl: Error opening log file $logfile";
print F "# " . $cmd . "\n";
print F "# Started at " . `date`;
$starttime = `date +'%s'`;
print F "#\n";
close(F);
# Pipe into bash.. make sure we're not using any other shell.
open(B, "|bash") || die "run.pl: Error opening shell command";
print B "( " . $cmd . ") 2>>$logfile >> $logfile";
close(B); # If there was an error, exit status is in $?
$ret = $?;
$lowbits = $ret & 127;
$highbits = $ret >> 8;
if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" }
else { $return_str = "code $highbits"; }
$endtime = `date +'%s'`;
open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)";
$enddate = `date`;
chop $enddate;
print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n";
print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n";
close(F);
exit($ret == 0 ? 0 : 1);
} else {
$pid[$jobid] = $childpid;
$active_pids{$childpid} = $jobid;
# print STDERR "Queued: " . Dumper(\%active_pids) . "\n";
}
}
# Now we have submitted all the jobs, lets wait until all the jobs finish
foreach $child (keys %active_pids) {
$jobid=$active_pids{$child};
$r = waitpid($pid[$jobid], 0);
$code = $?;
if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen.
if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully
}
# Some sanity checks:
# The $fail array should not contain undefined codes
# The number of non-zeros in that array should be equal to $numfail
# We cannot do foreach() here, as the JOB ids do not start at zero
$failed_jids=0;
for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
$job_return = $fail[$jobid];
if (not defined $job_return ) {
# print Dumper(\@fail);
die "run.pl: Sanity check failed: we have indication that some jobs are running " .
"even after we waited for all jobs to finish" ;
}
if ($job_return != 0 ){ $failed_jids++;}
}
if ($failed_jids != $numfail) {
die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)."
}
if ($numfail > 0) { $ret = 1; }
if ($ret != 0) {
$njobs = $jobend - $jobstart + 1;
if ($njobs == 1) {
if (defined $jobname) {
$logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with
# that job.
}
print STDERR "run.pl: job failed, log is in $logfile\n";
if ($logfile =~ m/JOB/) {
print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script.";
}
}
else {
$logfile =~ s/$jobname/*/g;
print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n";
}
}
exit ($ret);

View File

@ -1,44 +0,0 @@
#!/usr/bin/env perl
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
if ($ARGV[0] eq "--srand") {
$n = $ARGV[1];
$n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\"";
srand($ARGV[1]);
shift;
shift;
} else {
srand(0); # Gives inconsistent behavior if we don't seed.
}
if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we
# don't understand.
print "Usage: shuffle_list.pl [--srand N] [input file] > output\n";
print "randomizes the order of lines of input.\n";
exit(1);
}
@lines;
while (<>) {
push @lines, [ (rand(), $_)] ;
}
@lines = sort { $a->[0] cmp $b->[0] } @lines;
foreach $l (@lines) {
print $l->[1];
}

View File

@ -1,60 +0,0 @@
import os
import sys
import random
in_dir = sys.argv[1]
out_dir = sys.argv[2]
num_split = sys.argv[3]
def split_scp(scp, num):
assert len(scp) >= num
avg = len(scp) // num
out = []
begin = 0
for i in range(num):
if i == num - 1:
out.append(scp[begin:])
else:
out.append(scp[begin:begin+avg])
begin += avg
return out
os.path.exists("{}/wav.scp".format(in_dir))
os.path.exists("{}/text".format(in_dir))
with open("{}/wav.scp".format(in_dir), 'r') as infile:
wav_list = infile.readlines()
with open("{}/text".format(in_dir), 'r') as infile:
text_list = infile.readlines()
assert len(wav_list) == len(text_list)
x = list(zip(wav_list, text_list))
random.shuffle(x)
wav_shuffle_list, text_shuffle_list = zip(*x)
num_split = int(num_split)
wav_split_list = split_scp(wav_shuffle_list, num_split)
text_split_list = split_scp(text_shuffle_list, num_split)
for idx, wav_list in enumerate(wav_split_list, 1):
path = out_dir + "/split" + str(num_split) + "/" + str(idx)
if not os.path.exists(path):
os.makedirs(path)
with open("{}/wav.scp".format(path), 'w') as wav_writer:
for line in wav_list:
wav_writer.write(line)
for idx, text_list in enumerate(text_split_list, 1):
path = out_dir + "/split" + str(num_split) + "/" + str(idx)
if not os.path.exists(path):
os.makedirs(path)
with open("{}/text".format(path), 'w') as text_writer:
for line in text_list:
text_writer.write(line)

View File

@ -1,246 +0,0 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# See ../../COPYING for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can. If you use the utt2spk
# option it will make sure these chunks coincide with speaker boundaries. In
# this case, if there are more chunks than speakers (and in some other
# circumstances), some of the resulting chunks will be empty and it will print
# an error message and exit with nonzero status.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
# You can also call the scripts like:
# split_scp.pl -j 3 0 scp scp.0
# [note: with this option, it assumes zero-based indexing of the split parts,
# i.e. the second number must be 0 <= n < num-jobs.]
use warnings;
$num_jobs = 0;
$job_id = 0;
$utt2spk_file = "";
$one_based = 0;
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
if ($ARGV[0] eq "-j") {
shift @ARGV;
$num_jobs = shift @ARGV;
$job_id = shift @ARGV;
}
if ($ARGV[0] =~ /--utt2spk=(.+)/) {
$utt2spk_file=$1;
shift;
}
if ($ARGV[0] eq '--one-based') {
$one_based = 1;
shift @ARGV;
}
}
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
$job_id - $one_based >= $num_jobs)) {
die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
($one_based ? " --one-based" : "") . "'\n"
}
$one_based
and $job_id--;
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
die
"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
}
$error = 0;
$inscp = shift @ARGV;
if ($num_jobs == 0) { # without -j option
@OUTPUTS = @ARGV;
} else {
for ($j = 0; $j < $num_jobs; $j++) {
if ($j == $job_id) {
if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
else { push @OUTPUTS, "-"; }
} else {
push @OUTPUTS, "/dev/null";
}
}
}
if ($utt2spk_file ne "") { # We have the --utt2spk option...
open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
while(<$u_fh>) {
@A = split;
@A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
($u,$s) = @A;
$utt2spk{$u} = $s;
}
close $u_fh;
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
@spkrs = ();
while(<$i_fh>) {
@A = split;
if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
$u = $A[0];
$s = $utt2spk{$u};
defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
if(!defined $spk_count{$s}) {
push @spkrs, $s;
$spk_count{$s} = 0;
$spk_data{$s} = []; # ref to new empty array.
}
$spk_count{$s}++;
push @{$spk_data{$s}}, $_;
}
# Now split as equally as possible ..
# First allocate spks to files by allocating an approximately
# equal number of speakers.
$numspks = @spkrs; # number of speakers.
$numscps = @OUTPUTS; # number of output files.
if ($numspks < $numscps) {
die "$0: Refusing to split data because number of speakers $numspks " .
"is less than the number of output .scp files $numscps\n";
}
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scparray[$scpidx] = []; # [] is array reference.
}
for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
$scpidx = int(($spkidx*$numscps) / $numspks);
$spk = $spkrs[$spkidx];
push @{$scparray[$scpidx]}, $spk;
$scpcount[$scpidx] += $spk_count{$spk};
}
# Now will try to reassign beginning + ending speakers
# to different scp's and see if it gets more balanced.
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
# We can show that if considering changing just 2 scp's, we minimize
# this by minimizing the squared difference in sizes. This is
# equivalent to minimizing the absolute difference in sizes. This
# shows this method is bound to converge.
$changed = 1;
while($changed) {
$changed = 0;
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
# First try to reassign ending spk of this scp.
if($scpidx < $numscps-1) {
$sz = @{$scparray[$scpidx]};
if($sz > 0) {
$spk = $scparray[$scpidx]->[$sz-1];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx];
$nutt2 = $scpcount[$scpidx+1];
if( abs( ($nutt2+$count) - ($nutt1-$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx+1] += $count;
$scpcount[$scpidx] -= $count;
pop @{$scparray[$scpidx]};
unshift @{$scparray[$scpidx+1]}, $spk;
$changed = 1;
}
}
}
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
$spk = $scparray[$scpidx]->[0];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx-1];
$nutt2 = $scpcount[$scpidx];
if( abs( ($nutt2-$count) - ($nutt1+$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx-1] += $count;
$scpcount[$scpidx] -= $count;
shift @{$scparray[$scpidx]};
push @{$scparray[$scpidx-1]}, $spk;
$changed = 1;
}
}
}
}
# Now print out the files...
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($f_fh, '>', $scpfile)
: open($f_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
$count = 0;
if(@{$scparray[$scpidx]} == 0) {
print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
"$scpfile (too many splits and too few speakers?)\n";
$error = 1;
} else {
foreach $spk ( @{$scparray[$scpidx]} ) {
print $f_fh @{$spk_data{$spk}};
$count += $spk_count{$spk};
}
$count == $scpcount[$scpidx] || die "Count mismatch [code error]";
}
close($f_fh);
}
} else {
# This block is the "normal" case where there is no --utt2spk
# option and we just break into equal size chunks.
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
$numscps = @OUTPUTS; # size of array.
@F = ();
while(<$i_fh>) {
push @F, $_;
}
$numlines = @F;
if($numlines == 0) {
print STDERR "$0: error: empty input scp file $inscp\n";
$error = 1;
}
$linesperscp = int( $numlines / $numscps); # the "whole part"..
$linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
$remainder = $numlines - ($linesperscp * $numscps);
($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
# [just doing int() rounds down].
$n = 0;
for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($o_fh, '>', $scpfile)
: open($o_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
print $o_fh $F[$n++];
}
close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
}
$n == $numlines || die "$n != $numlines [code error]";
}
exit ($error);

View File

@ -1,30 +0,0 @@
#!/usr/bin/env bash
dev_num_utt=1000
echo "$0 $@"
. utils/parse_options.sh || exit 1;
train_data=$1
out_dir=$2
[ ! -f ${train_data}/wav.scp ] && echo "$0: no such file ${train_data}/wav.scp" && exit 1;
[ ! -f ${train_data}/text ] && echo "$0: no such file ${train_data}/text" && exit 1;
mkdir -p ${out_dir}/train && mkdir -p ${out_dir}/dev
cp ${train_data}/wav.scp ${out_dir}/train/wav.scp.bak
cp ${train_data}/text ${out_dir}/train/text.bak
num_utt=$(wc -l <${out_dir}/train/wav.scp.bak)
utils/shuffle_list.pl --srand 1 ${out_dir}/train/wav.scp.bak > ${out_dir}/train/wav.scp.shuf
head -n ${dev_num_utt} ${out_dir}/train/wav.scp.shuf > ${out_dir}/dev/wav.scp
tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/wav.scp.shuf > ${out_dir}/train/wav.scp
utils/shuffle_list.pl --srand 1 ${out_dir}/train/text.bak > ${out_dir}/train/text.shuf
head -n ${dev_num_utt} ${out_dir}/train/text.shuf > ${out_dir}/dev/text
tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/text.shuf > ${out_dir}/train/text
rm ${out_dir}/train/wav.scp.bak ${out_dir}/train/text.bak
rm ${out_dir}/train/wav.scp.shuf ${out_dir}/train/text.shuf

View File

@ -1,135 +0,0 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import re
import sys
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
start_pos = None
end_pos = None
for pos in match_pos:
if pos[0] <= i < pos[1]:
start_pos = pos[0]
end_pos = pos[1]
break
return start_pos, end_pos
def get_parser():
parser = argparse.ArgumentParser(
description="convert raw text to tokenized text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--nchar",
"-n",
default=1,
type=int,
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
default=None,
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
type=str,
default="char",
choices=["char", "phn"],
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
If trans_type is char,
read from SI1279.WRD file -> "bricks are an alternative"
Else if trans_type is phn,
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
sil t er n ih sil t ih v sil" """,
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
rs = []
if args.non_lang_syms is not None:
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
nls = [x.rstrip() for x in f.readlines()]
rs = [re.compile(re.escape(x)) for x in nls]
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
line = f.readline()
n = args.nchar
while line:
x = line.split()
print(" ".join(x[: args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols :])
# get all matched positions
match_pos = []
for r in rs:
i = 0
while i >= 0:
m = r.search(a, i)
if m:
match_pos.append([m.start(), m.end()])
i = m.end()
else:
break
if args.trans_type == "phn":
a = a.split(" ")
else:
if len(match_pos) > 0:
chars = []
i = 0
while i < len(a):
start_pos, end_pos = exist_or_not(i, match_pos)
if start_pos is not None:
chars.append(a[start_pos:end_pos])
i = end_pos
else:
chars.append(a[i])
i += 1
a = chars
a = [a[j : j + n] for j in range(0, len(a), n)]
a_flat = []
for z in a:
a_flat.append("".join(z))
a_chars = [z.replace(" ", args.space) for z in a_flat]
if args.trans_type == "phn":
a_chars = [z.replace("sil", args.space) for z in a_chars]
print(" ".join(a_chars))
line = f.readline()
if __name__ == "__main__":
main()

View File

@ -1,106 +0,0 @@
import re
import argparse
def load_dict(seg_file):
seg_dict = {}
with open(seg_file, 'r') as infile:
for line in infile:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dict
def forward_segment(text, dic):
word_list = []
i = 0
while i < len(text):
longest_word = text[i]
for j in range(i + 1, len(text) + 1):
word = text[i:j]
if word in dic:
if len(word) > len(longest_word):
longest_word = word
word_list.append(longest_word)
i += len(longest_word)
return word_list
def tokenize(txt,
seg_dict):
out_txt = ""
pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
for word in txt:
if pattern.match(word):
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
out_txt += "<unk>" + " "
else:
continue
return out_txt.strip()
def get_parser():
parser = argparse.ArgumentParser(
description="text tokenize",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--text-file",
"-t",
default=False,
required=True,
type=str,
help="input text",
)
parser.add_argument(
"--seg-file",
"-s",
default=False,
required=True,
type=str,
help="seg file",
)
parser.add_argument(
"--txt-index",
"-i",
default=1,
required=True,
type=int,
help="txt index",
)
parser.add_argument(
"--output-dir",
"-o",
default=False,
required=True,
type=str,
help="output dir",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w')
shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w')
seg_dict = load_dict(args.seg_file)
with open(args.text_file, 'r') as infile:
for line in infile:
s = line.strip().split()
text_id = s[0]
text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
text = tokenize(text_list, seg_dict)
lens = len(text.strip().split())
txt_writer.write(text_id + " " + text + '\n')
shape_writer.write(text_id + " " + str(lens) + '\n')
if __name__ == '__main__':
main()

Some files were not shown because too many files have changed in this diff Show More