diff --git a/data/list/text.txt b/data/list/text.txt new file mode 100644 index 000000000..f4d4fe457 --- /dev/null +++ b/data/list/text.txt @@ -0,0 +1,2 @@ +ID0012W0013 当客户风险承受能力评估依据发生变化时 +ID0012W0014 杨涛不得不将工厂关掉 \ No newline at end of file diff --git a/data/list/wav.scp b/data/list/wav.scp new file mode 100644 index 000000000..325a340c6 --- /dev/null +++ b/data/list/wav.scp @@ -0,0 +1,2 @@ +ID0012W0013 /Users/zhifu/funasr_github/test_local/aishell2_dev_ios/wav/D0012/ID0012W0013.wav +ID0012W0014 /Users/zhifu/funasr_github/test_local/aishell2_dev_ios/wav/D0012/ID0012W0014.wav \ No newline at end of file diff --git a/examples/aishell/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml b/examples/aishell/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml index 94b7f6d14..3a2231f4d 100644 --- a/examples/aishell/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml +++ b/examples/aishell/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml @@ -1,6 +1,6 @@ # network architecture -model: funasr.cli.models.paraformer:Paraformer +model: Paraformer model_conf: ctc_weight: 0.3 lsm_weight: 0.1 @@ -9,9 +9,8 @@ model_conf: sampling_ratio: 0.4 use_1st_decoder_loss: true - -# encoder related -encoder: conformer +# encoder +encoder: ConformerEncoder encoder_conf: output_size: 256 # dimension of attention attention_heads: 4 @@ -29,8 +28,8 @@ encoder_conf: use_cnn_module: true cnn_module_kernel: 15 -# decoder related -decoder: paraformer_decoder_san +# decoder +decoder: ParaformerSANDecoder decoder_conf: attention_heads: 4 linear_units: 2048 @@ -40,8 +39,17 @@ decoder_conf: self_attention_dropout_rate: 0.0 src_attention_dropout_rate: 0.0 +# predictor +predictor: CifPredictor +predictor_conf: + idim: 256 + threshold: 1.0 + l_order: 1 + r_order: 1 + tail_threshold: 0.45 + # frontend related -frontend: wav_frontend +frontend: WavFrontend frontend_conf: fs: 16000 window: hamming @@ -51,29 +59,7 @@ frontend_conf: lfr_m: 1 lfr_n: 1 - -train_conf: - accum_grad: 1 - grad_clip: 5 - max_epoch: 150 - val_scheduler_criterion: - - valid - - acc - best_model_criterion: - - - valid - - acc - - max - keep_nbest_models: 10 - log_interval: 50 - -optim: adam -optim_conf: - lr: 0.0005 -scheduler: warmuplr -scheduler_conf: - warmup_steps: 30000 - -specaug: specaug +specaug: SpecAug specaug_conf: apply_time_warp: true time_warp_window: 5 @@ -89,25 +75,43 @@ specaug_conf: - 40 num_time_mask: 2 -predictor: cif_predictor -predictor_conf: - idim: 256 - threshold: 1.0 - l_order: 1 - r_order: 1 - tail_threshold: 0.45 +train_conf: + accum_grad: 1 + grad_clip: 5 + max_epoch: 150 + keep_nbest_models: 10 + avg_nbest_model: 5 + log_interval: 50 +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 30000 + +dataset: AudioDataset dataset_conf: - data_names: speech,text - data_types: sound,text + index_ds: IndexDSJsonl + batch_sampler: RankFullLocalShuffleBatchSampler + batch_type: example # example or length + batch_size: 32 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; + max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, + buffer_size: 1024 shuffle: True - shuffle_conf: - shuffle_size: 2048 - sort_size: 500 - batch_conf: - batch_type: example - batch_size: 2 - num_workers: 8 + num_workers: 0 + +tokenizer: CharTokenizer +tokenizer_conf: + unk_symbol: + split_with_space: true + + +ctc_conf: + dropout_rate: 0.0 + ctc_type: builtin + reduce: true + ignore_nan_grad: true +normalize: null -normalize: null \ No newline at end of file diff --git a/examples/aishell/finetune.sh b/examples/aishell/finetune.sh new file mode 100644 index 000000000..ea6c32c2c --- /dev/null +++ b/examples/aishell/finetune.sh @@ -0,0 +1,9 @@ + +cmd="funasr/bin/train.py" + +python $cmd \ +--config-path "/Users/zhifu/funasr_github/test_local/funasr_cli_egs" \ +--config-name "config.yaml" \ +++token_list="/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \ +++train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \ +++output_dir="/nfs/zhifu.gzf/ckpt/funasr2/exp1" \ No newline at end of file diff --git a/examples/aishell/paraformer/local/aishell_data_prep.sh b/examples/aishell/paraformer/local/aishell_data_prep.sh new file mode 100755 index 000000000..83f489b3c --- /dev/null +++ b/examples/aishell/paraformer/local/aishell_data_prep.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Copyright 2017 Xingyu Na +# Apache 2.0 + +#. ./path.sh || exit 1; + +if [ $# != 3 ]; then + echo "Usage: $0 " + echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data" + exit 1; +fi + +aishell_audio_dir=$1 +aishell_text=$2/aishell_transcript_v0.8.txt +output_dir=$3 + +train_dir=$output_dir/data/local/train +dev_dir=$output_dir/data/local/dev +test_dir=$output_dir/data/local/test +tmp_dir=$output_dir/data/local/tmp + +mkdir -p $train_dir +mkdir -p $dev_dir +mkdir -p $test_dir +mkdir -p $tmp_dir + +# data directory check +if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then + echo "Error: $0 requires two directory arguments" + exit 1; +fi + +# find wav audio file for train, dev and test resp. +find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist +n=`cat $tmp_dir/wav.flist | wc -l` +[ $n -ne 141925 ] && \ + echo Warning: expected 141925 data data files, found $n + +grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; +grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; +grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; + +rm -r $tmp_dir + +# Transcriptions preparation +for dir in $train_dir $dev_dir $test_dir; do + echo Preparing $dir transcriptions + sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list + paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all + utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt + awk '{print $1}' $dir/transcripts.txt > $dir/utt.list + utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp + sort -u $dir/transcripts.txt > $dir/text +done + +mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test + +for f in wav.scp text; do + cp $train_dir/$f $output_dir/data/train/$f || exit 1; + cp $dev_dir/$f $output_dir/data/dev/$f || exit 1; + cp $test_dir/$f $output_dir/data/test/$f || exit 1; +done + +echo "$0: AISHELL data preparation succeeded" +exit 0; diff --git a/examples/aishell/paraformer/local/download_and_untar.sh b/examples/aishell/paraformer/local/download_and_untar.sh new file mode 100755 index 000000000..d98255915 --- /dev/null +++ b/examples/aishell/paraformer/local/download_and_untar.sh @@ -0,0 +1,105 @@ +#!/usr/bin/env bash + +# Copyright 2014 Johns Hopkins University (author: Daniel Povey) +# 2017 Xingyu Na +# Apache 2.0 + +remove_archive=false + +if [ "$1" == --remove-archive ]; then + remove_archive=true + shift +fi + +if [ $# -ne 3 ]; then + echo "Usage: $0 [--remove-archive] " + echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" + echo "With --remove-archive it will remove the archive after successfully un-tarring it." + echo " can be one of: data_aishell, resource_aishell." +fi + +data=$1 +url=$2 +part=$3 + +if [ ! -d "$data" ]; then + echo "$0: no such directory $data" + exit 1; +fi + +part_ok=false +list="data_aishell resource_aishell" +for x in $list; do + if [ "$part" == $x ]; then part_ok=true; fi +done +if ! $part_ok; then + echo "$0: expected to be one of $list, but got '$part'" + exit 1; +fi + +if [ -z "$url" ]; then + echo "$0: empty URL base." + exit 1; +fi + +if [ -f $data/$part/.complete ]; then + echo "$0: data part $part was already successfully extracted, nothing to do." + exit 0; +fi + +# sizes of the archive files in bytes. +sizes="15582913665 1246920" + +if [ -f $data/$part.tgz ]; then + size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') + size_ok=false + for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done + if ! $size_ok; then + echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" + echo "does not equal the size of one of the archives." + rm $data/$part.tgz + else + echo "$data/$part.tgz exists and appears to be complete." + fi +fi + +if [ ! -f $data/$part.tgz ]; then + if ! command -v wget >/dev/null; then + echo "$0: wget is not installed." + exit 1; + fi + full_url=$url/$part.tgz + echo "$0: downloading data from $full_url. This may take some time, please be patient." + + cd $data || exit 1 + if ! wget --no-check-certificate $full_url; then + echo "$0: error executing wget $full_url" + exit 1; + fi +fi + +cd $data || exit 1 + +if ! tar -xvzf $part.tgz; then + echo "$0: error un-tarring archive $data/$part.tgz" + exit 1; +fi + +touch $data/$part/.complete + +if [ $part == "data_aishell" ]; then + cd $data/$part/wav || exit 1 + for wav in ./*.tar.gz; do + echo "Extracting wav from $wav" + tar -zxf $wav && rm $wav + done +fi + +echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" + +if $remove_archive; then + echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." + rm $data/$part.tgz +fi + +exit 0; diff --git a/examples/aishell/paraformer/run.sh b/examples/aishell/paraformer/run.sh new file mode 100755 index 000000000..7972a13c0 --- /dev/null +++ b/examples/aishell/paraformer/run.sh @@ -0,0 +1,203 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; + +# machines configuration +CUDA_VISIBLE_DEVICES="0,1" +gpu_num=2 +gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding +# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob +njob=1 + +# general configuration +feats_dir="../DATA" #feature output dictionary +exp_dir="." +lang=zh +token_type=char +stage=0 +stop_stage=5 + +# feature configuration +nj=64 + +# data +raw_data=../raw_data +data_url=www.openslr.org/resources/33 + +# exp tag +tag="exp1" + +. utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set=train +valid_set=dev +test_sets="dev test" + +asr_config=conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml +model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}" + +#inference_config=conf/decode_asr_transformer_noctc_1best.yaml +#inference_asr_model=valid.acc.ave_10best.pb + +## you can set gpu num for decoding here +#gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default +#ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') +# +#if ${gpu_inference}; then +# inference_nj=$[${ngpu}*${njob}] +# _ngpu=1 +#else +# inference_nj=$njob +# _ngpu=0 +#fi + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "stage -1: Data Download" + local/download_and_untar.sh ${raw_data} ${data_url} data_aishell + local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: Data preparation" + # Data preparation + local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir} + for x in train dev test; do + cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org + paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \ + > ${feats_dir}/data/${x}/text + utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org + mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text + + python funasr/datasets/audio_datasets/scp2jsonl.py \ + ++scp_file_list='["${feats_dir}/data/${x}/wav.scp", "${feats_dir}/data/${x}/text"]' \ + ++data_type_list='["source", "target"]' \ + ++jsonl_file_out=${feats_dir}/data/${x}/audio_datasets.jsonl + done +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Feature and CMVN Generation" +# utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0 + python funasr/bin/compute_audio_cmvn.py \ + --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \ + --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \ + ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \ + ++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \ + ++dataset_conf.num_workers=$nj +fi + +token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt +echo "dictionary: ${token_list}" +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: Dictionary Preparation" + mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/ + + echo "make a dictionary" + echo "" > ${token_list} + echo "" >> ${token_list} + echo "" >> ${token_list} + utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} + echo "" >> ${token_list} +fi + +# LM Training Stage +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: LM Training" +fi + +# ASR Training Stage +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then +echo "stage 4: ASR Training" + +torchrun \ +--nnodes 1 \ +--nproc_per_node ${gpu_num} \ +funasr/bin/train.py \ +--config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \ +--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \ +++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \ +++cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \ +++token_list="${token_list}" \ +++output_dir="${exp_dir}/exp/${model_dir}" +fi + +# +## Testing Stage +#if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then +# echo "stage 5: Inference" +# for dset in ${test_sets}; do +# asr_exp=${exp_dir}/exp/${model_dir} +# inference_tag="$(basename "${inference_config}" .yaml)" +# _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" +# _logdir="${_dir}/logdir" +# if [ -d ${_dir} ]; then +# echo "${_dir} is already exists. if you want to decode again, please delete this dir first." +# exit 0 +# fi +# mkdir -p "${_logdir}" +# _data="${feats_dir}/data/${dset}" +# key_file=${_data}/${scp} +# num_scp_file="$(<${key_file} wc -l)" +# _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") +# split_scps= +# for n in $(seq "${_nj}"); do +# split_scps+=" ${_logdir}/keys.${n}.scp" +# done +# # shellcheck disable=SC2086 +# utils/split_scp.pl "${key_file}" ${split_scps} +# _opts= +# if [ -n "${inference_config}" ]; then +# _opts+="--config ${inference_config} " +# fi +# ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ +# python -m funasr.bin.asr_inference_launch \ +# --batch_size 1 \ +# --ngpu "${_ngpu}" \ +# --njob ${njob} \ +# --gpuid_list ${gpuid_list} \ +# --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ +# --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \ +# --key_file "${_logdir}"/keys.JOB.scp \ +# --asr_train_config "${asr_exp}"/config.yaml \ +# --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ +# --output_dir "${_logdir}"/output.JOB \ +# --mode paraformer \ +# ${_opts} +# +# for f in token token_int score text; do +# if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then +# for i in $(seq "${_nj}"); do +# cat "${_logdir}/output.${i}/1best_recog/${f}" +# done | sort -k1 >"${_dir}/${f}" +# fi +# done +# python utils/proce_text.py ${_dir}/text ${_dir}/text.proc +# python utils/proce_text.py ${_data}/text ${_data}/text.proc +# python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer +# tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt +# cat ${_dir}/text.cer.txt +# done +#fi +# +## Prepare files for ModelScope fine-tuning and inference +#if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then +# echo "stage 6: ModelScope Preparation" +# cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn +# vocab_size=$(cat ${token_list} | wc -l) +# python utils/gen_modelscope_configuration.py \ +# --am_model_name $inference_asr_model \ +# --mode paraformer \ +# --model_name paraformer \ +# --dataset aishell \ +# --output_dir $exp_dir/exp/$model_dir \ +# --vocab_size $vocab_size \ +# --nat _nat \ +# --tag $tag +#fi \ No newline at end of file diff --git a/examples/aishell/paraformer/utils/extract_embeds.py b/examples/aishell/paraformer/utils/extract_embeds.py new file mode 100755 index 000000000..7b817d8ca --- /dev/null +++ b/examples/aishell/paraformer/utils/extract_embeds.py @@ -0,0 +1,47 @@ +from transformers import AutoTokenizer, AutoModel, pipeline +import numpy as np +import sys +import os +import torch +from kaldiio import WriteHelper +import re +text_file_json = sys.argv[1] +out_ark = sys.argv[2] +out_scp = sys.argv[3] +out_shape = sys.argv[4] +device = int(sys.argv[5]) +model_path = sys.argv[6] + +model = AutoModel.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) +extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device) + +with open(text_file_json, 'r') as f: + js = f.readlines() + + +f_shape = open(out_shape, "w") +with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer: + with torch.no_grad(): + for idx, line in enumerate(js): + id, tokens = line.strip().split(" ", 1) + tokens = re.sub(" ", "", tokens.strip()) + tokens = ' '.join([j for j in tokens]) + token_num = len(tokens.split(" ")) + outputs = extractor(tokens) + outputs = np.array(outputs) + embeds = outputs[0, 1:-1, :] + + token_num_embeds, dim = embeds.shape + if token_num == token_num_embeds: + writer(id, embeds) + shape_line = "{} {},{}\n".format(id, token_num_embeds, dim) + f_shape.write(shape_line) + else: + print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens)) + + + +f_shape.close() + + diff --git a/examples/aishell/paraformer/utils/filter_scp.pl b/examples/aishell/paraformer/utils/filter_scp.pl new file mode 100755 index 000000000..003530d53 --- /dev/null +++ b/examples/aishell/paraformer/utils/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: scripts/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2) diff --git a/examples/aishell/paraformer/utils/fix_data.sh b/examples/aishell/paraformer/utils/fix_data.sh new file mode 100755 index 000000000..b1a2bb808 --- /dev/null +++ b/examples/aishell/paraformer/utils/fix_data.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +echo "$0 $@" +data_dir=$1 + +if [ ! -f ${data_dir}/wav.scp ]; then + echo "$0: wav.scp is not found" + exit 1; +fi + +if [ ! -f ${data_dir}/text ]; then + echo "$0: text is not found" + exit 1; +fi + + + +mkdir -p ${data_dir}/.backup + +awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id +awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id + +sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id + +cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp +cp ${data_dir}/text ${data_dir}/.backup/text + +mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak +mv ${data_dir}/text ${data_dir}/text.bak + +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text + +rm ${data_dir}/wav.scp.bak +rm ${data_dir}/text.bak diff --git a/examples/aishell/paraformer/utils/fix_data_feat.sh b/examples/aishell/paraformer/utils/fix_data_feat.sh new file mode 100755 index 000000000..84eea36b6 --- /dev/null +++ b/examples/aishell/paraformer/utils/fix_data_feat.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash + +echo "$0 $@" +data_dir=$1 + +if [ ! -f ${data_dir}/feats.scp ]; then + echo "$0: feats.scp is not found" + exit 1; +fi + +if [ ! -f ${data_dir}/text ]; then + echo "$0: text is not found" + exit 1; +fi + +if [ ! -f ${data_dir}/speech_shape ]; then + echo "$0: feature lengths is not found" + exit 1; +fi + +if [ ! -f ${data_dir}/text_shape ]; then + echo "$0: text lengths is not found" + exit 1; +fi + +mkdir -p ${data_dir}/.backup + +awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id +awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id + +sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id + +cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp +cp ${data_dir}/text ${data_dir}/.backup/text +cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape +cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape + +mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak +mv ${data_dir}/text ${data_dir}/text.bak +mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak +mv ${data_dir}/text_shape ${data_dir}/text_shape.bak + +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape +utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape + +rm ${data_dir}/feats.scp.bak +rm ${data_dir}/text.bak +rm ${data_dir}/speech_shape.bak +rm ${data_dir}/text_shape.bak + diff --git a/examples/aishell/paraformer/utils/parse_options.sh b/examples/aishell/paraformer/utils/parse_options.sh new file mode 100755 index 000000000..71fb9e5ea --- /dev/null +++ b/examples/aishell/paraformer/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/examples/aishell/paraformer/utils/shuffle_list.pl b/examples/aishell/paraformer/utils/shuffle_list.pl new file mode 100755 index 000000000..a116200f4 --- /dev/null +++ b/examples/aishell/paraformer/utils/shuffle_list.pl @@ -0,0 +1,44 @@ +#!/usr/bin/env perl + +# Copyright 2013 Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +if ($ARGV[0] eq "--srand") { + $n = $ARGV[1]; + $n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\""; + srand($ARGV[1]); + shift; + shift; +} else { + srand(0); # Gives inconsistent behavior if we don't seed. +} + +if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we + # don't understand. + print "Usage: shuffle_list.pl [--srand N] [input file] > output\n"; + print "randomizes the order of lines of input.\n"; + exit(1); +} + +@lines; +while (<>) { + push @lines, [ (rand(), $_)] ; +} + +@lines = sort { $a->[0] cmp $b->[0] } @lines; +foreach $l (@lines) { + print $l->[1]; +} \ No newline at end of file diff --git a/examples/aishell/paraformer/utils/split_scp.pl b/examples/aishell/paraformer/utils/split_scp.pl new file mode 100755 index 000000000..0876dcb6d --- /dev/null +++ b/examples/aishell/paraformer/utils/split_scp.pl @@ -0,0 +1,246 @@ +#!/usr/bin/env perl + +# Copyright 2010-2011 Microsoft Corporation + +# See ../../COPYING for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This program splits up any kind of .scp or archive-type file. +# If there is no utt2spk option it will work on any text file and +# will split it up with an approximately equal number of lines in +# each but. +# With the --utt2spk option it will work on anything that has the +# utterance-id as the first entry on each line; the utt2spk file is +# of the form "utterance speaker" (on each line). +# It splits it into equal size chunks as far as it can. If you use the utt2spk +# option it will make sure these chunks coincide with speaker boundaries. In +# this case, if there are more chunks than speakers (and in some other +# circumstances), some of the resulting chunks will be empty and it will print +# an error message and exit with nonzero status. +# You will normally call this like: +# split_scp.pl scp scp.1 scp.2 scp.3 ... +# or +# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ... +# Note that you can use this script to split the utt2spk file itself, +# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ... + +# You can also call the scripts like: +# split_scp.pl -j 3 0 scp scp.0 +# [note: with this option, it assumes zero-based indexing of the split parts, +# i.e. the second number must be 0 <= n < num-jobs.] + +use warnings; + +$num_jobs = 0; +$job_id = 0; +$utt2spk_file = ""; +$one_based = 0; + +for ($x = 1; $x <= 3 && @ARGV > 0; $x++) { + if ($ARGV[0] eq "-j") { + shift @ARGV; + $num_jobs = shift @ARGV; + $job_id = shift @ARGV; + } + if ($ARGV[0] =~ /--utt2spk=(.+)/) { + $utt2spk_file=$1; + shift; + } + if ($ARGV[0] eq '--one-based') { + $one_based = 1; + shift @ARGV; + } +} + +if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 || + $job_id - $one_based >= $num_jobs)) { + die "$0: Invalid job number/index values for '-j $num_jobs $job_id" . + ($one_based ? " --one-based" : "") . "'\n" +} + +$one_based + and $job_id--; + +if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) { + die +"Usage: split_scp.pl [--utt2spk=] in.scp out1.scp out2.scp ... + or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=] in.scp [out.scp] + ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n"; +} + +$error = 0; +$inscp = shift @ARGV; +if ($num_jobs == 0) { # without -j option + @OUTPUTS = @ARGV; +} else { + for ($j = 0; $j < $num_jobs; $j++) { + if ($j == $job_id) { + if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; } + else { push @OUTPUTS, "-"; } + } else { + push @OUTPUTS, "/dev/null"; + } + } +} + +if ($utt2spk_file ne "") { # We have the --utt2spk option... + open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n"; + while(<$u_fh>) { + @A = split; + @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n"; + ($u,$s) = @A; + $utt2spk{$u} = $s; + } + close $u_fh; + open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; + @spkrs = (); + while(<$i_fh>) { + @A = split; + if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; } + $u = $A[0]; + $s = $utt2spk{$u}; + defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n"; + if(!defined $spk_count{$s}) { + push @spkrs, $s; + $spk_count{$s} = 0; + $spk_data{$s} = []; # ref to new empty array. + } + $spk_count{$s}++; + push @{$spk_data{$s}}, $_; + } + # Now split as equally as possible .. + # First allocate spks to files by allocating an approximately + # equal number of speakers. + $numspks = @spkrs; # number of speakers. + $numscps = @OUTPUTS; # number of output files. + if ($numspks < $numscps) { + die "$0: Refusing to split data because number of speakers $numspks " . + "is less than the number of output .scp files $numscps\n"; + } + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scparray[$scpidx] = []; # [] is array reference. + } + for ($spkidx = 0; $spkidx < $numspks; $spkidx++) { + $scpidx = int(($spkidx*$numscps) / $numspks); + $spk = $spkrs[$spkidx]; + push @{$scparray[$scpidx]}, $spk; + $scpcount[$scpidx] += $spk_count{$spk}; + } + + # Now will try to reassign beginning + ending speakers + # to different scp's and see if it gets more balanced. + # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2. + # We can show that if considering changing just 2 scp's, we minimize + # this by minimizing the squared difference in sizes. This is + # equivalent to minimizing the absolute difference in sizes. This + # shows this method is bound to converge. + + $changed = 1; + while($changed) { + $changed = 0; + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + # First try to reassign ending spk of this scp. + if($scpidx < $numscps-1) { + $sz = @{$scparray[$scpidx]}; + if($sz > 0) { + $spk = $scparray[$scpidx]->[$sz-1]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx]; + $nutt2 = $scpcount[$scpidx+1]; + if( abs( ($nutt2+$count) - ($nutt1-$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx+1] += $count; + $scpcount[$scpidx] -= $count; + pop @{$scparray[$scpidx]}; + unshift @{$scparray[$scpidx+1]}, $spk; + $changed = 1; + } + } + } + if($scpidx > 0 && @{$scparray[$scpidx]} > 0) { + $spk = $scparray[$scpidx]->[0]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx-1]; + $nutt2 = $scpcount[$scpidx]; + if( abs( ($nutt2-$count) - ($nutt1+$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx-1] += $count; + $scpcount[$scpidx] -= $count; + shift @{$scparray[$scpidx]}; + push @{$scparray[$scpidx-1]}, $spk; + $changed = 1; + } + } + } + } + # Now print out the files... + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scpfile = $OUTPUTS[$scpidx]; + ($scpfile ne '-' ? open($f_fh, '>', $scpfile) + : open($f_fh, '>&', \*STDOUT)) || + die "$0: Could not open scp file $scpfile for writing: $!\n"; + $count = 0; + if(@{$scparray[$scpidx]} == 0) { + print STDERR "$0: eError: split_scp.pl producing empty .scp file " . + "$scpfile (too many splits and too few speakers?)\n"; + $error = 1; + } else { + foreach $spk ( @{$scparray[$scpidx]} ) { + print $f_fh @{$spk_data{$spk}}; + $count += $spk_count{$spk}; + } + $count == $scpcount[$scpidx] || die "Count mismatch [code error]"; + } + close($f_fh); + } +} else { + # This block is the "normal" case where there is no --utt2spk + # option and we just break into equal size chunks. + + open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; + + $numscps = @OUTPUTS; # size of array. + @F = (); + while(<$i_fh>) { + push @F, $_; + } + $numlines = @F; + if($numlines == 0) { + print STDERR "$0: error: empty input scp file $inscp\n"; + $error = 1; + } + $linesperscp = int( $numlines / $numscps); # the "whole part".. + $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n"; + $remainder = $numlines - ($linesperscp * $numscps); + ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder"; + # [just doing int() rounds down]. + $n = 0; + for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) { + $scpfile = $OUTPUTS[$scpidx]; + ($scpfile ne '-' ? open($o_fh, '>', $scpfile) + : open($o_fh, '>&', \*STDOUT)) || + die "$0: Could not open scp file $scpfile for writing: $!\n"; + for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) { + print $o_fh $F[$n++]; + } + close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n"; + } + $n == $numlines || die "$n != $numlines [code error]"; +} + +exit ($error); diff --git a/examples/aishell/paraformer/utils/text2token.py b/examples/aishell/paraformer/utils/text2token.py new file mode 100755 index 000000000..56c39138f --- /dev/null +++ b/examples/aishell/paraformer/utils/text2token.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + + +import argparse +import codecs +import re +import sys + +is_python2 = sys.version_info[0] == 2 + + +def exist_or_not(i, match_pos): + start_pos = None + end_pos = None + for pos in match_pos: + if pos[0] <= i < pos[1]: + start_pos = pos[0] + end_pos = pos[1] + break + + return start_pos, end_pos + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert raw text to tokenized text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--nchar", + "-n", + default=1, + type=int, + help="number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2", + ) + parser.add_argument( + "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" + ) + parser.add_argument("--space", default="", type=str, help="space symbol") + parser.add_argument( + "--non-lang-syms", + "-l", + default=None, + type=str, + help="list of non-linguistic symobles, e.g., etc.", + ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "--trans_type", + "-t", + type=str, + default="char", + choices=["char", "phn"], + help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 - + If trans_type is char, + read from SI1279.WRD file -> "bricks are an alternative" + Else if trans_type is phn, + read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l + sil t er n ih sil t ih v sil" """, + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + rs = [] + if args.non_lang_syms is not None: + with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: + nls = [x.rstrip() for x in f.readlines()] + rs = [re.compile(re.escape(x)) for x in nls] + + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")( + sys.stdout if is_python2 else sys.stdout.buffer + ) + line = f.readline() + n = args.nchar + while line: + x = line.split() + print(" ".join(x[: args.skip_ncols]), end=" ") + a = " ".join(x[args.skip_ncols :]) + + # get all matched positions + match_pos = [] + for r in rs: + i = 0 + while i >= 0: + m = r.search(a, i) + if m: + match_pos.append([m.start(), m.end()]) + i = m.end() + else: + break + + if args.trans_type == "phn": + a = a.split(" ") + else: + if len(match_pos) > 0: + chars = [] + i = 0 + while i < len(a): + start_pos, end_pos = exist_or_not(i, match_pos) + if start_pos is not None: + chars.append(a[start_pos:end_pos]) + i = end_pos + else: + chars.append(a[i]) + i += 1 + a = chars + + a = [a[j : j + n] for j in range(0, len(a), n)] + + a_flat = [] + for z in a: + a_flat.append("".join(z)) + + a_chars = [z.replace(" ", args.space) for z in a_flat] + if args.trans_type == "phn": + a_chars = [z.replace("sil", args.space) for z in a_chars] + print(" ".join(a_chars)) + line = f.readline() + + +if __name__ == "__main__": + main() diff --git a/examples/aishell/paraformer/utils/text_tokenize.py b/examples/aishell/paraformer/utils/text_tokenize.py new file mode 100755 index 000000000..962ea11bc --- /dev/null +++ b/examples/aishell/paraformer/utils/text_tokenize.py @@ -0,0 +1,106 @@ +import re +import argparse + + +def load_dict(seg_file): + seg_dict = {} + with open(seg_file, 'r') as infile: + for line in infile: + s = line.strip().split() + key = s[0] + value = s[1:] + seg_dict[key] = " ".join(value) + return seg_dict + + +def forward_segment(text, dic): + word_list = [] + i = 0 + while i < len(text): + longest_word = text[i] + for j in range(i + 1, len(text) + 1): + word = text[i:j] + if word in dic: + if len(word) > len(longest_word): + longest_word = word + word_list.append(longest_word) + i += len(longest_word) + return word_list + + +def tokenize(txt, + seg_dict): + out_txt = "" + pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])") + for word in txt: + if pattern.match(word): + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + out_txt += "" + " " + else: + continue + return out_txt.strip() + + +def get_parser(): + parser = argparse.ArgumentParser( + description="text tokenize", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--text-file", + "-t", + default=False, + required=True, + type=str, + help="input text", + ) + parser.add_argument( + "--seg-file", + "-s", + default=False, + required=True, + type=str, + help="seg file", + ) + parser.add_argument( + "--txt-index", + "-i", + default=1, + required=True, + type=int, + help="txt index", + ) + parser.add_argument( + "--output-dir", + "-o", + default=False, + required=True, + type=str, + help="output dir", + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w') + shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w') + seg_dict = load_dict(args.seg_file) + with open(args.text_file, 'r') as infile: + for line in infile: + s = line.strip().split() + text_id = s[0] + text_list = forward_segment("".join(s[1:]).lower(), seg_dict) + text = tokenize(text_list, seg_dict) + lens = len(text.strip().split()) + txt_writer.write(text_id + " " + text + '\n') + shape_writer.write(text_id + " " + str(lens) + '\n') + + +if __name__ == '__main__': + main() + diff --git a/examples/aishell/paraformer/utils/text_tokenize.sh b/examples/aishell/paraformer/utils/text_tokenize.sh new file mode 100755 index 000000000..6b74fef80 --- /dev/null +++ b/examples/aishell/paraformer/utils/text_tokenize.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + + +# Begin configuration section. +nj=32 +cmd=utils/run.pl + +echo "$0 $@" + +. utils/parse_options.sh || exit 1; + +# tokenize configuration +text_dir=$1 +seg_file=$2 +logdir=$3 +output_dir=$4 + +txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt +mkdir -p ${logdir} + +$cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \ + python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \ + -s ${seg_file} -i JOB -o ${txt_dir} \ + || exit 1; + +# concatenate the text files together. +for n in $(seq $nj); do + cat ${txt_dir}/text.$n.txt || exit 1 +done > ${output_dir}/text || exit 1 + +for n in $(seq $nj); do + cat ${txt_dir}/len.$n || exit 1 +done > ${output_dir}/text_shape || exit 1 + +echo "$0: Succeeded text tokenize" diff --git a/examples/aishell/paraformer/utils/textnorm_zh.py b/examples/aishell/paraformer/utils/textnorm_zh.py new file mode 100755 index 000000000..79feb83fd --- /dev/null +++ b/examples/aishell/paraformer/utils/textnorm_zh.py @@ -0,0 +1,834 @@ +#!/usr/bin/env python3 +# coding=utf-8 + +# Authors: +# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) +# 2019.9 Jiayu DU +# +# requirements: +# - python 3.X +# notes: python 2.X WILL fail or produce misleading results + +import sys, os, argparse, codecs, string, re + +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = u'零一二三四五六七八九' +BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖' +BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖' +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬' +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载' +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載' +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬' + +ZERO_ALT = u'〇' +ONE_ALT = u'幺' +TWO_ALTS = [u'两', u'兩'] + +POSITIVE = [u'正', u'正'] +NEGATIVE = [u'负', u'負'] +POINT = [u'点', u'點'] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +FILLER_CHARS = ['呃', '啊'] +ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \ + '胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \ + '儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \ + '佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)' + +# 中文数字系统类型 +NUMBERING_TYPES = ['low', 'mid', 'high'] + +CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \ + '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)' +CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)' +COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \ + '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \ + '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \ + '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \ + '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \ + '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)' + +# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CHINESE_PUNC_STOP = '!?。。' +CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏' +CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP + +# ================================================================================ # +# basic class +# ================================================================================ # +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + #self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return '10^{}'.format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit(power=index + 1, + simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit(power=index + 8, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit(power=(index + 2) * 4, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit(power=pow(2, index + 3), + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + else: + raise ValueError( + 'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type)) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) + larger_units = [CNU.create(i, v, numbering_type, False) + for i, v in enumerate(all_larger_units)] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) + smaller_units = [CNU.create(i, v, small_unit=True) + for i, v in enumerate(all_smaller_units)] + # digis + chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x) + point_cn = CM(POINT[0], POINT[1], '.', lambda x, + y: float(str(x) + '.' + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, '' + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], \ + [get_symbol(c, system) for c in dec_string] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None)) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power: + result[-i - 1] = CNU(result[-i - 1].power + + current_unit.power, None, None, None, None) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * + pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = ''.join([str(d.value) for d in dec_part]) + if dec_part: + return '{0}.{1}'.format(int_str, dec_str) + else: + return int_str + + +def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False, + traditional=False, alt_zero=False, alt_one=False, alt_two=True, + use_zeros=True, use_units=True): + + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip('0') + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next(u for u in reversed( + system.units) if u.power < len(striped_string)) + result_string = value_string[:-result_unit.power] + return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:]) + + system = create_system(numbering_type) + + int_dec = number_string.split('.') + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string)) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, + system.digits[2].big_s, system.digits[2].big_t) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = result_symbols[i + + 1] if i < len(result_symbols) - 1 else None + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): + if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = 'big_' + if traditional: + attr_name += 't' + else: + attr_name += 's' + else: + if traditional: + attr_name = 'traditional' + else: + attr_name = 'simplified' + + result = ''.join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \ + result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]: + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split('-') + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sil_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + else: + sp_parts = self.telephone.strip('+').split() + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sp_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split('分之') + return chn2num(numerator) + '/' + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split('/') + return num2chn(denominator) + '分之' + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split('年', 1) + year = Digit(digit=year).digit2chntext() + '年' + except ValueError: + other = date + year = '' + if other: + try: + month, day = other.strip().split('月', 1) + month = Cardinal(cardinal=month).cardinal2chntext() + '月' + except ValueError: + day = date + month = '' + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = '' + day = '' + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r'(\d+(\.\d+)?)') + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip('百分之')) + '%' + + def percentage2chntext(self): + return '百分之' + num2chn(self.percentage.strip().strip('%')) + + +def remove_erhua(text, er_whitelist): + """ + 去除儿化音词中的儿: + 他女儿在那边儿 -> 他女儿在那边 + """ + + er_pattern = re.compile(er_whitelist) + new_str='' + while re.search('儿',text): + a = re.search('儿',text).span() + remove_er_flag = 0 + + if er_pattern.search(text): + b = er_pattern.search(text).span() + if b[0] <= a[0]: + remove_er_flag = 1 + + if remove_er_flag == 0 : + new_str = new_str + text[0:a[0]] + text = text[a[1]:] + else: + new_str = new_str + text[0:b[1]] + text = text[b[1]:] + + text = new_str + text + return text + +# ================================================================================ # +# NSW Normalizer +# ================================================================================ # +class NSWNormalizer: + def __init__(self, raw_text): + self.raw_text = '^' + raw_text + '$' + self.norm_text = '' + + def _particular(self): + text = self.norm_text + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1) + self.norm_text = text + return self.norm_text + + def normalize(self): + text = self.raw_text + + # 规范化日期 + pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + #print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") + matchers = pattern.findall(text) + if matchers: + #print('money') + for matcher in matchers: + text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + #print('telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + #print('fraction') + for matcher in matchers: + text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) + + # 规范化百分数 + text = text.replace('%', '%') + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + #print('percentage') + for matcher in matchers: + text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + #print('cardinal+quantifier') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + #print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + #print('cardinal') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + self.norm_text = text + self._particular() + + return self.norm_text.lstrip('^').rstrip('$') + + +def nsw_test_case(raw_text): + print('I:' + raw_text) + print('O:' + NSWNormalizer(raw_text).normalize()) + print('') + + +def nsw_test(): + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('手机:+86 19859213959或15659451527。') + nsw_test_case('分数:32477/76391。') + nsw_test_case('百分数:80.03%。') + nsw_test_case('编号:31520181154418。') + nsw_test_case('纯数:2983.07克或12345.60米。') + nsw_test_case('日期:1999年2月20日或09年3月15号。') + nsw_test_case('金钱:12块5,34.5元,20.1万') + nsw_test_case('特殊:O2O或B2C。') + nsw_test_case('3456万吨') + nsw_test_case('2938个') + nsw_test_case('938') + nsw_test_case('今天吃了115个小笼包231个馒头') + nsw_test_case('有62%的概率') + + +if __name__ == '__main__': + #nsw_test() + + p = argparse.ArgumentParser() + p.add_argument('ifile', help='input filename, assume utf-8 encoding') + p.add_argument('ofile', help='output filename') + p.add_argument('--to_upper', action='store_true', help='convert to upper case') + p.add_argument('--to_lower', action='store_true', help='convert to lower case') + p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.") + p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "呃, 啊"') + p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "这儿"') + p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines') + args = p.parse_args() + + ifile = codecs.open(args.ifile, 'r', 'utf8') + ofile = codecs.open(args.ofile, 'w+', 'utf8') + + n = 0 + for l in ifile: + key = '' + text = '' + if args.has_key: + cols = l.split(maxsplit=1) + key = cols[0] + if len(cols) == 2: + text = cols[1].strip() + else: + text = '' + else: + text = l.strip() + + # cases + if args.to_upper and args.to_lower: + sys.stderr.write('text norm: to_upper OR to_lower?') + exit(1) + if args.to_upper: + text = text.upper() + if args.to_lower: + text = text.lower() + + # Filler chars removal + if args.remove_fillers: + for ch in FILLER_CHARS: + text = text.replace(ch, '') + + if args.remove_erhua: + text = remove_erhua(text, ER_WHITELIST) + + # NSW(Non-Standard-Word) normalization + text = NSWNormalizer(text).normalize() + + # Punctuations removal + old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations + new_chars = ' ' * len(old_chars) + del_chars = '' + text = text.translate(str.maketrans(old_chars, new_chars, del_chars)) + + # + if args.has_key: + ofile.write(key + '\t' + text + '\n') + else: + ofile.write(text + '\n') + + n += 1 + if n % args.log_interval == 0: + sys.stderr.write("text norm: {} lines done.\n".format(n)) + + sys.stderr.write("text norm: {} lines done in total.\n".format(n)) + + ifile.close() + ofile.close() diff --git a/examples/aishell/run.sh b/examples/aishell/run.sh deleted file mode 100644 index 786691f67..000000000 --- a/examples/aishell/run.sh +++ /dev/null @@ -1,9 +0,0 @@ - -cmd="funasr_cli/cli/train_cli.py" - -python $cmd \ ---config-path "/Users/zhifu/funasr_github/test_local/funasr_cli_egs" \ ---config-name "config.yaml" \ -+token_list="/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \ -+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \ -+output_dir="/nfs/zhifu.gzf/ckpt/funasr2/exp1" \ No newline at end of file diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh index 7d4ea9474..5fc7481d9 100644 --- a/examples/industrial_data_pretraining/paraformer/finetune.sh +++ b/examples/industrial_data_pretraining/paraformer/finetune.sh @@ -5,7 +5,14 @@ #local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch #git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path} - +## generate jsonl from wav.scp and text.txt +#python funasr/datasets/audio_datasets/scp2jsonl.py \ +#++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ +#++data_type_list='["source", "target"]' \ +#++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl +# torchrun \ +# --nnodes 1 \ +# --nproc_per_node 1 \ python funasr/bin/train.py \ +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ +model_revision="v2.0.4" \ diff --git a/funasr/bin/compute_audio_cmvn.py b/funasr/bin/compute_audio_cmvn.py new file mode 100644 index 000000000..b66bb14d6 --- /dev/null +++ b/funasr/bin/compute_audio_cmvn.py @@ -0,0 +1,123 @@ +import os +import json +import numpy as np +import torch +import hydra +import logging +from omegaconf import DictConfig, OmegaConf + +from funasr.register import tables +from funasr.download.download_from_hub import download_model +from funasr.train_utils.set_all_random_seed import set_all_random_seed + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(kwargs: DictConfig): + if kwargs.get("debug", False): + import pdb; pdb.set_trace() + + assert "model" in kwargs + if "model_conf" not in kwargs: + logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms"))) + kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) + + + main(**kwargs) + + +def main(**kwargs): + print(kwargs) + # set random seed + tables.print() + set_all_random_seed(kwargs.get("seed", 0)) + torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) + torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) + torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) + + + + + tokenizer = kwargs.get("tokenizer", None) + + # build frontend if frontend is none None + frontend = kwargs.get("frontend", None) + if frontend is not None: + frontend_class = tables.frontend_classes.get(frontend) + frontend = frontend_class(**kwargs["frontend_conf"]) + kwargs["frontend"] = frontend + kwargs["input_size"] = frontend.output_size() + + + + # dataset + dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) + dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf")) + + # dataloader + batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") + batch_sampler_train = None + if batch_sampler is not None: + batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) + dataset_conf = kwargs.get("dataset_conf") + dataset_conf["batch_type"] = "example" + dataset_conf["batch_size"] = 1 + batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf) + + + dataloader_train = torch.utils.data.DataLoader(dataset_train, + collate_fn=dataset_train.collator, + batch_sampler=batch_sampler_train, + num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)), + pin_memory=True) + + iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train)) + + total_frames = 0 + for batch_idx, batch in enumerate(dataloader_train): + if batch_idx >= iter_stop: + break + + fbank = batch["speech"].numpy()[0, :, :] + if total_frames == 0: + mean_stats = fbank + var_stats = np.square(fbank) + else: + mean_stats += np.sum(fbank, axis=0) + var_stats += np.sum(np.square(fbank), axis=0) + total_frames += fbank.shape[0] + + + cmvn_info = { + 'mean_stats': list(mean_stats.tolist()), + 'var_stats': list(var_stats.tolist()), + 'total_frames': total_frames + } + cmvn_file = kwargs.get("cmvn_file", "cmvn.json") + with open(cmvn_file, 'w') as fout: + fout.write(json.dumps(cmvn_info)) + + mean = -1.0 * mean_stats / total_frames + var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean) + dims = mean.shape[0] + am_mvn = os.path.dirname(cmvn_file) + "/am.mvn" + with open(am_mvn, 'w') as fout: + fout.write("" + "\n" + " " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + " " + str(dims) + " " + str(dims) + "\n") + mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]') + fout.write(" 0 " + mean_str + '\n') + fout.write(" " + str(dims) + " " + str(dims) + '\n') + var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]') + fout.write(" 0 " + var_str + '\n') + fout.write("" + '\n') + + + +if __name__ == "__main__": + main_hydra() + """ + python funasr/bin/compute_status.py \ + --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \ + --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \ + ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \ + ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \ + ++dataset_conf.num_workers=32 + """ \ No newline at end of file diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 8ea0c0db5..c9a4a6784 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- + import os import sys import torch @@ -144,9 +147,8 @@ def main(**kwargs): # dataset dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) - dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) - dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, - **kwargs.get("dataset_conf")) + dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf")) + dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf")) # dataloader batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py index ebb72a327..62acb44af 100644 --- a/funasr/datasets/audio_datasets/datasets.py +++ b/funasr/datasets/audio_datasets/datasets.py @@ -19,7 +19,7 @@ class AudioDataset(torch.utils.data.Dataset): **kwargs): super().__init__() index_ds_class = tables.index_ds_classes.get(index_ds) - self.index_ds = index_ds_class(path) + self.index_ds = index_ds_class(path, **kwargs) preprocessor_speech = kwargs.get("preprocessor_speech", None) if preprocessor_speech: preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech) @@ -63,9 +63,14 @@ class AudioDataset(torch.utils.data.Dataset): target = item["target"] if self.preprocessor_text: target = self.preprocessor_text(target) - ids = self.tokenizer.encode(target) + if self.tokenizer: + ids = self.tokenizer.encode(target) + text = torch.tensor(ids, dtype=torch.int64) + else: + ids = target + text = ids ids_lengths = len(ids) - text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32) + text_lengths = torch.tensor([ids_lengths], dtype=torch.int32) return {"speech": speech[0, :, :], "speech_lengths": speech_lengths, @@ -83,11 +88,13 @@ class AudioDataset(torch.utils.data.Dataset): outputs[key].append(sample[key]) for key, data_list in outputs.items(): - if data_list[0].dtype == torch.int64: - - pad_value = self.int_pad_value - else: - pad_value = self.float_pad_value - outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) + if isinstance(data_list[0], torch.Tensor): + if data_list[0].dtype == torch.int64: + + pad_value = self.int_pad_value + else: + pad_value = self.float_pad_value + + outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) return outputs diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py index 008b08ff1..12ffd23d5 100644 --- a/funasr/datasets/audio_datasets/index_ds.py +++ b/funasr/datasets/audio_datasets/index_ds.py @@ -1,6 +1,9 @@ +import os import json import torch import logging +import concurrent.futures +import librosa import torch.distributed as dist from funasr.register import tables @@ -71,9 +74,19 @@ class IndexDSJsonlRankSplit(torch.utils.data.Dataset): @tables.register("index_ds_classes", "IndexDSJsonlRankFull") class IndexDSJsonlRankFull(torch.utils.data.Dataset): - def __init__(self, path): + def __init__(self, path: str, **kwargs): super().__init__() + if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans + from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list + jsonl_outdir = os.path.dirname(path[0]) + jsonl_name = "datalist_train.jsonl" if kwargs.get("is_training", True) else "datalist_val.jsonl" + jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name) + if not os.path.exists(jsonl_file_out): + print(f"datalist is: {path}, generate jsonl from it") + gen_jsonl_from_wav_text_list(path, jsonl_file_out=jsonl_file_out, **kwargs) + path = jsonl_file_out + contents = [] with open(path, encoding='utf-8') as fin: for line in fin: diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py new file mode 100644 index 000000000..c60c6f577 --- /dev/null +++ b/funasr/datasets/audio_datasets/scp2jsonl.py @@ -0,0 +1,94 @@ +import os +import json +import torch +import logging +import hydra +from omegaconf import DictConfig, OmegaConf +import concurrent.futures +import librosa +import torch.distributed as dist + + + +def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs): + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = 0 + world_size = 1 + + cpu_cores = os.cpu_count() or 1 + + if rank == 0: + json_dict = {} + for data_type, data_file in zip(data_type_list, path): + json_dict[data_type] = {} + with open(data_file, "r") as f: + + data_file_lists = f.readlines() + lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1 + task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1 + with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor: + + futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)] + + for future in concurrent.futures.as_completed(futures): + + json_dict[data_type].update(future.result()) + # print(json_dict) + + with open(jsonl_file_out, "w") as f: + for key in json_dict[data_type_list[0]].keys(): + jsonl_line = {"key": key} + for data_file in data_type_list: + jsonl_line.update(json_dict[data_file][key]) + jsonl_line = json.dumps(jsonl_line, ensure_ascii=False) + f.write(jsonl_line+"\n") + f.flush() + + else: + pass + + if world_size > 1: + dist.barrier() + + +def parse_context_length(data_list: list, data_type: str): + + res = {} + for i, line in enumerate(data_list): + key, line = line.strip().split(maxsplit=1) + line = line.strip() + if os.path.exists(line): + waveform, _ = librosa.load(line, sr=16000) + sample_num = len(waveform) + context_len = int(sample_num//16000*1000/10) + else: + context_len = len(line) + res[key] = {data_type: line, f"{data_type}_len": context_len} + return res + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + """ + python funasr/datasets/audio_datasets/scp2jsonl.py \ + ++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ + ++data_type_list='["source", "target"]' \ + ++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl + + """ + + kwargs = OmegaConf.to_container(cfg, resolve=True) + + scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt")) + data_type_list = kwargs.get("data_type_list", ("source", "target")) + jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl") + gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out) + + +if __name__ == "__main__": + main_hydra() + + \ No newline at end of file diff --git a/funasr/download/download_dataset_from_hub.py b/funasr/download/download_dataset_from_hub.py new file mode 100644 index 000000000..d06d8213a --- /dev/null +++ b/funasr/download/download_dataset_from_hub.py @@ -0,0 +1,11 @@ + +def download_dataset(): + pass + +def download_dataset_from_ms(**kwargs): + from modelscope.msdatasets import MsDataset + dataset_name = kwargs.get("dataset_name", 'speech_asr/speech_asr_aishell1_trainsets') + subset_name = kwargs.get("subset_name", 'default') + split = kwargs.get("split", 'train') + data_dump_dir = kwargs.get("data_dump_dir", None) + ds = MsDataset.load(dataset_name=dataset_name, subset_name=subset_name, split=split, cache_dir=data_dump_dir) \ No newline at end of file diff --git a/funasr/models/model_hf/__init__.py b/funasr/models/model_hf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/register.py b/funasr/register.py index 454105f1d..ebfdaacf2 100644 --- a/funasr/register.py +++ b/funasr/register.py @@ -52,8 +52,8 @@ class RegisterTables: registry = getattr(self, register_tables_key) registry_key = key if key is not None else target_class.__name__ - assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format( - registry_key, target_class, register_tables_key) + # assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format( + # registry_key, target_class, register_tables_key) registry[registry_key] = target_class diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index d144019aa..3cd61a155 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -204,25 +204,25 @@ class Trainer: my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext with my_context(): time2 = time.perf_counter() - print("before, GPU, memory: {:.1} MB, " - "{:.1} MB, " - "{:.1} MB, " - "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024, - torch.cuda.max_memory_allocated()/1024/1024/1024, - torch.cuda.memory_reserved()/1024/1024/1024, - torch.cuda.max_memory_reserved()/1024/1024/1024, - )) + # print("before, GPU, memory: {:.3f} GB, " + # "{:.3f} GB, " + # "{:.3f} GB, " + # "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024, + # torch.cuda.max_memory_allocated()/1024/1024/1024, + # torch.cuda.memory_reserved()/1024/1024/1024, + # torch.cuda.max_memory_reserved()/1024/1024/1024, + # )) retval = self.model(**batch) torch.cuda.empty_cache() - print("after, GPU, memory: {:.1} MB, " - "{:.1} MB, " - "{:.1} MB, " - "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024, - torch.cuda.max_memory_allocated()/1024/1024/1024, - torch.cuda.memory_reserved()/1024/1024/1024, - torch.cuda.max_memory_reserved()/1024/1024/1024, - )) + # print("after, GPU, memory: {:.3f} GB, " + # "{:.3f} GB, " + # "{:.3f} GB, " + # "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024, + # torch.cuda.max_memory_allocated()/1024/1024/1024, + # torch.cuda.memory_reserved()/1024/1024/1024, + # torch.cuda.max_memory_reserved()/1024/1024/1024, + # )) time3 = time.perf_counter() speed_stats["forward_time"] = f"{time3 - time2:0.3f}" loss, stats, weight = retval @@ -275,12 +275,21 @@ class Trainer: pbar.update(1) if self.local_rank == 0: + gpu_info = "GPU, memory: {:.3f} GB, " \ + "{:.3f} GB, "\ + "{:.3f} GB, "\ + "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024, + torch.cuda.max_memory_allocated()/1024/1024/1024, + torch.cuda.memory_reserved()/1024/1024/1024, + torch.cuda.max_memory_reserved()/1024/1024/1024, + ) description = ( f"Train epoch: {epoch}/{self.max_epoch}, " f"step {batch_idx}/{len(self.dataloader_train)}, " f"{speed_stats}, " f"(loss: {loss.detach().cpu().item():.3f}), " f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}" + f"{gpu_info}" ) pbar.set_description(description) if self.writer: