diff --git a/egs/aishell2/transformer/utils b/egs/aishell2/transformer/utils deleted file mode 120000 index fe070dd3a..000000000 --- a/egs/aishell2/transformer/utils +++ /dev/null @@ -1 +0,0 @@ -../../aishell/transformer/utils \ No newline at end of file diff --git a/egs/aishell2/transformer/utils/__init__.py b/egs/aishell2/transformer/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/aishell2/transformer/utils/apply_cmvn.py b/egs/aishell2/transformer/utils/apply_cmvn.py new file mode 100755 index 000000000..b5c5086b3 --- /dev/null +++ b/egs/aishell2/transformer/utils/apply_cmvn.py @@ -0,0 +1,79 @@ +from kaldiio import ReadHelper +from kaldiio import WriteHelper + +import argparse +import json +import math +import numpy as np + + +def get_parser(): + parser = argparse.ArgumentParser( + description="apply cmvn", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--ark-file", + "-a", + default=False, + required=True, + type=str, + help="fbank ark file", + ) + parser.add_argument( + "--cmvn-file", + "-c", + default=False, + required=True, + type=str, + help="cmvn file", + ) + parser.add_argument( + "--ark-index", + "-i", + default=1, + required=True, + type=int, + help="ark index", + ) + parser.add_argument( + "--output-dir", + "-o", + default=False, + required=True, + type=str, + help="output dir", + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark" + scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp" + ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file)) + + with open(args.cmvn_file) as f: + cmvn_stats = json.load(f) + + means = cmvn_stats['mean_stats'] + vars = cmvn_stats['var_stats'] + total_frames = cmvn_stats['total_frames'] + + for i in range(len(means)): + means[i] /= total_frames + vars[i] = vars[i] / total_frames - means[i] * means[i] + if vars[i] < 1.0e-20: + vars[i] = 1.0e-20 + vars[i] = 1.0 / math.sqrt(vars[i]) + + with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader: + for key, mat in ark_reader: + mat = (mat - means) * vars + ark_writer(key, mat) + + +if __name__ == '__main__': + main() diff --git a/egs/aishell2/transformer/utils/apply_cmvn.sh b/egs/aishell2/transformer/utils/apply_cmvn.sh new file mode 100755 index 000000000..f8fd1d140 --- /dev/null +++ b/egs/aishell2/transformer/utils/apply_cmvn.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; +# Begin configuration section. +nj=32 +cmd=./utils/run.pl + +echo "$0 $@" + +. utils/parse_options.sh || exit 1; + +fbankdir=$1 +cmvn_file=$2 +logdir=$3 +output_dir=$4 + +dump_dir=${output_dir}/ark; mkdir -p ${dump_dir} +mkdir -p ${logdir} + +$cmd JOB=1:$nj $logdir/apply_cmvn.JOB.log \ + python utils/apply_cmvn.py -a $fbankdir/ark/feats.JOB.ark \ + -c $cmvn_file -i JOB -o ${dump_dir} \ + || exit 1; + +for n in $(seq $nj); do + cat ${dump_dir}/feats.$n.scp || exit 1 +done > ${output_dir}/feats.scp || exit 1 + +echo "$0: Succeeded apply cmvn" diff --git a/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py new file mode 100755 index 000000000..50d18d1a4 --- /dev/null +++ b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py @@ -0,0 +1,143 @@ +from kaldiio import ReadHelper, WriteHelper + +import argparse +import numpy as np + + +def build_LFR_features(inputs, m=7, n=6): + LFR_inputs = [] + T = inputs.shape[0] + T_lfr = int(np.ceil(T / n)) + left_padding = np.tile(inputs[0], ((m - 1) // 2, 1)) + inputs = np.vstack((left_padding, inputs)) + T = T + (m - 1) // 2 + for i in range(T_lfr): + if m <= T - i * n: + LFR_inputs.append(np.hstack(inputs[i * n:i * n + m])) + else: + num_padding = m - (T - i * n) + frame = np.hstack(inputs[i * n:]) + for _ in range(num_padding): + frame = np.hstack((frame, inputs[-1])) + LFR_inputs.append(frame) + return np.vstack(LFR_inputs) + + +def build_CMVN_features(inputs, mvn_file): # noqa + with open(mvn_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + + add_shift_list = [] + rescale_list = [] + for i in range(len(lines)): + line_item = lines[i].split() + if line_item[0] == '': + line_item = lines[i + 1].split() + if line_item[0] == '': + add_shift_line = line_item[3:(len(line_item) - 1)] + add_shift_list = list(add_shift_line) + continue + elif line_item[0] == '': + line_item = lines[i + 1].split() + if line_item[0] == '': + rescale_line = line_item[3:(len(line_item) - 1)] + rescale_list = list(rescale_line) + continue + + for j in range(inputs.shape[0]): + for k in range(inputs.shape[1]): + add_shift_value = add_shift_list[k] + rescale_value = rescale_list[k] + inputs[j, k] = float(inputs[j, k]) + float(add_shift_value) + inputs[j, k] = float(inputs[j, k]) * float(rescale_value) + + return inputs + + +def get_parser(): + parser = argparse.ArgumentParser( + description="apply low_frame_rate and cmvn", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--ark-file", + "-a", + default=False, + required=True, + type=str, + help="fbank ark file", + ) + parser.add_argument( + "--lfr", + "-f", + default=True, + type=str, + help="low frame rate", + ) + parser.add_argument( + "--lfr-m", + "-m", + default=7, + type=int, + help="number of frames to stack", + ) + parser.add_argument( + "--lfr-n", + "-n", + default=6, + type=int, + help="number of frames to skip", + ) + parser.add_argument( + "--cmvn-file", + "-c", + default=False, + required=True, + type=str, + help="global cmvn file", + ) + parser.add_argument( + "--ark-index", + "-i", + default=1, + required=True, + type=int, + help="ark index", + ) + parser.add_argument( + "--output-dir", + "-o", + default=False, + required=True, + type=str, + help="output dir", + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + dump_ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark" + dump_scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp" + shape_file = args.output_dir + "/len." + str(args.ark_index) + ark_writer = WriteHelper('ark,scp:{},{}'.format(dump_ark_file, dump_scp_file)) + + shape_writer = open(shape_file, 'w') + with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader: + for key, mat in ark_reader: + if args.lfr: + lfr = build_LFR_features(mat, args.lfr_m, args.lfr_n) + else: + lfr = mat + cmvn = build_CMVN_features(lfr, args.cmvn_file) + dims = cmvn.shape[1] + lens = cmvn.shape[0] + shape_writer.write(key + " " + str(lens) + "," + str(dims) + '\n') + ark_writer(key, cmvn) + + +if __name__ == '__main__': + main() + diff --git a/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.sh b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.sh new file mode 100755 index 000000000..3119fdb8f --- /dev/null +++ b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + + +# Begin configuration section. +nj=32 +cmd=utils/run.pl + +# feature configuration +lfr=True +lfr_m=7 +lfr_n=6 + +echo "$0 $@" + +. utils/parse_options.sh || exit 1; + +fbankdir=$1 +cmvn_file=$2 +logdir=$3 +output_dir=$4 + +dump_dir=${output_dir}/ark; mkdir -p ${dump_dir} +mkdir -p ${logdir} + +$cmd JOB=1:$nj $logdir/apply_lfr_and_cmvn.JOB.log \ + python utils/apply_lfr_and_cmvn.py -a $fbankdir/ark/feats.JOB.ark \ + -f $lfr -m $lfr_m -n $lfr_n -c $cmvn_file -i JOB -o ${dump_dir} \ + || exit 1; + +for n in $(seq $nj); do + cat ${dump_dir}/feats.$n.scp || exit 1 +done > ${output_dir}/feats.scp || exit 1 + +for n in $(seq $nj); do + cat ${dump_dir}/len.$n || exit 1 +done > ${output_dir}/speech_shape || exit 1 + +echo "$0: Succeeded apply low frame rate and cmvn" diff --git a/egs/aishell2/transformer/utils/cmvn_converter.py b/egs/aishell2/transformer/utils/cmvn_converter.py new file mode 100644 index 000000000..d405d1290 --- /dev/null +++ b/egs/aishell2/transformer/utils/cmvn_converter.py @@ -0,0 +1,51 @@ +import argparse +import json +import numpy as np + + +def get_parser(): + parser = argparse.ArgumentParser( + description="cmvn converter", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--cmvn_json", + default=False, + required=True, + type=str, + help="cmvn json file", + ) + parser.add_argument( + "--am_mvn", + default=False, + required=True, + type=str, + help="am mvn file", + ) + return parser + +def main(): + parser = get_parser() + args = parser.parse_args() + + with open(args.cmvn_json, "r") as fin: + cmvn_dict = json.load(fin) + + mean_stats = np.array(cmvn_dict["mean_stats"]) + var_stats = np.array(cmvn_dict["var_stats"]) + total_frame = np.array(cmvn_dict["total_frames"]) + + mean = -1.0 * mean_stats / total_frame + var = 1.0 / np.sqrt(var_stats / total_frame - mean * mean) + dims = mean.shape[0] + with open(args.am_mvn, 'w') as fout: + fout.write("" + "\n" + " " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + " " + str(dims) + " " + str(dims) + "\n") + mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]') + fout.write(" 0 " + mean_str + '\n') + fout.write(" " + str(dims) + " " + str(dims) + '\n') + var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]') + fout.write(" 0 " + var_str + '\n') + fout.write("" + '\n') + +if __name__ == '__main__': + main() diff --git a/egs/aishell2/transformer/utils/combine_cmvn_file.py b/egs/aishell2/transformer/utils/combine_cmvn_file.py new file mode 100755 index 000000000..c52597372 --- /dev/null +++ b/egs/aishell2/transformer/utils/combine_cmvn_file.py @@ -0,0 +1,72 @@ +import argparse +import json +import os + +import numpy as np + + +def get_parser(): + parser = argparse.ArgumentParser( + description="combine cmvn file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--dim", + default=80, + type=int, + help="feature dim", + ) + parser.add_argument( + "--cmvn_dir", + default=False, + required=True, + type=str, + help="cmvn dir", + ) + + parser.add_argument( + "--nj", + default=1, + required=True, + type=int, + help="num of cmvn files", + ) + parser.add_argument( + "--output_dir", + default=False, + required=True, + type=str, + help="output dir", + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + total_means = np.zeros(args.dim) + total_vars = np.zeros(args.dim) + total_frames = 0 + + cmvn_file = os.path.join(args.output_dir, "cmvn.json") + + for i in range(1, args.nj + 1): + with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin: + cmvn_stats = json.load(fin) + + total_means += np.array(cmvn_stats["mean_stats"]) + total_vars += np.array(cmvn_stats["var_stats"]) + total_frames += cmvn_stats["total_frames"] + + cmvn_info = { + 'mean_stats': list(total_means.tolist()), + 'var_stats': list(total_vars.tolist()), + 'total_frames': total_frames + } + with open(cmvn_file, 'w') as fout: + fout.write(json.dumps(cmvn_info)) + + +if __name__ == '__main__': + main() diff --git a/egs/aishell2/transformer/utils/compute_cmvn.py b/egs/aishell2/transformer/utils/compute_cmvn.py new file mode 100755 index 000000000..949cc084c --- /dev/null +++ b/egs/aishell2/transformer/utils/compute_cmvn.py @@ -0,0 +1,104 @@ +import argparse +import json +import os + +import numpy as np +import torchaudio +import torchaudio.compliance.kaldi as kaldi + + +def get_parser(): + parser = argparse.ArgumentParser( + description="computer global cmvn", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--dim", + default=80, + type=int, + help="feature dimension", + ) + parser.add_argument( + "--wav_path", + default=False, + required=True, + type=str, + help="the path of wav scps", + ) + parser.add_argument( + "--idx", + default=1, + required=True, + type=int, + help="index", + ) + return parser + + +def compute_fbank(wav_file, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + resample_rate=16000, + speed=1.0, + window_type="hamming"): + waveform, sample_rate = torchaudio.load(wav_file) + if resample_rate != sample_rate: + waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, + new_freq=resample_rate)(waveform) + if speed != 1.0: + waveform, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, resample_rate, + [['speed', str(speed)], ['rate', str(resample_rate)]] + ) + + waveform = waveform * (1 << 15) + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + window_type=window_type, + sample_frequency=resample_rate) + + return mat.numpy() + + +def main(): + parser = get_parser() + args = parser.parse_args() + + wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx)) + cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx)) + + mean_stats = np.zeros(args.dim) + var_stats = np.zeros(args.dim) + total_frames = 0 + + # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader: + # for key, mat in ark_reader: + # mean_stats += np.sum(mat, axis=0) + # var_stats += np.sum(np.square(mat), axis=0) + # total_frames += mat.shape[0] + with open(wav_scp_file) as f: + lines = f.readlines() + for line in lines: + _, wav_file = line.strip().split() + fbank = compute_fbank(wav_file, num_mel_bins=args.dim) + mean_stats += np.sum(fbank, axis=0) + var_stats += np.sum(np.square(fbank), axis=0) + total_frames += fbank.shape[0] + + cmvn_info = { + 'mean_stats': list(mean_stats.tolist()), + 'var_stats': list(var_stats.tolist()), + 'total_frames': total_frames + } + with open(cmvn_file, 'w') as fout: + fout.write(json.dumps(cmvn_info)) + + +if __name__ == '__main__': + main() diff --git a/egs/aishell2/transformer/utils/compute_cmvn.sh b/egs/aishell2/transformer/utils/compute_cmvn.sh new file mode 100755 index 000000000..7663df992 --- /dev/null +++ b/egs/aishell2/transformer/utils/compute_cmvn.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; +# Begin configuration section. +nj=32 +cmd=./utils/run.pl +feats_dim=80 + +echo "$0 $@" + +. utils/parse_options.sh || exit 1; + +fbankdir=$1 + +split_dir=${fbankdir}/cmvn/split_${nj}; +mkdir -p $split_dir +split_scps="" +for n in $(seq $nj); do + split_scps="$split_scps $split_dir/wav.$n.scp" +done +utils/split_scp.pl ${fbankdir}/wav.scp $split_scps || exit 1; + +logdir=${fbankdir}/cmvn/log +$cmd JOB=1:$nj $logdir/cmvn.JOB.log \ + python utils/compute_cmvn.py \ + --dim ${feats_dim} \ + --wav_path $split_dir \ + --idx JOB + +python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn + +python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/cmvn.mvn + +echo "$0: Succeeded compute global cmvn" diff --git a/egs/aishell2/transformer/utils/compute_fbank.py b/egs/aishell2/transformer/utils/compute_fbank.py new file mode 100755 index 000000000..9c3904f78 --- /dev/null +++ b/egs/aishell2/transformer/utils/compute_fbank.py @@ -0,0 +1,171 @@ +from kaldiio import WriteHelper + +import argparse +import numpy as np +import json +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi + + +def compute_fbank(wav_file, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + resample_rate=16000, + speed=1.0, + window_type="hamming"): + + waveform, sample_rate = torchaudio.load(wav_file) + if resample_rate != sample_rate: + waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, + new_freq=resample_rate)(waveform) + if speed != 1.0: + waveform, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, resample_rate, + [['speed', str(speed)], ['rate', str(resample_rate)]] + ) + + waveform = waveform * (1 << 15) + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + window_type=window_type, + sample_frequency=resample_rate) + + return mat.numpy() + + +def get_parser(): + parser = argparse.ArgumentParser( + description="computer features", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--wav-lists", + "-w", + default=False, + required=True, + type=str, + help="input wav lists", + ) + parser.add_argument( + "--text-files", + "-t", + default=False, + required=True, + type=str, + help="input text files", + ) + parser.add_argument( + "--dims", + "-d", + default=80, + type=int, + help="feature dims", + ) + parser.add_argument( + "--max-lengths", + "-m", + default=1500, + type=int, + help="max frame numbers", + ) + parser.add_argument( + "--sample-frequency", + "-s", + default=16000, + type=int, + help="sample frequency", + ) + parser.add_argument( + "--speed-perturb", + "-p", + default="1.0", + type=str, + help="speed perturb", + ) + parser.add_argument( + "--ark-index", + "-a", + default=1, + required=True, + type=int, + help="ark index", + ) + parser.add_argument( + "--output-dir", + "-o", + default=False, + required=True, + type=str, + help="output dir", + ) + parser.add_argument( + "--window-type", + default="hamming", + required=False, + type=str, + help="window type" + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + ark_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".ark" + scp_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".scp" + text_file = args.output_dir + "/txt/text." + str(args.ark_index) + ".txt" + feats_shape_file = args.output_dir + "/ark/len." + str(args.ark_index) + text_shape_file = args.output_dir + "/txt/len." + str(args.ark_index) + + ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file)) + text_writer = open(text_file, 'w') + feats_shape_writer = open(feats_shape_file, 'w') + text_shape_writer = open(text_shape_file, 'w') + + speed_perturb_list = args.speed_perturb.split(',') + + for speed in speed_perturb_list: + with open(args.wav_lists, 'r', encoding='utf-8') as wavfile: + with open(args.text_files, 'r', encoding='utf-8') as textfile: + for wav, text in zip(wavfile, textfile): + s_w = wav.strip().split() + wav_id = s_w[0] + wav_file = s_w[1] + + s_t = text.strip().split() + text_id = s_t[0] + txt = s_t[1:] + fbank = compute_fbank(wav_file, + num_mel_bins=args.dims, + resample_rate=args.sample_frequency, + speed=float(speed), + window_type=args.window_type + ) + feats_dims = fbank.shape[1] + feats_lens = fbank.shape[0] + if feats_lens >= args.max_lengths: + continue + txt_lens = len(txt) + if speed == "1.0": + wav_id_sp = wav_id + else: + wav_id_sp = wav_id + "_sp" + speed + + feats_shape_writer.write(wav_id_sp + " " + str(feats_lens) + "," + str(feats_dims) + '\n') + text_shape_writer.write(wav_id_sp + " " + str(txt_lens) + '\n') + + text_writer.write(wav_id_sp + " " + " ".join(txt) + '\n') + ark_writer(wav_id_sp, fbank) + + +if __name__ == '__main__': + main() + diff --git a/egs/aishell2/transformer/utils/compute_fbank.sh b/egs/aishell2/transformer/utils/compute_fbank.sh new file mode 100755 index 000000000..8704b313c --- /dev/null +++ b/egs/aishell2/transformer/utils/compute_fbank.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; +# Begin configuration section. +nj=32 +cmd=./utils/run.pl + +# feature configuration +feats_dim=80 +sample_frequency=16000 +speed_perturb="1.0" +window_type="hamming" +max_lengths=1500 + +echo "$0 $@" + +. utils/parse_options.sh || exit 1; + +data=$1 +logdir=$2 +fbankdir=$3 + +[ ! -f $data/wav.scp ] && echo "$0: no such file $data/wav.scp" && exit 1; +[ ! -f $data/text ] && echo "$0: no such file $data/text" && exit 1; + +python utils/split_data.py $data $data $nj + +ark_dir=${fbankdir}/ark; mkdir -p ${ark_dir} +text_dir=${fbankdir}/txt; mkdir -p ${text_dir} +mkdir -p ${logdir} + +$cmd JOB=1:$nj $logdir/make_fbank.JOB.log \ + python utils/compute_fbank.py -w $data/split${nj}/JOB/wav.scp -t $data/split${nj}/JOB/text \ + -d $feats_dim -s $sample_frequency -m ${max_lengths} -p ${speed_perturb} -a JOB -o ${fbankdir} \ + --window-type ${window_type} \ + || exit 1; + +for n in $(seq $nj); do + cat ${ark_dir}/feats.$n.scp || exit 1 +done > $fbankdir/feats.scp || exit 1 + +for n in $(seq $nj); do + cat ${text_dir}/text.$n.txt || exit 1 +done > $fbankdir/text || exit 1 + +for n in $(seq $nj); do + cat ${ark_dir}/len.$n || exit 1 +done > $fbankdir/speech_shape || exit 1 + +for n in $(seq $nj); do + cat ${text_dir}/len.$n || exit 1 +done > $fbankdir/text_shape || exit 1 + +echo "$0: Succeeded compute FBANK features" diff --git a/egs/aishell2/transformer/utils/compute_wer.py b/egs/aishell2/transformer/utils/compute_wer.py new file mode 100755 index 000000000..26a9f491f --- /dev/null +++ b/egs/aishell2/transformer/utils/compute_wer.py @@ -0,0 +1,157 @@ +import os +import numpy as np +import sys + +def compute_wer(ref_file, + hyp_file, + cer_detail_file): + rst = { + 'Wrd': 0, + 'Corr': 0, + 'Ins': 0, + 'Del': 0, + 'Sub': 0, + 'Snt': 0, + 'Err': 0.0, + 'S.Err': 0.0, + 'wrong_words': 0, + 'wrong_sentences': 0 + } + + hyp_dict = {} + ref_dict = {} + with open(hyp_file, 'r') as hyp_reader: + for line in hyp_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + hyp_dict[key] = value + with open(ref_file, 'r') as ref_reader: + for line in ref_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + ref_dict[key] = value + + cer_detail_writer = open(cer_detail_file, 'w') + for hyp_key in hyp_dict: + if hyp_key in ref_dict: + out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key]) + rst['Wrd'] += out_item['nwords'] + rst['Corr'] += out_item['cor'] + rst['wrong_words'] += out_item['wrong'] + rst['Ins'] += out_item['ins'] + rst['Del'] += out_item['del'] + rst['Sub'] += out_item['sub'] + rst['Snt'] += 1 + if out_item['wrong'] > 0: + rst['wrong_sentences'] += 1 + cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n') + cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n') + cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n') + + if rst['Wrd'] > 0: + rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) + if rst['Snt'] > 0: + rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) + + cer_detail_writer.write('\n') + cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) + + ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n') + cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n') + cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n') + + +def compute_wer_by_line(hyp, + ref): + hyp = list(map(lambda x: x.lower(), hyp)) + ref = list(map(lambda x: x.lower(), ref)) + + len_hyp = len(hyp) + len_ref = len(ref) + + cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) + + ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) + + for i in range(len_hyp + 1): + cost_matrix[i][0] = i + for j in range(len_ref + 1): + cost_matrix[0][j] = j + + for i in range(1, len_hyp + 1): + for j in range(1, len_ref + 1): + if hyp[i - 1] == ref[j - 1]: + cost_matrix[i][j] = cost_matrix[i - 1][j - 1] + else: + substitution = cost_matrix[i - 1][j - 1] + 1 + insertion = cost_matrix[i - 1][j] + 1 + deletion = cost_matrix[i][j - 1] + 1 + + compare_val = [substitution, insertion, deletion] + + min_val = min(compare_val) + operation_idx = compare_val.index(min_val) + 1 + cost_matrix[i][j] = min_val + ops_matrix[i][j] = operation_idx + + match_idx = [] + i = len_hyp + j = len_ref + rst = { + 'nwords': len_ref, + 'cor': 0, + 'wrong': 0, + 'ins': 0, + 'del': 0, + 'sub': 0 + } + while i >= 0 or j >= 0: + i_idx = max(0, i) + j_idx = max(0, j) + + if ops_matrix[i_idx][j_idx] == 0: # correct + if i - 1 >= 0 and j - 1 >= 0: + match_idx.append((j - 1, i - 1)) + rst['cor'] += 1 + + i -= 1 + j -= 1 + + elif ops_matrix[i_idx][j_idx] == 2: # insert + i -= 1 + rst['ins'] += 1 + + elif ops_matrix[i_idx][j_idx] == 3: # delete + j -= 1 + rst['del'] += 1 + + elif ops_matrix[i_idx][j_idx] == 1: # substitute + i -= 1 + j -= 1 + rst['sub'] += 1 + + if i < 0 and j >= 0: + rst['del'] += 1 + elif j < 0 and i >= 0: + rst['ins'] += 1 + + match_idx.reverse() + wrong_cnt = cost_matrix[len_hyp][len_ref] + rst['wrong'] = wrong_cnt + + return rst + +def print_cer_detail(rst): + return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor']) + + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub=" + + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords']) + + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords'])) + +if __name__ == '__main__': + if len(sys.argv) != 4: + print("usage : python compute-wer.py test.ref test.hyp test.wer") + sys.exit(0) + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + cer_detail_file = sys.argv[3] + compute_wer(ref_file, hyp_file, cer_detail_file) diff --git a/egs/aishell2/transformer/utils/download_model.py b/egs/aishell2/transformer/utils/download_model.py new file mode 100755 index 000000000..70ea17965 --- /dev/null +++ b/egs/aishell2/transformer/utils/download_model.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +import argparse + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="download model configs", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--model_name", + type=str, + default="damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch", + help="model name in ModelScope") + args = parser.parse_args() + + inference_pipeline = pipeline( + task=Tasks.auto_speech_recognition, + model=args.model_name) diff --git a/egs/aishell2/transformer/utils/error_rate_zh b/egs/aishell2/transformer/utils/error_rate_zh new file mode 100755 index 000000000..6871a07fa --- /dev/null +++ b/egs/aishell2/transformer/utils/error_rate_zh @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +# coding=utf8 + +# Copyright 2021 Jiayu DU + +import sys +import argparse +import json +import logging +logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s') + +DEBUG = None + +def GetEditType(ref_token, hyp_token): + if ref_token == None and hyp_token != None: + return 'I' + elif ref_token != None and hyp_token == None: + return 'D' + elif ref_token == hyp_token: + return 'C' + elif ref_token != hyp_token: + return 'S' + else: + raise RuntimeError + +class AlignmentArc: + def __init__(self, src, dst, ref, hyp): + self.src = src + self.dst = dst + self.ref = ref + self.hyp = hyp + self.edit_type = GetEditType(ref, hyp) + +def similarity_score_function(ref_token, hyp_token): + return 0 if (ref_token == hyp_token) else -1.0 + +def insertion_score_function(token): + return -1.0 + +def deletion_score_function(token): + return -1.0 + +def EditDistance( + ref, + hyp, + similarity_score_function = similarity_score_function, + insertion_score_function = insertion_score_function, + deletion_score_function = deletion_score_function): + assert(len(ref) != 0) + class DPState: + def __init__(self): + self.score = -float('inf') + # backpointer + self.prev_r = None + self.prev_h = None + + def print_search_grid(S, R, H, fstream): + print(file=fstream) + for r in range(R): + for h in range(H): + print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream) + print(file=fstream) + + R = len(ref) + 1 + H = len(hyp) + 1 + + # Construct DP search space, a (R x H) grid + S = [ [] for r in range(R) ] + for r in range(R): + S[r] = [ DPState() for x in range(H) ] + + # initialize DP search grid origin, S(r = 0, h = 0) + S[0][0].score = 0.0 + S[0][0].prev_r = None + S[0][0].prev_h = None + + # initialize REF axis + for r in range(1, R): + S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1]) + S[r][0].prev_r = r-1 + S[r][0].prev_h = 0 + + # initialize HYP axis + for h in range(1, H): + S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1]) + S[0][h].prev_r = 0 + S[0][h].prev_h = h-1 + + best_score = S[0][0].score + best_state = (0, 0) + + for r in range(1, R): + for h in range(1, H): + sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1]) + new_score = S[r-1][h-1].score + sub_or_cor_score + if new_score >= S[r][h].score: + S[r][h].score = new_score + S[r][h].prev_r = r-1 + S[r][h].prev_h = h-1 + + del_score = deletion_score_function(ref[r-1]) + new_score = S[r-1][h].score + del_score + if new_score >= S[r][h].score: + S[r][h].score = new_score + S[r][h].prev_r = r - 1 + S[r][h].prev_h = h + + ins_score = insertion_score_function(hyp[h-1]) + new_score = S[r][h-1].score + ins_score + if new_score >= S[r][h].score: + S[r][h].score = new_score + S[r][h].prev_r = r + S[r][h].prev_h = h-1 + + best_score = S[R-1][H-1].score + best_state = (R-1, H-1) + + if DEBUG: + print_search_grid(S, R, H, sys.stderr) + + # Backtracing best alignment path, i.e. a list of arcs + # arc = (src, dst, ref, hyp, edit_type) + # src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis + best_path = [] + r, h = best_state[0], best_state[1] + prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h + score = S[r][h].score + # loop invariant: + # 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path + # 2. score is the value of point(r, h) on DP search grid + while prev_r != None or prev_h != None: + src = (prev_r, prev_h) + dst = (r, h) + if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct + arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h]) + elif (r == prev_r + 1 and h == prev_h): # Deletion + arc = AlignmentArc(src, dst, ref[prev_r], None) + elif (r == prev_r and h == prev_h + 1): # Insertion + arc = AlignmentArc(src, dst, None, hyp[prev_h]) + else: + raise RuntimeError + best_path.append(arc) + r, h = prev_r, prev_h + prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h + score = S[r][h].score + + best_path.reverse() + return (best_path, best_score) + +def PrettyPrintAlignment(alignment, stream = sys.stderr): + def get_token_str(token): + if token == None: + return "*" + return token + + def is_double_width_char(ch): + if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars + return True + # TODO: support other double-width-char language such as Japanese, Korean + else: + return False + + def display_width(token_str): + m = 0 + for c in token_str: + if is_double_width_char(c): + m += 2 + else: + m += 1 + return m + + R = ' REF : ' + H = ' HYP : ' + E = ' EDIT : ' + for arc in alignment: + r = get_token_str(arc.ref) + h = get_token_str(arc.hyp) + e = arc.edit_type if arc.edit_type != 'C' else '' + + nr, nh, ne = display_width(r), display_width(h), display_width(e) + n = max(nr, nh, ne) + 1 + + R += r + ' ' * (n-nr) + H += h + ' ' * (n-nh) + E += e + ' ' * (n-ne) + + print(R, file=stream) + print(H, file=stream) + print(E, file=stream) + +def CountEdits(alignment): + c, s, i, d = 0, 0, 0, 0 + for arc in alignment: + if arc.edit_type == 'C': + c += 1 + elif arc.edit_type == 'S': + s += 1 + elif arc.edit_type == 'I': + i += 1 + elif arc.edit_type == 'D': + d += 1 + else: + raise RuntimeError + return (c, s, i, d) + +def ComputeTokenErrorRate(c, s, i, d): + return 100.0 * (s + d + i) / (s + d + c) + +def ComputeSentenceErrorRate(num_err_utts, num_utts): + assert(num_utts != 0) + return 100.0 * num_err_utts / num_utts + + +class EvaluationResult: + def __init__(self): + self.num_ref_utts = 0 + self.num_hyp_utts = 0 + self.num_eval_utts = 0 # seen in both ref & hyp + self.num_hyp_without_ref = 0 + + self.C = 0 + self.S = 0 + self.I = 0 + self.D = 0 + self.token_error_rate = 0.0 + + self.num_utts_with_error = 0 + self.sentence_error_rate = 0.0 + + def to_json(self): + return json.dumps(self.__dict__) + + def to_kaldi(self): + info = ( + F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n' + F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n' + ) + return info + + def to_sclite(self): + return "TODO" + + def to_espnet(self): + return "TODO" + + def to_summary(self): + #return json.dumps(self.__dict__, indent=4) + summary = ( + '==================== Overall Statistics ====================\n' + F'num_ref_utts: {self.num_ref_utts}\n' + F'num_hyp_utts: {self.num_hyp_utts}\n' + F'num_hyp_without_ref: {self.num_hyp_without_ref}\n' + F'num_eval_utts: {self.num_eval_utts}\n' + F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n' + F'token_error_rate: {self.token_error_rate:.2f}%\n' + F'token_stats:\n' + F' - tokens:{self.C + self.S + self.D:>7}\n' + F' - edits: {self.S + self.I + self.D:>7}\n' + F' - cor: {self.C:>7}\n' + F' - sub: {self.S:>7}\n' + F' - ins: {self.I:>7}\n' + F' - del: {self.D:>7}\n' + '============================================================\n' + ) + return summary + + +class Utterance: + def __init__(self, uid, text): + self.uid = uid + self.text = text + + +def LoadUtterances(filepath, format): + utts = {} + if format == 'text': # utt_id word1 word2 ... + with open(filepath, 'r', encoding='utf8') as f: + for line in f: + line = line.strip() + if line: + cols = line.split(maxsplit=1) + assert(len(cols) == 2 or len(cols) == 1) + uid = cols[0] + text = cols[1] if len(cols) == 2 else '' + if utts.get(uid) != None: + raise RuntimeError(F'Found duplicated utterence id {uid}') + utts[uid] = Utterance(uid, text) + else: + raise RuntimeError(F'Unsupported text format {format}') + return utts + + +def tokenize_text(text, tokenizer): + if tokenizer == 'whitespace': + return text.split() + elif tokenizer == 'char': + return [ ch for ch in ''.join(text.split()) ] + else: + raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # optional + parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER') + parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text') + parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text') + # required + parser.add_argument('--ref', type=str, required=True, help='input reference file') + parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file') + + parser.add_argument('result_file', type=str) + args = parser.parse_args() + logging.info(args) + + ref_utts = LoadUtterances(args.ref, args.ref_format) + hyp_utts = LoadUtterances(args.hyp, args.hyp_format) + + r = EvaluationResult() + + # check valid utterances in hyp that have matched non-empty reference + eval_utts = [] + r.num_hyp_without_ref = 0 + for uid in sorted(hyp_utts.keys()): + if uid in ref_utts.keys(): # TODO: efficiency + if ref_utts[uid].text.strip(): # non-empty reference + eval_utts.append(uid) + else: + logging.warn(F'Found {uid} with empty reference, skipping...') + else: + logging.warn(F'Found {uid} without reference, skipping...') + r.num_hyp_without_ref += 1 + + r.num_hyp_utts = len(hyp_utts) + r.num_ref_utts = len(ref_utts) + r.num_eval_utts = len(eval_utts) + + with open(args.result_file, 'w+', encoding='utf8') as fo: + for uid in eval_utts: + ref = ref_utts[uid] + hyp = hyp_utts[uid] + + alignment, score = EditDistance( + tokenize_text(ref.text, args.tokenizer), + tokenize_text(hyp.text, args.tokenizer) + ) + + c, s, i, d = CountEdits(alignment) + utt_ter = ComputeTokenErrorRate(c, s, i, d) + + # utt-level evaluation result + print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo) + PrettyPrintAlignment(alignment, fo) + + r.C += c + r.S += s + r.I += i + r.D += d + + if utt_ter > 0: + r.num_utts_with_error += 1 + + # corpus level evaluation result + r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts) + r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D) + + print(r.to_summary(), file=fo) + + print(r.to_json()) + print(r.to_kaldi()) diff --git a/egs/aishell2/transformer/utils/extract_embeds.py b/egs/aishell2/transformer/utils/extract_embeds.py new file mode 100755 index 000000000..7b817d8ca --- /dev/null +++ b/egs/aishell2/transformer/utils/extract_embeds.py @@ -0,0 +1,47 @@ +from transformers import AutoTokenizer, AutoModel, pipeline +import numpy as np +import sys +import os +import torch +from kaldiio import WriteHelper +import re +text_file_json = sys.argv[1] +out_ark = sys.argv[2] +out_scp = sys.argv[3] +out_shape = sys.argv[4] +device = int(sys.argv[5]) +model_path = sys.argv[6] + +model = AutoModel.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) +extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device) + +with open(text_file_json, 'r') as f: + js = f.readlines() + + +f_shape = open(out_shape, "w") +with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer: + with torch.no_grad(): + for idx, line in enumerate(js): + id, tokens = line.strip().split(" ", 1) + tokens = re.sub(" ", "", tokens.strip()) + tokens = ' '.join([j for j in tokens]) + token_num = len(tokens.split(" ")) + outputs = extractor(tokens) + outputs = np.array(outputs) + embeds = outputs[0, 1:-1, :] + + token_num_embeds, dim = embeds.shape + if token_num == token_num_embeds: + writer(id, embeds) + shape_line = "{} {},{}\n".format(id, token_num_embeds, dim) + f_shape.write(shape_line) + else: + print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens)) + + + +f_shape.close() + + diff --git a/egs/aishell2/transformer/utils/filter_scp.pl b/egs/aishell2/transformer/utils/filter_scp.pl new file mode 100755 index 000000000..003530d53 --- /dev/null +++ b/egs/aishell2/transformer/utils/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: scripts/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/aishell2/transformer/utils/fix_data.sh b/egs/aishell2/transformer/utils/fix_data.sh new file mode 100755 index 000000000..b1a2bb808 --- /dev/null +++ b/egs/aishell2/transformer/utils/fix_data.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +echo "$0 $@" +data_dir=$1 + +if [ ! -f ${data_dir}/wav.scp ]; then + echo "$0: wav.scp is not found" + exit 1; +fi + +if [ ! -f ${data_dir}/text ]; then + echo "$0: text is not found" + exit 1; +fi + + + +mkdir -p ${data_dir}/.backup + +awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id +awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id + +sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id + +cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp +cp ${data_dir}/text ${data_dir}/.backup/text + +mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak +mv ${data_dir}/text ${data_dir}/text.bak + +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text + +rm ${data_dir}/wav.scp.bak +rm ${data_dir}/text.bak diff --git a/egs/aishell2/transformer/utils/fix_data_feat.sh b/egs/aishell2/transformer/utils/fix_data_feat.sh new file mode 100755 index 000000000..84eea36b6 --- /dev/null +++ b/egs/aishell2/transformer/utils/fix_data_feat.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash + +echo "$0 $@" +data_dir=$1 + +if [ ! -f ${data_dir}/feats.scp ]; then + echo "$0: feats.scp is not found" + exit 1; +fi + +if [ ! -f ${data_dir}/text ]; then + echo "$0: text is not found" + exit 1; +fi + +if [ ! -f ${data_dir}/speech_shape ]; then + echo "$0: feature lengths is not found" + exit 1; +fi + +if [ ! -f ${data_dir}/text_shape ]; then + echo "$0: text lengths is not found" + exit 1; +fi + +mkdir -p ${data_dir}/.backup + +awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id +awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id + +sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id + +cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp +cp ${data_dir}/text ${data_dir}/.backup/text +cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape +cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape + +mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak +mv ${data_dir}/text ${data_dir}/text.bak +mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak +mv ${data_dir}/text_shape ${data_dir}/text_shape.bak + +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape + +rm ${data_dir}/feats.scp.bak +rm ${data_dir}/text.bak +rm ${data_dir}/speech_shape.bak +rm ${data_dir}/text_shape.bak + diff --git a/egs/aishell2/transformer/utils/gen_ark_list.sh b/egs/aishell2/transformer/utils/gen_ark_list.sh new file mode 100755 index 000000000..aebf3562d --- /dev/null +++ b/egs/aishell2/transformer/utils/gen_ark_list.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + + +# Begin configuration section. +nj=32 +cmd=./utils/run.pl + +echo "$0 $@" + +. utils/parse_options.sh || exit 1; + +ark_dir=$1 +txt_dir=$2 +output_dir=$3 + +[ ! -d ${ark_dir}/ark ] && echo "$0: ark data is required" && exit 1; +[ ! -d ${txt_dir}/txt ] && echo "$0: txt data is required" && exit 1; + +for n in $(seq $nj); do + echo "${ark_dir}/ark/feats.$n.ark ${txt_dir}/txt/text.$n.txt" || exit 1 +done > ${output_dir}/ark_txt.scp || exit 1 + diff --git a/egs/aishell2/transformer/utils/parse_options.sh b/egs/aishell2/transformer/utils/parse_options.sh new file mode 100755 index 000000000..71fb9e5ea --- /dev/null +++ b/egs/aishell2/transformer/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/aishell2/transformer/utils/print_args.py b/egs/aishell2/transformer/utils/print_args.py new file mode 100755 index 000000000..b0c61e5b4 --- /dev/null +++ b/egs/aishell2/transformer/utils/print_args.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +import sys + + +def get_commandline_args(no_executable=True): + extra_chars = [ + " ", + ";", + "&", + "|", + "<", + ">", + "?", + "*", + "~", + "`", + '"', + "'", + "\\", + "{", + "}", + "(", + ")", + ] + + # Escape the extra characters for shell + argv = [ + arg.replace("'", "'\\''") + if all(char not in arg for char in extra_chars) + else "'" + arg.replace("'", "'\\''") + "'" + for arg in sys.argv + ] + + if no_executable: + return " ".join(argv[1:]) + else: + return sys.executable + " " + " ".join(argv) + + +def main(): + print(get_commandline_args()) + + +if __name__ == "__main__": + main() diff --git a/egs/aishell2/transformer/utils/proc_conf_oss.py b/egs/aishell2/transformer/utils/proc_conf_oss.py new file mode 100755 index 000000000..c4a90c5c1 --- /dev/null +++ b/egs/aishell2/transformer/utils/proc_conf_oss.py @@ -0,0 +1,35 @@ +from pathlib import Path + +import torch +import yaml + + +class NoAliasSafeDumper(yaml.SafeDumper): + # Disable anchor/alias in yaml because looks ugly + def ignore_aliases(self, data): + return True + + +def yaml_no_alias_safe_dump(data, stream=None, **kwargs): + """Safe-dump in yaml with no anchor/alias""" + return yaml.dump( + data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs + ) + + +def gen_conf(file, out_dir): + conf = torch.load(file)["config"] + conf["oss_bucket"] = "null" + print(conf) + output_dir = Path(out_dir) + output_dir.mkdir(parents=True, exist_ok=True) + with (output_dir / "config.yaml").open("w", encoding="utf-8") as f: + yaml_no_alias_safe_dump(conf, f, indent=4, sort_keys=False) + + +if __name__ == "__main__": + import sys + + in_f = sys.argv[1] + out_f = sys.argv[2] + gen_conf(in_f, out_f) diff --git a/egs/aishell2/transformer/utils/proce_text.py b/egs/aishell2/transformer/utils/proce_text.py new file mode 100755 index 000000000..9e517a4e1 --- /dev/null +++ b/egs/aishell2/transformer/utils/proce_text.py @@ -0,0 +1,31 @@ + +import sys +import re + +in_f = sys.argv[1] +out_f = sys.argv[2] + + +with open(in_f, "r", encoding="utf-8") as f: + lines = f.readlines() + +with open(out_f, "w", encoding="utf-8") as f: + for line in lines: + outs = line.strip().split(" ", 1) + if len(outs) == 2: + idx, text = outs + text = re.sub("", "", text) + text = re.sub("", "", text) + text = re.sub("@@", "", text) + text = re.sub("@", "", text) + text = re.sub("", "", text) + text = re.sub(" ", "", text) + text = text.lower() + else: + idx = outs[0] + text = " " + + text = [x for x in text] + text = " ".join(text) + out = "{} {}\n".format(idx, text) + f.write(out) diff --git a/egs/aishell2/transformer/utils/run.pl b/egs/aishell2/transformer/utils/run.pl new file mode 100755 index 000000000..483f95bc6 --- /dev/null +++ b/egs/aishell2/transformer/utils/run.pl @@ -0,0 +1,356 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter +# In general, doing +# run.pl some.log a b c is like running the command a b c in +# the bash shell, and putting the standard error and output into some.log. +# To run parallel jobs (backgrounded on the host machine), you can do (e.g.) +# run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB +# and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier]. +# If any of the jobs fails, this script will fail. + +# A typical example is: +# run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz +# and run.pl will run something like: +# ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log +# +# Basically it takes the command-line arguments, quotes them +# as necessary to preserve spaces, and evaluates them with bash. +# In addition it puts the command line at the top of the log, and +# the start and end times of the command at the beginning and end. +# The reason why this is useful is so that we can create a different +# version of this program that uses a queueing system instead. + +#use Data::Dumper; + +@ARGV < 2 && die "usage: run.pl log-file command-line arguments..."; + +#print STDERR "COMMAND-LINE: " . Dumper(\@ARGV) . "\n"; +$job_pick = 'all'; +$max_jobs_run = -1; +$jobstart = 1; +$jobend = 1; +$ignored_opts = ""; # These will be ignored. + +# First parse an option like JOB=1:4, and any +# options that would normally be given to +# queue.pl, which we will just discard. + +for (my $x = 1; $x <= 2; $x++) { # This for-loop is to + # allow the JOB=1:n option to be interleaved with the + # options to qsub. + while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) { + # parse any options that would normally go to qsub, but which will be ignored here. + my $switch = shift @ARGV; + if ($switch eq "-V") { + $ignored_opts .= "-V "; + } elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") { + # we do support the option --max-jobs-run n, and its GridEngine form -tc n. + # if the command appears multiple times uses the smallest option. + if ( $max_jobs_run <= 0 ) { + $max_jobs_run = shift @ARGV; + } else { + my $new_constraint = shift @ARGV; + if ( ($new_constraint < $max_jobs_run) ) { + $max_jobs_run = $new_constraint; + } + } + + if (! ($max_jobs_run > 0)) { + die "run.pl: invalid option --max-jobs-run $max_jobs_run"; + } + } else { + my $argument = shift @ARGV; + if ($argument =~ m/^--/) { + print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n"; + } + if ($switch eq "-sync" && $argument =~ m/^[yY]/) { + $ignored_opts .= "-sync "; # Note: in the + # corresponding code in queue.pl it says instead, just "$sync = 1;". + } elsif ($switch eq "-pe") { # e.g. -pe smp 5 + my $argument2 = shift @ARGV; + $ignored_opts .= "$switch $argument $argument2 "; + } elsif ($switch eq "--gpu") { + $using_gpu = $argument; + } elsif ($switch eq "--pick") { + if($argument =~ m/^(all|failed|incomplete)$/) { + $job_pick = $argument; + } else { + print STDERR "run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'" + } + } else { + # Ignore option. + $ignored_opts .= "$switch $argument "; + } + } + } + if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20 + $jobname = $1; + $jobstart = $2; + $jobend = $3; + if ($jobstart > $jobend) { + die "run.pl: invalid job range $ARGV[0]"; + } + if ($jobstart <= 0) { + die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility)."; + } + shift; + } elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1. + $jobname = $1; + $jobstart = $2; + $jobend = $2; + shift; + } elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) { + print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n"; + } +} + +# Users found this message confusing so we are removing it. +# if ($ignored_opts ne "") { +# print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n"; +# } + +if ($max_jobs_run == -1) { # If --max-jobs-run option not set, + # then work out the number of processors if possible, + # and set it based on that. + $max_jobs_run = 0; + if ($using_gpu) { + if (open(P, "nvidia-smi -L |")) { + $max_jobs_run++ while (

); + close(P); + } + if ($max_jobs_run == 0) { + $max_jobs_run = 1; + print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n"; + } + } elsif (open(P, ") { if (m/^processor/) { $max_jobs_run++; } } + if ($max_jobs_run == 0) { + print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n"; + $max_jobs_run = 10; # reasonable default. + } + close(P); + } elsif (open(P, "sysctl -a |")) { # BSD/Darwin + while (

) { + if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4 + $max_jobs_run = $1; + last; + } + } + close(P); + if ($max_jobs_run == 0) { + print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n"; + $max_jobs_run = 10; # reasonable default. + } + } else { + # allow at most 32 jobs at once, on non-UNIX systems; change this code + # if you need to change this default. + $max_jobs_run = 32; + } + # The just-computed value of $max_jobs_run is just the number of processors + # (or our best guess); and if it happens that the number of jobs we need to + # run is just slightly above $max_jobs_run, it will make sense to increase + # $max_jobs_run to equal the number of jobs, so we don't have a small number + # of leftover jobs. + $num_jobs = $jobend - $jobstart + 1; + if (!$using_gpu && + $num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) { + $max_jobs_run = $num_jobs; + } +} + +sub pick_or_exit { + # pick_or_exit ( $logfile ) + # Invoked before each job is started helps to run jobs selectively. + # + # Given the name of the output logfile decides whether the job must be + # executed (by returning from the subroutine) or not (by terminating the + # process calling exit) + # + # PRE: $job_pick is a global variable set by command line switch --pick + # and indicates which class of jobs must be executed. + # + # 1) If a failed job is not executed the process exit code will indicate + # failure, just as if the task was just executed and failed. + # + # 2) If a task is incomplete it will be executed. Incomplete may be either + # a job whose log file does not contain the accounting notes in the end, + # or a job whose log file does not exist. + # + # 3) If the $job_pick is set to 'all' (default behavior) a task will be + # executed regardless of the result of previous attempts. + # + # This logic could have been implemented in the main execution loop + # but a subroutine to preserve the current level of readability of + # that part of the code. + # + # Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020 + # + if($job_pick eq 'all'){ + return; # no need to bother with the previous log + } + open my $fh, "<", $_[0] or return; # job not executed yet + my $log_line; + my $cur_line; + while ($cur_line = <$fh>) { + if( $cur_line =~ m/# Ended \(code .*/ ) { + $log_line = $cur_line; + } + } + close $fh; + if (! defined($log_line)){ + return; # incomplete + } + if ( $log_line =~ m/# Ended \(code 0\).*/ ) { + exit(0); # complete + } elsif ( $log_line =~ m/# Ended \(code \d+(; signal \d+)?\).*/ ){ + if ($job_pick !~ m/^(failed|all)$/) { + exit(1); # failed but not going to run + } else { + return; # failed + } + } elsif ( $log_line =~ m/.*\S.*/ ) { + return; # incomplete jobs are always run + } +} + + +$logfile = shift @ARGV; + +if (defined $jobname && $logfile !~ m/$jobname/ && + $jobend > $jobstart) { + print STDERR "run.pl: you are trying to run a parallel job but " + . "you are putting the output into just one log file ($logfile)\n"; + exit(1); +} + +$cmd = ""; + +foreach $x (@ARGV) { + if ($x =~ m/^\S+$/) { $cmd .= $x . " "; } + elsif ($x =~ m:\":) { $cmd .= "'$x' "; } + else { $cmd .= "\"$x\" "; } +} + +#$Data::Dumper::Indent=0; +$ret = 0; +$numfail = 0; +%active_pids=(); + +use POSIX ":sys_wait_h"; +for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) { + if (scalar(keys %active_pids) >= $max_jobs_run) { + + # Lets wait for a change in any child's status + # Then we have to work out which child finished + $r = waitpid(-1, 0); + $code = $?; + if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen. + if ( defined $active_pids{$r} ) { + $jid=$active_pids{$r}; + $fail[$jid]=$code; + if ($code !=0) { $numfail++;} + delete $active_pids{$r}; + # print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n"; + } else { + die "run.pl: Cannot find the PID of the child process that just finished."; + } + + # In theory we could do a non-blocking waitpid over all jobs running just + # to find out if only one or more jobs finished during the previous waitpid() + # However, we just omit this and will reap the next one in the next pass + # through the for(;;) cycle + } + $childpid = fork(); + if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; } + if ($childpid == 0) { # We're in the child... this branch + # executes the job and returns (possibly with an error status). + if (defined $jobname) { + $cmd =~ s/$jobname/$jobid/g; + $logfile =~ s/$jobname/$jobid/g; + } + # exit if the job does not need to be executed + pick_or_exit( $logfile ); + + system("mkdir -p `dirname $logfile` 2>/dev/null"); + open(F, ">$logfile") || die "run.pl: Error opening log file $logfile"; + print F "# " . $cmd . "\n"; + print F "# Started at " . `date`; + $starttime = `date +'%s'`; + print F "#\n"; + close(F); + + # Pipe into bash.. make sure we're not using any other shell. + open(B, "|bash") || die "run.pl: Error opening shell command"; + print B "( " . $cmd . ") 2>>$logfile >> $logfile"; + close(B); # If there was an error, exit status is in $? + $ret = $?; + + $lowbits = $ret & 127; + $highbits = $ret >> 8; + if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" } + else { $return_str = "code $highbits"; } + + $endtime = `date +'%s'`; + open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)"; + $enddate = `date`; + chop $enddate; + print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n"; + print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n"; + close(F); + exit($ret == 0 ? 0 : 1); + } else { + $pid[$jobid] = $childpid; + $active_pids{$childpid} = $jobid; + # print STDERR "Queued: " . Dumper(\%active_pids) . "\n"; + } +} + +# Now we have submitted all the jobs, lets wait until all the jobs finish +foreach $child (keys %active_pids) { + $jobid=$active_pids{$child}; + $r = waitpid($pid[$jobid], 0); + $code = $?; + if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen. + if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully +} + +# Some sanity checks: +# The $fail array should not contain undefined codes +# The number of non-zeros in that array should be equal to $numfail +# We cannot do foreach() here, as the JOB ids do not start at zero +$failed_jids=0; +for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) { + $job_return = $fail[$jobid]; + if (not defined $job_return ) { + # print Dumper(\@fail); + + die "run.pl: Sanity check failed: we have indication that some jobs are running " . + "even after we waited for all jobs to finish" ; + } + if ($job_return != 0 ){ $failed_jids++;} +} +if ($failed_jids != $numfail) { + die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)." +} +if ($numfail > 0) { $ret = 1; } + +if ($ret != 0) { + $njobs = $jobend - $jobstart + 1; + if ($njobs == 1) { + if (defined $jobname) { + $logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with + # that job. + } + print STDERR "run.pl: job failed, log is in $logfile\n"; + if ($logfile =~ m/JOB/) { + print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script."; + } + } + else { + $logfile =~ s/$jobname/*/g; + print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n"; + } +} + + +exit ($ret); diff --git a/egs/aishell2/transformer/utils/shuffle_list.pl b/egs/aishell2/transformer/utils/shuffle_list.pl new file mode 100755 index 000000000..a116200f4 --- /dev/null +++ b/egs/aishell2/transformer/utils/shuffle_list.pl @@ -0,0 +1,44 @@ +#!/usr/bin/env perl + +# Copyright 2013 Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +if ($ARGV[0] eq "--srand") { + $n = $ARGV[1]; + $n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\""; + srand($ARGV[1]); + shift; + shift; +} else { + srand(0); # Gives inconsistent behavior if we don't seed. +} + +if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we + # don't understand. + print "Usage: shuffle_list.pl [--srand N] [input file] > output\n"; + print "randomizes the order of lines of input.\n"; + exit(1); +} + +@lines; +while (<>) { + push @lines, [ (rand(), $_)] ; +} + +@lines = sort { $a->[0] cmp $b->[0] } @lines; +foreach $l (@lines) { + print $l->[1]; +} \ No newline at end of file diff --git a/egs/aishell2/transformer/utils/split_data.py b/egs/aishell2/transformer/utils/split_data.py new file mode 100755 index 000000000..060eae6d3 --- /dev/null +++ b/egs/aishell2/transformer/utils/split_data.py @@ -0,0 +1,60 @@ +import os +import sys +import random + + +in_dir = sys.argv[1] +out_dir = sys.argv[2] +num_split = sys.argv[3] + + +def split_scp(scp, num): + assert len(scp) >= num + avg = len(scp) // num + out = [] + begin = 0 + + for i in range(num): + if i == num - 1: + out.append(scp[begin:]) + else: + out.append(scp[begin:begin+avg]) + begin += avg + + return out + + +os.path.exists("{}/wav.scp".format(in_dir)) +os.path.exists("{}/text".format(in_dir)) + +with open("{}/wav.scp".format(in_dir), 'r') as infile: + wav_list = infile.readlines() + +with open("{}/text".format(in_dir), 'r') as infile: + text_list = infile.readlines() + +assert len(wav_list) == len(text_list) + +x = list(zip(wav_list, text_list)) +random.shuffle(x) +wav_shuffle_list, text_shuffle_list = zip(*x) + +num_split = int(num_split) +wav_split_list = split_scp(wav_shuffle_list, num_split) +text_split_list = split_scp(text_shuffle_list, num_split) + +for idx, wav_list in enumerate(wav_split_list, 1): + path = out_dir + "/split" + str(num_split) + "/" + str(idx) + if not os.path.exists(path): + os.makedirs(path) + with open("{}/wav.scp".format(path), 'w') as wav_writer: + for line in wav_list: + wav_writer.write(line) + +for idx, text_list in enumerate(text_split_list, 1): + path = out_dir + "/split" + str(num_split) + "/" + str(idx) + if not os.path.exists(path): + os.makedirs(path) + with open("{}/text".format(path), 'w') as text_writer: + for line in text_list: + text_writer.write(line) diff --git a/egs/aishell2/transformer/utils/split_scp.pl b/egs/aishell2/transformer/utils/split_scp.pl new file mode 100755 index 000000000..0876dcb6d --- /dev/null +++ b/egs/aishell2/transformer/utils/split_scp.pl @@ -0,0 +1,246 @@ +#!/usr/bin/env perl + +# Copyright 2010-2011 Microsoft Corporation + +# See ../../COPYING for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This program splits up any kind of .scp or archive-type file. +# If there is no utt2spk option it will work on any text file and +# will split it up with an approximately equal number of lines in +# each but. +# With the --utt2spk option it will work on anything that has the +# utterance-id as the first entry on each line; the utt2spk file is +# of the form "utterance speaker" (on each line). +# It splits it into equal size chunks as far as it can. If you use the utt2spk +# option it will make sure these chunks coincide with speaker boundaries. In +# this case, if there are more chunks than speakers (and in some other +# circumstances), some of the resulting chunks will be empty and it will print +# an error message and exit with nonzero status. +# You will normally call this like: +# split_scp.pl scp scp.1 scp.2 scp.3 ... +# or +# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ... +# Note that you can use this script to split the utt2spk file itself, +# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ... + +# You can also call the scripts like: +# split_scp.pl -j 3 0 scp scp.0 +# [note: with this option, it assumes zero-based indexing of the split parts, +# i.e. the second number must be 0 <= n < num-jobs.] + +use warnings; + +$num_jobs = 0; +$job_id = 0; +$utt2spk_file = ""; +$one_based = 0; + +for ($x = 1; $x <= 3 && @ARGV > 0; $x++) { + if ($ARGV[0] eq "-j") { + shift @ARGV; + $num_jobs = shift @ARGV; + $job_id = shift @ARGV; + } + if ($ARGV[0] =~ /--utt2spk=(.+)/) { + $utt2spk_file=$1; + shift; + } + if ($ARGV[0] eq '--one-based') { + $one_based = 1; + shift @ARGV; + } +} + +if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 || + $job_id - $one_based >= $num_jobs)) { + die "$0: Invalid job number/index values for '-j $num_jobs $job_id" . + ($one_based ? " --one-based" : "") . "'\n" +} + +$one_based + and $job_id--; + +if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) { + die +"Usage: split_scp.pl [--utt2spk=] in.scp out1.scp out2.scp ... + or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=] in.scp [out.scp] + ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n"; +} + +$error = 0; +$inscp = shift @ARGV; +if ($num_jobs == 0) { # without -j option + @OUTPUTS = @ARGV; +} else { + for ($j = 0; $j < $num_jobs; $j++) { + if ($j == $job_id) { + if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; } + else { push @OUTPUTS, "-"; } + } else { + push @OUTPUTS, "/dev/null"; + } + } +} + +if ($utt2spk_file ne "") { # We have the --utt2spk option... + open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n"; + while(<$u_fh>) { + @A = split; + @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n"; + ($u,$s) = @A; + $utt2spk{$u} = $s; + } + close $u_fh; + open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; + @spkrs = (); + while(<$i_fh>) { + @A = split; + if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; } + $u = $A[0]; + $s = $utt2spk{$u}; + defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n"; + if(!defined $spk_count{$s}) { + push @spkrs, $s; + $spk_count{$s} = 0; + $spk_data{$s} = []; # ref to new empty array. + } + $spk_count{$s}++; + push @{$spk_data{$s}}, $_; + } + # Now split as equally as possible .. + # First allocate spks to files by allocating an approximately + # equal number of speakers. + $numspks = @spkrs; # number of speakers. + $numscps = @OUTPUTS; # number of output files. + if ($numspks < $numscps) { + die "$0: Refusing to split data because number of speakers $numspks " . + "is less than the number of output .scp files $numscps\n"; + } + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scparray[$scpidx] = []; # [] is array reference. + } + for ($spkidx = 0; $spkidx < $numspks; $spkidx++) { + $scpidx = int(($spkidx*$numscps) / $numspks); + $spk = $spkrs[$spkidx]; + push @{$scparray[$scpidx]}, $spk; + $scpcount[$scpidx] += $spk_count{$spk}; + } + + # Now will try to reassign beginning + ending speakers + # to different scp's and see if it gets more balanced. + # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2. + # We can show that if considering changing just 2 scp's, we minimize + # this by minimizing the squared difference in sizes. This is + # equivalent to minimizing the absolute difference in sizes. This + # shows this method is bound to converge. + + $changed = 1; + while($changed) { + $changed = 0; + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + # First try to reassign ending spk of this scp. + if($scpidx < $numscps-1) { + $sz = @{$scparray[$scpidx]}; + if($sz > 0) { + $spk = $scparray[$scpidx]->[$sz-1]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx]; + $nutt2 = $scpcount[$scpidx+1]; + if( abs( ($nutt2+$count) - ($nutt1-$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx+1] += $count; + $scpcount[$scpidx] -= $count; + pop @{$scparray[$scpidx]}; + unshift @{$scparray[$scpidx+1]}, $spk; + $changed = 1; + } + } + } + if($scpidx > 0 && @{$scparray[$scpidx]} > 0) { + $spk = $scparray[$scpidx]->[0]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx-1]; + $nutt2 = $scpcount[$scpidx]; + if( abs( ($nutt2-$count) - ($nutt1+$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx-1] += $count; + $scpcount[$scpidx] -= $count; + shift @{$scparray[$scpidx]}; + push @{$scparray[$scpidx-1]}, $spk; + $changed = 1; + } + } + } + } + # Now print out the files... + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scpfile = $OUTPUTS[$scpidx]; + ($scpfile ne '-' ? open($f_fh, '>', $scpfile) + : open($f_fh, '>&', \*STDOUT)) || + die "$0: Could not open scp file $scpfile for writing: $!\n"; + $count = 0; + if(@{$scparray[$scpidx]} == 0) { + print STDERR "$0: eError: split_scp.pl producing empty .scp file " . + "$scpfile (too many splits and too few speakers?)\n"; + $error = 1; + } else { + foreach $spk ( @{$scparray[$scpidx]} ) { + print $f_fh @{$spk_data{$spk}}; + $count += $spk_count{$spk}; + } + $count == $scpcount[$scpidx] || die "Count mismatch [code error]"; + } + close($f_fh); + } +} else { + # This block is the "normal" case where there is no --utt2spk + # option and we just break into equal size chunks. + + open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; + + $numscps = @OUTPUTS; # size of array. + @F = (); + while(<$i_fh>) { + push @F, $_; + } + $numlines = @F; + if($numlines == 0) { + print STDERR "$0: error: empty input scp file $inscp\n"; + $error = 1; + } + $linesperscp = int( $numlines / $numscps); # the "whole part".. + $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n"; + $remainder = $numlines - ($linesperscp * $numscps); + ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder"; + # [just doing int() rounds down]. + $n = 0; + for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) { + $scpfile = $OUTPUTS[$scpidx]; + ($scpfile ne '-' ? open($o_fh, '>', $scpfile) + : open($o_fh, '>&', \*STDOUT)) || + die "$0: Could not open scp file $scpfile for writing: $!\n"; + for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) { + print $o_fh $F[$n++]; + } + close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n"; + } + $n == $numlines || die "$n != $numlines [code error]"; +} + +exit ($error); diff --git a/egs/aishell2/transformer/utils/subset_data_dir_tr_cv.sh b/egs/aishell2/transformer/utils/subset_data_dir_tr_cv.sh new file mode 100755 index 000000000..e16cebdf1 --- /dev/null +++ b/egs/aishell2/transformer/utils/subset_data_dir_tr_cv.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +dev_num_utt=1000 + +echo "$0 $@" +. utils/parse_options.sh || exit 1; + +train_data=$1 +out_dir=$2 + +[ ! -f ${train_data}/wav.scp ] && echo "$0: no such file ${train_data}/wav.scp" && exit 1; +[ ! -f ${train_data}/text ] && echo "$0: no such file ${train_data}/text" && exit 1; + +mkdir -p ${out_dir}/train && mkdir -p ${out_dir}/dev + +cp ${train_data}/wav.scp ${out_dir}/train/wav.scp.bak +cp ${train_data}/text ${out_dir}/train/text.bak + +num_utt=$(wc -l <${out_dir}/train/wav.scp.bak) + +utils/shuffle_list.pl --srand 1 ${out_dir}/train/wav.scp.bak > ${out_dir}/train/wav.scp.shuf +head -n ${dev_num_utt} ${out_dir}/train/wav.scp.shuf > ${out_dir}/dev/wav.scp +tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/wav.scp.shuf > ${out_dir}/train/wav.scp + +utils/shuffle_list.pl --srand 1 ${out_dir}/train/text.bak > ${out_dir}/train/text.shuf +head -n ${dev_num_utt} ${out_dir}/train/text.shuf > ${out_dir}/dev/text +tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/text.shuf > ${out_dir}/train/text + +rm ${out_dir}/train/wav.scp.bak ${out_dir}/train/text.bak +rm ${out_dir}/train/wav.scp.shuf ${out_dir}/train/text.shuf diff --git a/egs/aishell2/transformer/utils/text2token.py b/egs/aishell2/transformer/utils/text2token.py new file mode 100755 index 000000000..56c39138f --- /dev/null +++ b/egs/aishell2/transformer/utils/text2token.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + + +import argparse +import codecs +import re +import sys + +is_python2 = sys.version_info[0] == 2 + + +def exist_or_not(i, match_pos): + start_pos = None + end_pos = None + for pos in match_pos: + if pos[0] <= i < pos[1]: + start_pos = pos[0] + end_pos = pos[1] + break + + return start_pos, end_pos + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert raw text to tokenized text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--nchar", + "-n", + default=1, + type=int, + help="number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2", + ) + parser.add_argument( + "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" + ) + parser.add_argument("--space", default="", type=str, help="space symbol") + parser.add_argument( + "--non-lang-syms", + "-l", + default=None, + type=str, + help="list of non-linguistic symobles, e.g., etc.", + ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "--trans_type", + "-t", + type=str, + default="char", + choices=["char", "phn"], + help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 - + If trans_type is char, + read from SI1279.WRD file -> "bricks are an alternative" + Else if trans_type is phn, + read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l + sil t er n ih sil t ih v sil" """, + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + rs = [] + if args.non_lang_syms is not None: + with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: + nls = [x.rstrip() for x in f.readlines()] + rs = [re.compile(re.escape(x)) for x in nls] + + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")( + sys.stdout if is_python2 else sys.stdout.buffer + ) + line = f.readline() + n = args.nchar + while line: + x = line.split() + print(" ".join(x[: args.skip_ncols]), end=" ") + a = " ".join(x[args.skip_ncols :]) + + # get all matched positions + match_pos = [] + for r in rs: + i = 0 + while i >= 0: + m = r.search(a, i) + if m: + match_pos.append([m.start(), m.end()]) + i = m.end() + else: + break + + if args.trans_type == "phn": + a = a.split(" ") + else: + if len(match_pos) > 0: + chars = [] + i = 0 + while i < len(a): + start_pos, end_pos = exist_or_not(i, match_pos) + if start_pos is not None: + chars.append(a[start_pos:end_pos]) + i = end_pos + else: + chars.append(a[i]) + i += 1 + a = chars + + a = [a[j : j + n] for j in range(0, len(a), n)] + + a_flat = [] + for z in a: + a_flat.append("".join(z)) + + a_chars = [z.replace(" ", args.space) for z in a_flat] + if args.trans_type == "phn": + a_chars = [z.replace("sil", args.space) for z in a_chars] + print(" ".join(a_chars)) + line = f.readline() + + +if __name__ == "__main__": + main() diff --git a/egs/aishell2/transformer/utils/text_tokenize.py b/egs/aishell2/transformer/utils/text_tokenize.py new file mode 100755 index 000000000..962ea11bc --- /dev/null +++ b/egs/aishell2/transformer/utils/text_tokenize.py @@ -0,0 +1,106 @@ +import re +import argparse + + +def load_dict(seg_file): + seg_dict = {} + with open(seg_file, 'r') as infile: + for line in infile: + s = line.strip().split() + key = s[0] + value = s[1:] + seg_dict[key] = " ".join(value) + return seg_dict + + +def forward_segment(text, dic): + word_list = [] + i = 0 + while i < len(text): + longest_word = text[i] + for j in range(i + 1, len(text) + 1): + word = text[i:j] + if word in dic: + if len(word) > len(longest_word): + longest_word = word + word_list.append(longest_word) + i += len(longest_word) + return word_list + + +def tokenize(txt, + seg_dict): + out_txt = "" + pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])") + for word in txt: + if pattern.match(word): + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + out_txt += "" + " " + else: + continue + return out_txt.strip() + + +def get_parser(): + parser = argparse.ArgumentParser( + description="text tokenize", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--text-file", + "-t", + default=False, + required=True, + type=str, + help="input text", + ) + parser.add_argument( + "--seg-file", + "-s", + default=False, + required=True, + type=str, + help="seg file", + ) + parser.add_argument( + "--txt-index", + "-i", + default=1, + required=True, + type=int, + help="txt index", + ) + parser.add_argument( + "--output-dir", + "-o", + default=False, + required=True, + type=str, + help="output dir", + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w') + shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w') + seg_dict = load_dict(args.seg_file) + with open(args.text_file, 'r') as infile: + for line in infile: + s = line.strip().split() + text_id = s[0] + text_list = forward_segment("".join(s[1:]).lower(), seg_dict) + text = tokenize(text_list, seg_dict) + lens = len(text.strip().split()) + txt_writer.write(text_id + " " + text + '\n') + shape_writer.write(text_id + " " + str(lens) + '\n') + + +if __name__ == '__main__': + main() + diff --git a/egs/aishell2/transformer/utils/text_tokenize.sh b/egs/aishell2/transformer/utils/text_tokenize.sh new file mode 100755 index 000000000..6b74fef80 --- /dev/null +++ b/egs/aishell2/transformer/utils/text_tokenize.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + + +# Begin configuration section. +nj=32 +cmd=utils/run.pl + +echo "$0 $@" + +. utils/parse_options.sh || exit 1; + +# tokenize configuration +text_dir=$1 +seg_file=$2 +logdir=$3 +output_dir=$4 + +txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt +mkdir -p ${logdir} + +$cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \ + python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \ + -s ${seg_file} -i JOB -o ${txt_dir} \ + || exit 1; + +# concatenate the text files together. +for n in $(seq $nj); do + cat ${txt_dir}/text.$n.txt || exit 1 +done > ${output_dir}/text || exit 1 + +for n in $(seq $nj); do + cat ${txt_dir}/len.$n || exit 1 +done > ${output_dir}/text_shape || exit 1 + +echo "$0: Succeeded text tokenize" diff --git a/egs/aishell2/transformer/utils/textnorm_zh.py b/egs/aishell2/transformer/utils/textnorm_zh.py new file mode 100755 index 000000000..79feb83fd --- /dev/null +++ b/egs/aishell2/transformer/utils/textnorm_zh.py @@ -0,0 +1,834 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +# Authors: +# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) +# 2019.9 Jiayu DU +# +# requirements: +# - python 3.X +# notes: python 2.X WILL fail or produce misleading results + +import sys, os, argparse, codecs, string, re + +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = u'零一二三四五六七八九' +BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖' +BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖' +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬' +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载' +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載' +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬' + +ZERO_ALT = u'〇' +ONE_ALT = u'幺' +TWO_ALTS = [u'两', u'兩'] + +POSITIVE = [u'正', u'正'] +NEGATIVE = [u'负', u'負'] +POINT = [u'点', u'點'] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +FILLER_CHARS = ['呃', '啊'] +ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \ + '胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \ + '儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \ + '佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)' + +# 中文数字系统类型 +NUMBERING_TYPES = ['low', 'mid', 'high'] + +CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \ + '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)' +CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)' +COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \ + '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \ + '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \ + '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \ + '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \ + '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)' + +# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CHINESE_PUNC_STOP = '!?。。' +CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏' +CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP + +# ================================================================================ # +# basic class +# ================================================================================ # +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + #self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return '10^{}'.format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit(power=index + 1, + simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit(power=index + 8, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit(power=(index + 2) * 4, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit(power=pow(2, index + 3), + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + else: + raise ValueError( + 'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type)) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) + larger_units = [CNU.create(i, v, numbering_type, False) + for i, v in enumerate(all_larger_units)] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) + smaller_units = [CNU.create(i, v, small_unit=True) + for i, v in enumerate(all_smaller_units)] + # digis + chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x) + point_cn = CM(POINT[0], POINT[1], '.', lambda x, + y: float(str(x) + '.' + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, '' + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], \ + [get_symbol(c, system) for c in dec_string] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None)) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power: + result[-i - 1] = CNU(result[-i - 1].power + + current_unit.power, None, None, None, None) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * + pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = ''.join([str(d.value) for d in dec_part]) + if dec_part: + return '{0}.{1}'.format(int_str, dec_str) + else: + return int_str + + +def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False, + traditional=False, alt_zero=False, alt_one=False, alt_two=True, + use_zeros=True, use_units=True): + + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip('0') + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next(u for u in reversed( + system.units) if u.power < len(striped_string)) + result_string = value_string[:-result_unit.power] + return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:]) + + system = create_system(numbering_type) + + int_dec = number_string.split('.') + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string)) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, + system.digits[2].big_s, system.digits[2].big_t) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = result_symbols[i + + 1] if i < len(result_symbols) - 1 else None + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): + if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = 'big_' + if traditional: + attr_name += 't' + else: + attr_name += 's' + else: + if traditional: + attr_name = 'traditional' + else: + attr_name = 'simplified' + + result = ''.join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \ + result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]: + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split('-') + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sil_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + else: + sp_parts = self.telephone.strip('+').split() + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sp_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split('分之') + return chn2num(numerator) + '/' + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split('/') + return num2chn(denominator) + '分之' + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split('年', 1) + year = Digit(digit=year).digit2chntext() + '年' + except ValueError: + other = date + year = '' + if other: + try: + month, day = other.strip().split('月', 1) + month = Cardinal(cardinal=month).cardinal2chntext() + '月' + except ValueError: + day = date + month = '' + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = '' + day = '' + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r'(\d+(\.\d+)?)') + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip('百分之')) + '%' + + def percentage2chntext(self): + return '百分之' + num2chn(self.percentage.strip().strip('%')) + + +def remove_erhua(text, er_whitelist): + """ + 去除儿化音词中的儿: + 他女儿在那边儿 -> 他女儿在那边 + """ + + er_pattern = re.compile(er_whitelist) + new_str='' + while re.search('儿',text): + a = re.search('儿',text).span() + remove_er_flag = 0 + + if er_pattern.search(text): + b = er_pattern.search(text).span() + if b[0] <= a[0]: + remove_er_flag = 1 + + if remove_er_flag == 0 : + new_str = new_str + text[0:a[0]] + text = text[a[1]:] + else: + new_str = new_str + text[0:b[1]] + text = text[b[1]:] + + text = new_str + text + return text + +# ================================================================================ # +# NSW Normalizer +# ================================================================================ # +class NSWNormalizer: + def __init__(self, raw_text): + self.raw_text = '^' + raw_text + '$' + self.norm_text = '' + + def _particular(self): + text = self.norm_text + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1) + self.norm_text = text + return self.norm_text + + def normalize(self): + text = self.raw_text + + # 规范化日期 + pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + #print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") + matchers = pattern.findall(text) + if matchers: + #print('money') + for matcher in matchers: + text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + #print('telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + #print('fraction') + for matcher in matchers: + text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) + + # 规范化百分数 + text = text.replace('%', '%') + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + #print('percentage') + for matcher in matchers: + text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + #print('cardinal+quantifier') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + #print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + #print('cardinal') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + self.norm_text = text + self._particular() + + return self.norm_text.lstrip('^').rstrip('$') + + +def nsw_test_case(raw_text): + print('I:' + raw_text) + print('O:' + NSWNormalizer(raw_text).normalize()) + print('') + + +def nsw_test(): + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('手机:+86 19859213959或15659451527。') + nsw_test_case('分数:32477/76391。') + nsw_test_case('百分数:80.03%。') + nsw_test_case('编号:31520181154418。') + nsw_test_case('纯数:2983.07克或12345.60米。') + nsw_test_case('日期:1999年2月20日或09年3月15号。') + nsw_test_case('金钱:12块5,34.5元,20.1万') + nsw_test_case('特殊:O2O或B2C。') + nsw_test_case('3456万吨') + nsw_test_case('2938个') + nsw_test_case('938') + nsw_test_case('今天吃了115个小笼包231个馒头') + nsw_test_case('有62%的概率') + + +if __name__ == '__main__': + #nsw_test() + + p = argparse.ArgumentParser() + p.add_argument('ifile', help='input filename, assume utf-8 encoding') + p.add_argument('ofile', help='output filename') + p.add_argument('--to_upper', action='store_true', help='convert to upper case') + p.add_argument('--to_lower', action='store_true', help='convert to lower case') + p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.") + p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "呃, 啊"') + p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "这儿"') + p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines') + args = p.parse_args() + + ifile = codecs.open(args.ifile, 'r', 'utf8') + ofile = codecs.open(args.ofile, 'w+', 'utf8') + + n = 0 + for l in ifile: + key = '' + text = '' + if args.has_key: + cols = l.split(maxsplit=1) + key = cols[0] + if len(cols) == 2: + text = cols[1].strip() + else: + text = '' + else: + text = l.strip() + + # cases + if args.to_upper and args.to_lower: + sys.stderr.write('text norm: to_upper OR to_lower?') + exit(1) + if args.to_upper: + text = text.upper() + if args.to_lower: + text = text.lower() + + # Filler chars removal + if args.remove_fillers: + for ch in FILLER_CHARS: + text = text.replace(ch, '') + + if args.remove_erhua: + text = remove_erhua(text, ER_WHITELIST) + + # NSW(Non-Standard-Word) normalization + text = NSWNormalizer(text).normalize() + + # Punctuations removal + old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations + new_chars = ' ' * len(old_chars) + del_chars = '' + text = text.translate(str.maketrans(old_chars, new_chars, del_chars)) + + # + if args.has_key: + ofile.write(key + '\t' + text + '\n') + else: + ofile.write(text + '\n') + + n += 1 + if n % args.log_interval == 0: + sys.stderr.write("text norm: {} lines done.\n".format(n)) + + sys.stderr.write("text norm: {} lines done in total.\n".format(n)) + + ifile.close() + ofile.close()