From a73123bcfc14370b74b17084bc124f00c48613e4 Mon Sep 17 00:00:00 2001 From: smohan-speech Date: Sat, 6 May 2023 16:17:48 +0800 Subject: [PATCH] add speaker-attributed ASR task for alimeeting --- egs/alimeeting/sa-asr/asr_local.sh | 1562 +++++++++++++++++ egs/alimeeting/sa-asr/asr_local_infer.sh | 590 +++++++ .../sa-asr/conf/decode_asr_rnn.yaml | 6 + .../sa-asr/conf/train_asr_conformer.yaml | 88 + .../sa-asr/conf/train_lm_transformer.yaml | 29 + .../sa-asr/conf/train_sa_asr_conformer.yaml | 116 ++ .../sa-asr/local/alimeeting_data_prep.sh | 162 ++ .../local/alimeeting_data_prep_test_2023.sh | 129 ++ .../local/alimeeting_process_overlap_force.py | 235 +++ .../local/alimeeting_process_textgrid.py | 158 ++ egs/alimeeting/sa-asr/local/compute_cpcer.py | 91 + egs/alimeeting/sa-asr/local/compute_wer.py | 157 ++ .../sa-asr/local/download_xvector_model.py | 6 + .../sa-asr/local/filter_utt2spk_all_fifo.py | 22 + .../sa-asr/local/gen_cluster_profile_infer.py | 167 ++ .../sa-asr/local/gen_oracle_embedding.py | 70 + .../local/gen_oracle_profile_nopadding.py | 59 + .../local/gen_oracle_profile_padding.py | 68 + egs/alimeeting/sa-asr/local/proce_text.py | 32 + .../local/process_sot_fifo_textchar2spk.py | 86 + .../sa-asr/local/process_text_id.py | 24 + .../sa-asr/local/process_text_spk_merge.py | 55 + .../process_textgrid_to_single_speaker_wav.py | 127 ++ egs/alimeeting/sa-asr/local/text_format.pl | 14 + egs/alimeeting/sa-asr/local/text_normalize.pl | 38 + egs/alimeeting/sa-asr/path.sh | 6 + .../sa-asr/pyscripts/audio/format_wav_scp.py | 243 +++ .../sa-asr/pyscripts/utils/print_args.py | 45 + egs/alimeeting/sa-asr/run_m2met_2023.sh | 51 + egs/alimeeting/sa-asr/run_m2met_2023_infer.sh | 50 + .../sa-asr/scripts/audio/format_wav_scp.sh | 142 ++ .../scripts/utils/perturb_data_dir_speed.sh | 116 ++ egs/alimeeting/sa-asr/utils/apply_map.pl | 97 + egs/alimeeting/sa-asr/utils/combine_data.sh | 146 ++ egs/alimeeting/sa-asr/utils/copy_data_dir.sh | 145 ++ .../sa-asr/utils/data/get_reco2dur.sh | 143 ++ .../utils/data/get_segments_for_data.sh | 29 + .../sa-asr/utils/data/get_utt2dur.sh | 135 ++ .../sa-asr/utils/data/split_data.sh | 160 ++ egs/alimeeting/sa-asr/utils/filter_scp.pl | 87 + egs/alimeeting/sa-asr/utils/fix_data_dir.sh | 215 +++ egs/alimeeting/sa-asr/utils/parse_options.sh | 97 + .../sa-asr/utils/spk2utt_to_utt2spk.pl | 27 + egs/alimeeting/sa-asr/utils/split_scp.pl | 246 +++ .../sa-asr/utils/utt2spk_to_spk2utt.pl | 38 + .../sa-asr/utils/validate_data_dir.sh | 404 +++++ egs/alimeeting/sa-asr/utils/validate_text.pl | 136 ++ funasr/bin/asr_inference.py | 34 +- funasr/bin/asr_inference_launch.py | 5 +- funasr/bin/asr_train.py | 15 +- funasr/bin/sa_asr_inference.py | 674 +++++++ funasr/bin/sa_asr_train.py | 55 + funasr/fileio/sound_scp.py | 6 +- funasr/losses/nll_loss.py | 47 + funasr/models/decoder/decoder_layer_sa_asr.py | 169 ++ .../decoder/transformer_decoder_sa_asr.py | 291 +++ funasr/models/e2e_sa_asr.py | 521 ++++++ funasr/models/frontend/default.py | 11 +- funasr/models/pooling/statistic_pooling.py | 4 +- funasr/modules/attention.py | 37 + .../modules/beam_search/beam_search_sa_asr.py | 525 ++++++ funasr/tasks/abs_task.py | 16 +- funasr/tasks/sa_asr.py | 623 +++++++ funasr/utils/postprocess_utils.py | 7 +- setup.py | 7 +- 65 files changed, 9859 insertions(+), 37 deletions(-) create mode 100755 egs/alimeeting/sa-asr/asr_local.sh create mode 100755 egs/alimeeting/sa-asr/asr_local_infer.sh create mode 100644 egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml create mode 100644 egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml create mode 100644 egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml create mode 100644 egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml create mode 100755 egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh create mode 100755 egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh create mode 100755 egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py create mode 100755 egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py create mode 100644 egs/alimeeting/sa-asr/local/compute_cpcer.py create mode 100755 egs/alimeeting/sa-asr/local/compute_wer.py create mode 100644 egs/alimeeting/sa-asr/local/download_xvector_model.py create mode 100644 egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py create mode 100644 egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py create mode 100644 egs/alimeeting/sa-asr/local/gen_oracle_embedding.py create mode 100644 egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py create mode 100644 egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py create mode 100755 egs/alimeeting/sa-asr/local/proce_text.py create mode 100755 egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py create mode 100644 egs/alimeeting/sa-asr/local/process_text_id.py create mode 100644 egs/alimeeting/sa-asr/local/process_text_spk_merge.py create mode 100755 egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py create mode 100755 egs/alimeeting/sa-asr/local/text_format.pl create mode 100755 egs/alimeeting/sa-asr/local/text_normalize.pl create mode 100755 egs/alimeeting/sa-asr/path.sh create mode 100755 egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py create mode 100755 egs/alimeeting/sa-asr/pyscripts/utils/print_args.py create mode 100755 egs/alimeeting/sa-asr/run_m2met_2023.sh create mode 100755 egs/alimeeting/sa-asr/run_m2met_2023_infer.sh create mode 100755 egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh create mode 100755 egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh create mode 100755 egs/alimeeting/sa-asr/utils/apply_map.pl create mode 100755 egs/alimeeting/sa-asr/utils/combine_data.sh create mode 100755 egs/alimeeting/sa-asr/utils/copy_data_dir.sh create mode 100755 egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh create mode 100755 egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh create mode 100755 egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh create mode 100755 egs/alimeeting/sa-asr/utils/data/split_data.sh create mode 100755 egs/alimeeting/sa-asr/utils/filter_scp.pl create mode 100755 egs/alimeeting/sa-asr/utils/fix_data_dir.sh create mode 100755 egs/alimeeting/sa-asr/utils/parse_options.sh create mode 100755 egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl create mode 100755 egs/alimeeting/sa-asr/utils/split_scp.pl create mode 100755 egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl create mode 100755 egs/alimeeting/sa-asr/utils/validate_data_dir.sh create mode 100755 egs/alimeeting/sa-asr/utils/validate_text.pl create mode 100644 funasr/bin/sa_asr_inference.py create mode 100755 funasr/bin/sa_asr_train.py create mode 100644 funasr/losses/nll_loss.py create mode 100644 funasr/models/decoder/decoder_layer_sa_asr.py create mode 100644 funasr/models/decoder/transformer_decoder_sa_asr.py create mode 100644 funasr/models/e2e_sa_asr.py create mode 100755 funasr/modules/beam_search/beam_search_sa_asr.py create mode 100644 funasr/tasks/sa_asr.py diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh new file mode 100755 index 000000000..c0359eb35 --- /dev/null +++ b/egs/alimeeting/sa-asr/asr_local.sh @@ -0,0 +1,1562 @@ +#!/usr/bin/env bash + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} +min() { + local a b + a=$1 + for b in "$@"; do + if [ "${b}" -le "${a}" ]; then + a="${b}" + fi + done + echo "${a}" +} +SECONDS=0 + +# General configuration +stage=1 # Processes starts from the specified stage. +stop_stage=10000 # Processes is stopped at the specified stage. +skip_data_prep=false # Skip data preparation stages. +skip_train=false # Skip training stages. +skip_eval=false # Skip decoding and evaluation stages. +skip_upload=true # Skip packing and uploading stages. +ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu). +num_nodes=1 # The number of nodes. +nj=16 # The number of parallel jobs. +inference_nj=16 # The number of parallel jobs in decoding. +gpu_inference=false # Whether to perform gpu decoding. +njob_infer=4 +dumpdir=dump2 # Directory to dump features. +expdir=exp # Directory to save experiments. +python=python3 # Specify python to execute espnet commands. +device=0 + +# Data preparation related +local_data_opts= # The options given to local/data.sh. + +# Speed perturbation related +speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space). + +# Feature extraction related +feats_type=raw # Feature type (raw or fbank_pitch). +audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw). +fs=16000 # Sampling rate. +min_wav_duration=0.1 # Minimum duration in second. +max_wav_duration=20 # Maximum duration in second. + +# Tokenization related +token_type=bpe # Tokenization type (char or bpe). +nbpe=30 # The number of BPE vocabulary. +bpemode=unigram # Mode of BPE (unigram or bpe). +oov="" # Out of vocabulary symbol. +blank="" # CTC blank symbol +sos_eos="" # sos and eos symbole +bpe_input_sentence_size=100000000 # Size of input sentence for BPE. +bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE +bpe_char_cover=1.0 # character coverage when modeling BPE + +# Language model related +use_lm=true # Use language model for ASR decoding. +lm_tag= # Suffix to the result dir for language model training. +lm_exp= # Specify the direcotry path for LM experiment. + # If this option is specified, lm_tag is ignored. +lm_stats_dir= # Specify the direcotry path for LM statistics. +lm_config= # Config for language model training. +lm_args= # Arguments for language model training, e.g., "--max_epoch 10". + # Note that it will overwrite args in lm config. +use_word_lm=false # Whether to use word language model. +num_splits_lm=1 # Number of splitting for lm corpus. +# shellcheck disable=SC2034 +word_vocab_size=10000 # Size of word vocabulary. + +# ASR model related +asr_tag= # Suffix to the result dir for asr model training. +asr_exp= # Specify the direcotry path for ASR experiment. + # If this option is specified, asr_tag is ignored. +sa_asr_exp= +asr_stats_dir= # Specify the direcotry path for ASR statistics. +asr_config= # Config for asr model training. +sa_asr_config= +asr_args= # Arguments for asr model training, e.g., "--max_epoch 10". + # Note that it will overwrite args in asr config. +feats_normalize=global_mvn # Normalizaton layer type. +num_splits_asr=1 # Number of splitting for lm corpus. + +# Decoding related +inference_tag= # Suffix to the result dir for decoding. +inference_config= # Config for decoding. +inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1". + # Note that it will overwrite args in inference config. +sa_asr_inference_tag= +sa_asr_inference_args= + +inference_lm=valid.loss.ave.pb # Language modle path for decoding. +inference_asr_model=valid.acc.ave.pb # ASR model path for decoding. + # e.g. + # inference_asr_model=train.loss.best.pth + # inference_asr_model=3epoch.pth + # inference_asr_model=valid.acc.best.pth + # inference_asr_model=valid.loss.ave.pth +inference_sa_asr_model=valid.acc_spk.ave.pb +download_model= # Download a model from Model Zoo and use it for decoding. + +# [Task dependent] Set the datadir name created by local/data.sh +train_set= # Name of training set. +valid_set= # Name of validation set used for monitoring/tuning network training. +test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified. +bpe_train_text= # Text file path of bpe training set. +lm_train_text= # Text file path of language model training set. +lm_dev_text= # Text file path of language model development set. +lm_test_text= # Text file path of language model evaluation set. +nlsyms_txt=none # Non-linguistic symbol list if existing. +cleaner=none # Text cleaner. +g2p=none # g2p method (needed if token_type=phn). +lang=zh # The language type of corpus. +score_opts= # The options given to sclite scoring +local_score_opts= # The options given to local/score.sh. + + +help_message=$(cat << EOF +Usage: $0 --train-set "" --valid-set "" --test_sets "" + +Options: + # General configuration + --stage # Processes starts from the specified stage (default="${stage}"). + --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}"). + --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}"). + --skip_train # Skip training stages (default="${skip_train}"). + --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}"). + --skip_upload # Skip packing and uploading stages (default="${skip_upload}"). + --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}"). + --num_nodes # The number of nodes (default="${num_nodes}"). + --nj # The number of parallel jobs (default="${nj}"). + --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}"). + --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}"). + --dumpdir # Directory to dump features (default="${dumpdir}"). + --expdir # Directory to save experiments (default="${expdir}"). + --python # Specify python to execute espnet commands (default="${python}"). + --device # Which GPUs are use for local training (defalut="${device}"). + + # Data preparation related + --local_data_opts # The options given to local/data.sh (default="${local_data_opts}"). + + # Speed perturbation related + --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}"). + + # Feature extraction related + --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}"). + --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}"). + --fs # Sampling rate (default="${fs}"). + --min_wav_duration # Minimum duration in second (default="${min_wav_duration}"). + --max_wav_duration # Maximum duration in second (default="${max_wav_duration}"). + + # Tokenization related + --token_type # Tokenization type (char or bpe, default="${token_type}"). + --nbpe # The number of BPE vocabulary (default="${nbpe}"). + --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}"). + --oov # Out of vocabulary symbol (default="${oov}"). + --blank # CTC blank symbol (default="${blank}"). + --sos_eos # sos and eos symbole (default="${sos_eos}"). + --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}"). + --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}"). + --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}"). + + # Language model related + --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}"). + --lm_exp # Specify the direcotry path for LM experiment. + # If this option is specified, lm_tag is ignored (default="${lm_exp}"). + --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}"). + --lm_config # Config for language model training (default="${lm_config}"). + --lm_args # Arguments for language model training (default="${lm_args}"). + # e.g., --lm_args "--max_epoch 10" + # Note that it will overwrite args in lm config. + --use_word_lm # Whether to use word language model (default="${use_word_lm}"). + --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}"). + --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}"). + + # ASR model related + --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}"). + --asr_exp # Specify the direcotry path for ASR experiment. + # If this option is specified, asr_tag is ignored (default="${asr_exp}"). + --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}"). + --asr_config # Config for asr model training (default="${asr_config}"). + --asr_args # Arguments for asr model training (default="${asr_args}"). + # e.g., --asr_args "--max_epoch 10" + # Note that it will overwrite args in asr config. + --feats_normalize # Normalizaton layer type (default="${feats_normalize}"). + --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}"). + + # Decoding related + --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}"). + --inference_config # Config for decoding (default="${inference_config}"). + --inference_args # Arguments for decoding (default="${inference_args}"). + # e.g., --inference_args "--lm_weight 0.1" + # Note that it will overwrite args in inference config. + --inference_lm # Language modle path for decoding (default="${inference_lm}"). + --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}"). + --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}"). + + # [Task dependent] Set the datadir name created by local/data.sh + --train_set # Name of training set (required). + --valid_set # Name of validation set used for monitoring/tuning network training (required). + --test_sets # Names of test sets. + # Multiple items (e.g., both dev and eval sets) can be specified (required). + --bpe_train_text # Text file path of bpe training set. + --lm_train_text # Text file path of language model training set. + --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}"). + --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}"). + --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}"). + --cleaner # Text cleaner (default="${cleaner}"). + --g2p # g2p method (default="${g2p}"). + --lang # The language type of corpus (default=${lang}). + --score_opts # The options given to sclite scoring (default="{score_opts}"). + --local_score_opts # The options given to local/score.sh (default="{local_score_opts}"). +EOF +) + +log "$0 $*" +# Save command line args for logging (they will be lost after utils/parse_options.sh) +run_args=$(python -m funasr.utils.cli_utils $0 "$@") +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + log "${help_message}" + log "Error: No positional arguments are required." + exit 2 +fi + +. ./path.sh + + +# Check required arguments +[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; }; +[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; }; +[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; }; + +# Check feature type +if [ "${feats_type}" = raw ]; then + data_feats=${dumpdir}/raw +elif [ "${feats_type}" = fbank_pitch ]; then + data_feats=${dumpdir}/fbank_pitch +elif [ "${feats_type}" = fbank ]; then + data_feats=${dumpdir}/fbank +elif [ "${feats_type}" == extracted ]; then + data_feats=${dumpdir}/extracted +else + log "${help_message}" + log "Error: not supported: --feats_type ${feats_type}" + exit 2 +fi + +# Use the same text as ASR for bpe training if not specified. +[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text" +# Use the same text as ASR for lm training if not specified. +[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text" +# Use the same text as ASR for lm training if not specified. +[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text" +# Use the text of the 1st evaldir if lm_test is not specified +[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text" + +# Check tokenization type +if [ "${lang}" != noinfo ]; then + token_listdir=data/${lang}_token_list +else + token_listdir=data/token_list +fi +bpedir="${token_listdir}/bpe_${bpemode}${nbpe}" +bpeprefix="${bpedir}"/bpe +bpemodel="${bpeprefix}".model +bpetoken_list="${bpedir}"/tokens.txt +chartoken_list="${token_listdir}"/char/tokens.txt +# NOTE: keep for future development. +# shellcheck disable=SC2034 +wordtoken_list="${token_listdir}"/word/tokens.txt + +if [ "${token_type}" = bpe ]; then + token_list="${bpetoken_list}" +elif [ "${token_type}" = char ]; then + token_list="${chartoken_list}" + bpemodel=none +elif [ "${token_type}" = word ]; then + token_list="${wordtoken_list}" + bpemodel=none +else + log "Error: not supported --token_type '${token_type}'" + exit 2 +fi +if ${use_word_lm}; then + log "Error: Word LM is not supported yet" + exit 2 + + lm_token_list="${wordtoken_list}" + lm_token_type=word +else + lm_token_list="${token_list}" + lm_token_type="${token_type}" +fi + + +# Set tag for naming of model directory +if [ -z "${asr_tag}" ]; then + if [ -n "${asr_config}" ]; then + asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}" + else + asr_tag="train_${feats_type}" + fi + if [ "${lang}" != noinfo ]; then + asr_tag+="_${lang}_${token_type}" + else + asr_tag+="_${token_type}" + fi + if [ "${token_type}" = bpe ]; then + asr_tag+="${nbpe}" + fi + # Add overwritten arg's info + if [ -n "${asr_args}" ]; then + asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" + fi + if [ -n "${speed_perturb_factors}" ]; then + asr_tag+="_sp" + fi +fi +if [ -z "${lm_tag}" ]; then + if [ -n "${lm_config}" ]; then + lm_tag="$(basename "${lm_config}" .yaml)" + else + lm_tag="train" + fi + if [ "${lang}" != noinfo ]; then + lm_tag+="_${lang}_${lm_token_type}" + else + lm_tag+="_${lm_token_type}" + fi + if [ "${lm_token_type}" = bpe ]; then + lm_tag+="${nbpe}" + fi + # Add overwritten arg's info + if [ -n "${lm_args}" ]; then + lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" + fi +fi + +# The directory used for collect-stats mode +if [ -z "${asr_stats_dir}" ]; then + if [ "${lang}" != noinfo ]; then + asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}" + else + asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}" + fi + if [ "${token_type}" = bpe ]; then + asr_stats_dir+="${nbpe}" + fi + if [ -n "${speed_perturb_factors}" ]; then + asr_stats_dir+="_sp" + fi +fi +if [ -z "${lm_stats_dir}" ]; then + if [ "${lang}" != noinfo ]; then + lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}" + else + lm_stats_dir="${expdir}/lm_stats_${lm_token_type}" + fi + if [ "${lm_token_type}" = bpe ]; then + lm_stats_dir+="${nbpe}" + fi +fi +# The directory used for training commands +if [ -z "${asr_exp}" ]; then + asr_exp="${expdir}/asr_${asr_tag}" +fi +if [ -z "${lm_exp}" ]; then + lm_exp="${expdir}/lm_${lm_tag}" +fi + + +if [ -z "${inference_tag}" ]; then + if [ -n "${inference_config}" ]; then + inference_tag="$(basename "${inference_config}" .yaml)" + else + inference_tag=inference + fi + # Add overwritten arg's info + if [ -n "${inference_args}" ]; then + inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")" + fi + if "${use_lm}"; then + inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" + fi + inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" +fi + +if [ -z "${sa_asr_inference_tag}" ]; then + if [ -n "${inference_config}" ]; then + sa_asr_inference_tag="$(basename "${inference_config}" .yaml)" + else + sa_asr_inference_tag=sa_asr_inference + fi + # Add overwritten arg's info + if [ -n "${sa_asr_inference_args}" ]; then + sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")" + fi + if "${use_lm}"; then + sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" + fi + sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" +fi + +train_cmd="run.pl" +cuda_cmd="run.pl" +decode_cmd="run.pl" + +# ========================== Main stages start from here. ========================== + +if ! "${skip_data_prep}"; then + + if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + log "Stage 1: Data preparation for data/${train_set}, data/${valid_set}, etc." + + ./local/alimeeting_data_prep.sh --tgt Test + ./local/alimeeting_data_prep.sh --tgt Eval + ./local/alimeeting_data_prep.sh --tgt Train + fi + + if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + if [ -n "${speed_perturb_factors}" ]; then + log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp" + for factor in ${speed_perturb_factors}; do + if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then + scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}" + _dirs+="data/${train_set}_sp${factor} " + else + # If speed factor is 1, same as the original + _dirs+="data/${train_set} " + fi + done + utils/combine_data.sh "data/${train_set}_sp" ${_dirs} + else + log "Skip stage 2: Speed perturbation" + fi + fi + + if [ -n "${speed_perturb_factors}" ]; then + train_set="${train_set}_sp" + fi + + if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + if [ "${feats_type}" = raw ]; then + log "Stage 3: Format wav.scp: data/ -> ${data_feats}" + + # ====== Recreating "wav.scp" ====== + # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |", + # shouldn't be used in training process. + # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file + # and it can also change the audio-format and sampling rate. + # If nothing is need, then format_wav_scp.sh does nothing: + # i.e. the input file format and rate is same as the output. + + for dset in "${train_set}" "${valid_set}" "${test_sets}" ; do + if [ "${dset}" = "${train_set}" ] || [ "${dset}" = "${valid_set}" ]; then + _suf="/org" + else + if [ "${dset}" = "${test_sets}" ] && [ "${test_sets}" = "Test_Ali_far" ]; then + _suf="/org" + else + _suf="" + fi + fi + utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" + + cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/" + + rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur} + _opts= + if [ -e data/"${dset}"/segments ]; then + # "segments" is used for splitting wav files which are written in "wav".scp + # into utterances. The file format of segments: + # + # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5" + # Where the time is written in seconds. + _opts+="--segments data/${dset}/segments " + fi + # shellcheck disable=SC2086 + scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ + --audio-format "${audio_format}" --fs "${fs}" ${_opts} \ + "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}" + + echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type" + done + + else + log "Error: not supported: --feats_type ${feats_type}" + exit 2 + fi + fi + + + if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + log "Stage 4: Remove long/short data: ${data_feats}/org -> ${data_feats}" + + # NOTE(kamo): Not applying to test_sets to keep original data + if [ "${test_sets}" = "Test_Ali_far" ]; then + rm_dset="${train_set} ${valid_set} ${test_sets}" + else + rm_dset="${train_set} ${valid_set}" + fi + + for dset in $rm_dset; do + + # Copy data dir + utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}" + cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type" + + # Remove short utterances + _feats_type="$(<${data_feats}/${dset}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _fs=$(python3 -c "import humanfriendly as h;print(h.parse_size('${fs}'))") + _min_length=$(python3 -c "print(int(${min_wav_duration} * ${_fs}))") + _max_length=$(python3 -c "print(int(${max_wav_duration} * ${_fs}))") + + # utt2num_samples is created by format_wav_scp.sh + <"${data_feats}/org/${dset}/utt2num_samples" \ + awk -v min_length="${_min_length}" -v max_length="${_max_length}" \ + '{ if ($2 > min_length && $2 < max_length ) print $0; }' \ + >"${data_feats}/${dset}/utt2num_samples" + <"${data_feats}/org/${dset}/wav.scp" \ + utils/filter_scp.pl "${data_feats}/${dset}/utt2num_samples" \ + >"${data_feats}/${dset}/wav.scp" + else + # Get frame shift in ms from conf/fbank.conf + _frame_shift= + if [ -f conf/fbank.conf ] && [ "$( min_length && $2 < max_length) print $0; }' \ + >"${data_feats}/${dset}/feats_shape" + <"${data_feats}/org/${dset}/feats.scp" \ + utils/filter_scp.pl "${data_feats}/${dset}/feats_shape" \ + >"${data_feats}/${dset}/feats.scp" + fi + + # Remove empty text + <"${data_feats}/org/${dset}/text" \ + awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text" + + # fix_data_dir.sh leaves only utts which exist in all files + utils/fix_data_dir.sh "${data_feats}/${dset}" + + # generate uttid + cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid" + # filter utt2spk_all_fifo + python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset} + done + + # shellcheck disable=SC2002 + cat ${lm_train_text} | awk ' { if( NF != 1 ) print $0; } ' > "${data_feats}/lm_train.txt" + fi + + + if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + log "Stage 5: Dictionary Preparation" + mkdir -p data/${lang}_token_list/char/ + + echo "make a dictionary" + echo "" > ${token_list} + echo "" >> ${token_list} + echo "" >> ${token_list} + local/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} + num_token=$(cat ${token_list} | wc -l) + echo "" >> ${token_list} + vocab_size=$(cat ${token_list} | wc -l) + fi + + if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + log "Stage 6: Generate speaker settings" + mkdir -p "profile_log" + for dset in "${train_set}" "${valid_set}" "${test_sets}"; do + # generate text_id spk2id + python local/process_sot_fifo_textchar2spk.py --path ${data_feats}/${dset} + log "Successfully generate ${data_feats}/${dset}/text_id ${data_feats}/${dset}/spk2id" + # generate text_id_train for sot + python local/process_text_id.py ${data_feats}/${dset} + log "Successfully generate ${data_feats}/${dset}/text_id_train" + # generate oracle_embedding from single-speaker audio segment + python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log" + log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)" + # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training) + if [ "${dset}" = "${train_set}" ]; then + python local/gen_oracle_profile_padding.py ${data_feats}/${dset} + log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_padding.scp)" + else + python local/gen_oracle_profile_nopadding.py ${data_feats}/${dset} + log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_nopadding.scp)" + fi + # generate cluster_profile with spectral-cluster directly (for infering and without oracle information) + if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then + python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log" + log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)" + fi + + done + fi + +else + log "Skip the stages for data preparation" +fi + + +# ========================== Data preparation is done here. ========================== + + +if ! "${skip_train}"; then + if "${use_lm}"; then + if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + log "Stage 7: LM collect stats: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}" + + _opts= + if [ -n "${lm_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.lm_train --print_config --optim adam + _opts+="--config ${lm_config} " + fi + + # 1. Split the key file + _logdir="${lm_stats_dir}/logdir" + mkdir -p "${_logdir}" + # Get the minimum number among ${nj} and the number lines of input files + _nj=$(min "${nj}" "$(<${data_feats}/lm_train.txt wc -l)" "$(<${lm_dev_text} wc -l)") + + key_file="${data_feats}/lm_train.txt" + split_scps="" + for n in $(seq ${_nj}); do + split_scps+=" ${_logdir}/train.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + key_file="${lm_dev_text}" + split_scps="" + for n in $(seq ${_nj}); do + split_scps+=" ${_logdir}/dev.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Generate run.sh + log "Generate '${lm_stats_dir}/run.sh'. You can resume the process from stage 6 using this script" + mkdir -p "${lm_stats_dir}"; echo "${run_args} --stage 6 \"\$@\"; exit \$?" > "${lm_stats_dir}/run.sh"; chmod +x "${lm_stats_dir}/run.sh" + + # 3. Submit jobs + log "LM collect-stats started... log: '${_logdir}/stats.*.log'" + # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted, + # but it's used only for deciding the sample ids. + # shellcheck disable=SC2086 + ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \ + ${python} -m funasr.bin.lm_train \ + --collect_stats true \ + --use_preprocessor true \ + --bpemodel "${bpemodel}" \ + --token_type "${lm_token_type}"\ + --token_list "${lm_token_list}" \ + --non_linguistic_symbols "${nlsyms_txt}" \ + --cleaner "${cleaner}" \ + --g2p "${g2p}" \ + --train_data_path_and_name_and_type "${data_feats}/lm_train.txt,text,text" \ + --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \ + --train_shape_file "${_logdir}/train.JOB.scp" \ + --valid_shape_file "${_logdir}/dev.JOB.scp" \ + --output_dir "${_logdir}/stats.JOB" \ + ${_opts} ${lm_args} || { cat "${_logdir}"/stats.1.log; exit 1; } + + # 4. Aggregate shape files + _opts= + for i in $(seq "${_nj}"); do + _opts+="--input_dir ${_logdir}/stats.${i} " + done + # shellcheck disable=SC2086 + ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${lm_stats_dir}" + + # Append the num-tokens at the last dimensions. This is used for batch-bins count + <"${lm_stats_dir}/train/text_shape" \ + awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \ + >"${lm_stats_dir}/train/text_shape.${lm_token_type}" + + <"${lm_stats_dir}/valid/text_shape" \ + awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \ + >"${lm_stats_dir}/valid/text_shape.${lm_token_type}" + fi + + + if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + log "Stage 8: LM Training: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}" + + _opts= + if [ -n "${lm_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.lm_train --print_config --optim adam + _opts+="--config ${lm_config} " + fi + + if [ "${num_splits_lm}" -gt 1 ]; then + # If you met a memory error when parsing text files, this option may help you. + # The corpus is split into subsets and each subset is used for training one by one in order, + # so the memory footprint can be limited to the memory required for each dataset. + + _split_dir="${lm_stats_dir}/splits${num_splits_lm}" + if [ ! -f "${_split_dir}/.done" ]; then + rm -f "${_split_dir}/.done" + ${python} -m espnet2.bin.split_scps \ + --scps "${data_feats}/lm_train.txt" "${lm_stats_dir}/train/text_shape.${lm_token_type}" \ + --num_splits "${num_splits_lm}" \ + --output_dir "${_split_dir}" + touch "${_split_dir}/.done" + else + log "${_split_dir}/.done exists. Spliting is skipped" + fi + + _opts+="--train_data_path_and_name_and_type ${_split_dir}/lm_train.txt,text,text " + _opts+="--train_shape_file ${_split_dir}/text_shape.${lm_token_type} " + _opts+="--multiple_iterator true " + + else + _opts+="--train_data_path_and_name_and_type ${data_feats}/lm_train.txt,text,text " + _opts+="--train_shape_file ${lm_stats_dir}/train/text_shape.${lm_token_type} " + fi + + # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case + + log "Generate '${lm_exp}/run.sh'. You can resume the process from stage 8 using this script" + mkdir -p "${lm_exp}"; echo "${run_args} --stage 8 \"\$@\"; exit \$?" > "${lm_exp}/run.sh"; chmod +x "${lm_exp}/run.sh" + + log "LM training started... log: '${lm_exp}/train.log'" + if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then + # SGE can't include "/" in a job name + jobname="$(basename ${lm_exp})" + else + jobname="${lm_exp}/train.log" + fi + + mkdir -p ${lm_exp} + mkdir -p ${lm_exp}/log + INIT_FILE=${lm_exp}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $ngpu; ++i)); do + { + # i=0 + rank=$i + local_rank=$i + gpu_id=$(echo $device | cut -d',' -f$[$i+1]) + lm_train.py \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --bpemodel ${bpemodel} \ + --token_type ${token_type} \ + --token_list ${token_list} \ + --non_linguistic_symbols ${nlsyms_txt} \ + --cleaner ${cleaner} \ + --g2p ${g2p} \ + --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \ + --valid_shape_file "${lm_stats_dir}/valid/text_shape.${lm_token_type}" \ + --resume true \ + --output_dir ${lm_exp} \ + --config $lm_config \ + --ngpu $ngpu \ + --num_worker_count 1 \ + --multiprocessing_distributed true \ + --dist_init_method $init_method \ + --dist_world_size $ngpu \ + --dist_rank $rank \ + --local_rank $local_rank \ + ${_opts} 1> ${lm_exp}/log/train.log.$i 2>&1 + } & + done + wait + + fi + + + if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then + log "Stage 9: Calc perplexity: ${lm_test_text}" + _opts= + # TODO(kamo): Parallelize? + log "Perplexity calculation started... log: '${lm_exp}/perplexity_test/lm_calc_perplexity.log'" + # shellcheck disable=SC2086 + CUDA_VISIBLE_DEVICES=${device}\ + ${cuda_cmd} --gpu "${ngpu}" "${lm_exp}"/perplexity_test/lm_calc_perplexity.log \ + ${python} -m funasr.bin.lm_calc_perplexity \ + --ngpu "${ngpu}" \ + --data_path_and_name_and_type "${lm_test_text},text,text" \ + --train_config "${lm_exp}"/config.yaml \ + --model_file "${lm_exp}/${inference_lm}" \ + --output_dir "${lm_exp}/perplexity_test" \ + ${_opts} + log "PPL: ${lm_test_text}: $(cat ${lm_exp}/perplexity_test/ppl)" + + fi + + else + log "Stage 7-9: Skip lm-related stages: use_lm=${use_lm}" + fi + + + if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then + _asr_train_dir="${data_feats}/${train_set}" + _asr_valid_dir="${data_feats}/${valid_set}" + log "Stage 10: ASR collect stats: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}" + + _opts= + if [ -n "${asr_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.asr_train --print_config --optim adam + _opts+="--config ${asr_config} " + fi + + _feats_type="$(<${_asr_train_dir}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + # "sound" supports "wav", "flac", etc. + _type=sound + fi + _opts+="--frontend_conf fs=${fs} " + else + _scp=feats.scp + _type=kaldi_ark + _input_size="$(<${_asr_train_dir}/feats_dim)" + _opts+="--input_size=${_input_size} " + fi + + # 1. Split the key file + _logdir="${asr_stats_dir}/logdir" + mkdir -p "${_logdir}" + + # Get the minimum number among ${nj} and the number lines of input files + _nj=$(min "${nj}" "$(<${_asr_train_dir}/${_scp} wc -l)" "$(<${_asr_valid_dir}/${_scp} wc -l)") + + key_file="${_asr_train_dir}/${_scp}" + split_scps="" + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/train.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + key_file="${_asr_valid_dir}/${_scp}" + split_scps="" + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/valid.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Generate run.sh + log "Generate '${asr_stats_dir}/run.sh'. You can resume the process from stage 9 using this script" + mkdir -p "${asr_stats_dir}"; echo "${run_args} --stage 9 \"\$@\"; exit \$?" > "${asr_stats_dir}/run.sh"; chmod +x "${asr_stats_dir}/run.sh" + + # 3. Submit jobs + log "ASR collect-stats started... log: '${_logdir}/stats.*.log'" + + # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted, + # but it's used only for deciding the sample ids. + + # shellcheck disable=SC2086 + ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \ + ${python} -m funasr.bin.asr_train \ + --collect_stats true \ + --mc true \ + --use_preprocessor true \ + --bpemodel "${bpemodel}" \ + --token_type "${token_type}" \ + --token_list "${token_list}" \ + --split_with_space false \ + --non_linguistic_symbols "${nlsyms_txt}" \ + --cleaner "${cleaner}" \ + --g2p "${g2p}" \ + --train_data_path_and_name_and_type "${_asr_train_dir}/${_scp},speech,${_type}" \ + --train_data_path_and_name_and_type "${_asr_train_dir}/text,text,text" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \ + --train_shape_file "${_logdir}/train.JOB.scp" \ + --valid_shape_file "${_logdir}/valid.JOB.scp" \ + --output_dir "${_logdir}/stats.JOB" \ + ${_opts} ${asr_args} || { cat "${_logdir}"/stats.1.log; exit 1; } + + # 4. Aggregate shape files + _opts= + for i in $(seq "${_nj}"); do + _opts+="--input_dir ${_logdir}/stats.${i} " + done + # shellcheck disable=SC2086 + ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${asr_stats_dir}" + + # Append the num-tokens at the last dimensions. This is used for batch-bins count + <"${asr_stats_dir}/train/text_shape" \ + awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \ + >"${asr_stats_dir}/train/text_shape.${token_type}" + + <"${asr_stats_dir}/valid/text_shape" \ + awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \ + >"${asr_stats_dir}/valid/text_shape.${token_type}" + fi + + + if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then + _asr_train_dir="${data_feats}/${train_set}" + _asr_valid_dir="${data_feats}/${valid_set}" + log "Stage 11: ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}" + + _opts= + if [ -n "${asr_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.asr_train --print_config --optim adam + _opts+="--config ${asr_config} " + fi + + _feats_type="$(<${_asr_train_dir}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + # "sound" supports "wav", "flac", etc. + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + _opts+="--frontend_conf fs=${fs} " + else + _scp=feats.scp + _type=kaldi_ark + _input_size="$(<${_asr_train_dir}/feats_dim)" + _opts+="--input_size=${_input_size} " + + fi + if [ "${feats_normalize}" = global_mvn ]; then + # Default normalization is utterance_mvn and changes to global_mvn + _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz " + fi + + if [ "${num_splits_asr}" -gt 1 ]; then + # If you met a memory error when parsing text files, this option may help you. + # The corpus is split into subsets and each subset is used for training one by one in order, + # so the memory footprint can be limited to the memory required for each dataset. + + _split_dir="${asr_stats_dir}/splits${num_splits_asr}" + if [ ! -f "${_split_dir}/.done" ]; then + rm -f "${_split_dir}/.done" + ${python} -m espnet2.bin.split_scps \ + --scps \ + "${_asr_train_dir}/${_scp}" \ + "${_asr_train_dir}/text" \ + "${asr_stats_dir}/train/speech_shape" \ + "${asr_stats_dir}/train/text_shape.${token_type}" \ + --num_splits "${num_splits_asr}" \ + --output_dir "${_split_dir}" + touch "${_split_dir}/.done" + else + log "${_split_dir}/.done exists. Spliting is skipped" + fi + + _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} " + _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text " + _opts+="--train_shape_file ${_split_dir}/speech_shape " + _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} " + _opts+="--multiple_iterator true " + + else + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} " + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text " + _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape " + _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} " + fi + + # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script" + # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh" + + # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case + log "ASR training started... log: '${asr_exp}/log/train.log'" + # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then + # # SGE can't include "/" in a job name + # jobname="$(basename ${asr_exp})" + # else + # jobname="${asr_exp}/train.log" + # fi + + mkdir -p ${asr_exp} + mkdir -p ${asr_exp}/log + INIT_FILE=${asr_exp}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $ngpu; ++i)); do + { + # i=0 + rank=$i + local_rank=$i + gpu_id=$(echo $device | cut -d',' -f$[$i+1]) + asr_train.py \ + --mc true \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --bpemodel ${bpemodel} \ + --token_type ${token_type} \ + --token_list ${token_list} \ + --split_with_space false \ + --non_linguistic_symbols ${nlsyms_txt} \ + --cleaner ${cleaner} \ + --g2p ${g2p} \ + --valid_data_path_and_name_and_type ${_asr_valid_dir}/${_scp},speech,${_type} \ + --valid_data_path_and_name_and_type ${_asr_valid_dir}/text,text,text \ + --valid_shape_file ${asr_stats_dir}/valid/speech_shape \ + --valid_shape_file ${asr_stats_dir}/valid/text_shape.${token_type} \ + --resume true \ + --output_dir ${asr_exp} \ + --config $asr_config \ + --ngpu $ngpu \ + --num_worker_count 1 \ + --multiprocessing_distributed true \ + --dist_init_method $init_method \ + --dist_world_size $ngpu \ + --dist_rank $rank \ + --local_rank $local_rank \ + ${_opts} 1> ${asr_exp}/log/train.log.$i 2>&1 + } & + done + wait + + fi + + if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then + _asr_train_dir="${data_feats}/${train_set}" + _asr_valid_dir="${data_feats}/${valid_set}" + log "Stage 12: SA-ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}" + + _opts= + if [ -n "${sa_asr_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.asr_train --print_config --optim adam + _opts+="--config ${sa_asr_config} " + fi + + _feats_type="$(<${_asr_train_dir}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + # "sound" supports "wav", "flac", etc. + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + _opts+="--frontend_conf fs=${fs} " + else + _scp=feats.scp + _type=kaldi_ark + _input_size="$(<${_asr_train_dir}/feats_dim)" + _opts+="--input_size=${_input_size} " + + fi + if [ "${feats_normalize}" = global_mvn ]; then + # Default normalization is utterance_mvn and changes to global_mvn + _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz " + fi + + if [ "${num_splits_asr}" -gt 1 ]; then + # If you met a memory error when parsing text files, this option may help you. + # The corpus is split into subsets and each subset is used for training one by one in order, + # so the memory footprint can be limited to the memory required for each dataset. + + _split_dir="${asr_stats_dir}/splits${num_splits_asr}" + if [ ! -f "${_split_dir}/.done" ]; then + rm -f "${_split_dir}/.done" + ${python} -m espnet2.bin.split_scps \ + --scps \ + "${_asr_train_dir}/${_scp}" \ + "${_asr_train_dir}/text" \ + "${asr_stats_dir}/train/speech_shape" \ + "${asr_stats_dir}/train/text_shape.${token_type}" \ + --num_splits "${num_splits_asr}" \ + --output_dir "${_split_dir}" + touch "${_split_dir}/.done" + else + log "${_split_dir}/.done exists. Spliting is skipped" + fi + + _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} " + _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text " + _opts+="--train_data_path_and_name_and_type ${_split_dir}/text_id_train,text_id,text_int " + _opts+="--train_data_path_and_name_and_type ${_split_dir}/oracle_profile_padding.scp,profile,npy " + _opts+="--train_shape_file ${_split_dir}/speech_shape " + _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} " + _opts+="--multiple_iterator true " + + else + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} " + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text " + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/oracle_profile_padding.scp,profile,npy " + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text_id_train,text_id,text_int " + _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape " + _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} " + fi + + # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script" + # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh" + + # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case + log "SA-ASR training started... log: '${sa_asr_exp}/log/train.log'" + # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then + # # SGE can't include "/" in a job name + # jobname="$(basename ${asr_exp})" + # else + # jobname="${asr_exp}/train.log" + # fi + + mkdir -p ${sa_asr_exp} + mkdir -p ${sa_asr_exp}/log + INIT_FILE=${sa_asr_exp}/ddp_init + + if [ ! -f "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then + # download xvector extractor model file + python local/download_xvector_model.py exp + log "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" + fi + + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $ngpu; ++i)); do + { + # i=0 + rank=$i + local_rank=$i + gpu_id=$(echo $device | cut -d',' -f$[$i+1]) + sa_asr_train.py \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --unused_parameters true \ + --bpemodel ${bpemodel} \ + --token_type ${token_type} \ + --token_list ${token_list} \ + --max_spk_num 4 \ + --split_with_space false \ + --non_linguistic_symbols ${nlsyms_txt} \ + --cleaner ${cleaner} \ + --g2p ${g2p} \ + --allow_variable_data_keys true \ + --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \ + --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \ + --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \ + --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/oracle_profile_nopadding.scp,profile,npy" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/text_id_train,text_id,text_int" \ + --valid_shape_file "${asr_stats_dir}/valid/speech_shape" \ + --valid_shape_file "${asr_stats_dir}/valid/text_shape.${token_type}" \ + --resume true \ + --output_dir ${sa_asr_exp} \ + --config $sa_asr_config \ + --ngpu $ngpu \ + --num_worker_count 1 \ + --multiprocessing_distributed true \ + --dist_init_method $init_method \ + --dist_world_size $ngpu \ + --dist_rank $rank \ + --local_rank $local_rank \ + ${_opts} 1> ${sa_asr_exp}/log/train.log.$i 2>&1 + } & + done + wait + + fi + +else + log "Skip the training stages" +fi + + +if ! "${skip_eval}"; then + if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then + log "Stage 13: Decoding multi-talker ASR: training_dir=${asr_exp}" + + if ${gpu_inference}; then + _cmd="${cuda_cmd}" + inference_nj=$[${ngpu}*${njob_infer}] + _ngpu=1 + + else + _cmd="${decode_cmd}" + inference_nj=$inference_nj + _ngpu=0 + fi + + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + if "${use_lm}"; then + if "${use_word_lm}"; then + _opts+="--word_lm_train_config ${lm_exp}/config.yaml " + _opts+="--word_lm_file ${lm_exp}/${inference_lm} " + else + _opts+="--lm_train_config ${lm_exp}/config.yaml " + _opts+="--lm_file ${lm_exp}/${inference_lm} " + fi + fi + + # 2. Generate run.sh + log "Generate '${asr_exp}/${inference_tag}/run.sh'. You can resume the process from stage 13 using this script" + mkdir -p "${asr_exp}/${inference_tag}"; echo "${run_args} --stage 13 \"\$@\"; exit \$?" > "${asr_exp}/${inference_tag}/run.sh"; chmod +x "${asr_exp}/${inference_tag}/run.sh" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${asr_exp}/${inference_tag}/${dset}" + _logdir="${_dir}/logdir" + mkdir -p "${_logdir}" + + _feats_type="$(<${_data}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + else + _scp=feats.scp + _type=kaldi_ark + fi + + # 1. Split the key file + key_file=${_data}/${_scp} + split_scps="" + _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)") + echo $_nj + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Submit decoding jobs + log "Decoding started... log: '${_logdir}/asr_inference.*.log'" + + ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --nbest 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob_infer} \ + --gpuid_list ${device} \ + --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --asr_train_config "${asr_exp}"/config.yaml \ + --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode asr \ + ${_opts} + + # 3. Concatenates the output files from each jobs + for f in token token_int score text; do + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | LC_ALL=C sort -k1 >"${_dir}/${f}" + done + done + fi + + + if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then + log "Stage 14: Scoring multi-talker ASR" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${asr_exp}/${inference_tag}/${dset}" + + python local/proce_text.py ${_data}/text ${_data}/text.proc + python local/proce_text.py ${_dir}/text ${_dir}/text.proc + + python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + + done + + fi + + if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then + log "Stage 15: Decoding SA-ASR (oracle profile): training_dir=${sa_asr_exp}" + + if ${gpu_inference}; then + _cmd="${cuda_cmd}" + inference_nj=$[${ngpu}*${njob_infer}] + _ngpu=1 + + else + _cmd="${decode_cmd}" + inference_nj=$inference_nj + _ngpu=0 + fi + + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + if "${use_lm}"; then + if "${use_word_lm}"; then + _opts+="--word_lm_train_config ${lm_exp}/config.yaml " + _opts+="--word_lm_file ${lm_exp}/${inference_lm} " + else + _opts+="--lm_train_config ${lm_exp}/config.yaml " + _opts+="--lm_file ${lm_exp}/${inference_lm} " + fi + fi + + # 2. Generate run.sh + log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh'. You can resume the process from stage 15 using this script" + mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.oracle"; echo "${run_args} --stage 15 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}" + _logdir="${_dir}/logdir" + mkdir -p "${_logdir}" + + _feats_type="$(<${_data}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + else + _scp=feats.scp + _type=kaldi_ark + fi + + # 1. Split the key file + key_file=${_data}/${_scp} + split_scps="" + _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)") + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Submit decoding jobs + log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'" + # shellcheck disable=SC2086 + ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --nbest 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob_infer} \ + --gpuid_list ${device} \ + --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \ + --data_path_and_name_and_type "${_data}/oracle_profile_nopadding.scp,profile,npy" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --allow_variable_data_keys true \ + --asr_train_config "${sa_asr_exp}"/config.yaml \ + --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode sa_asr \ + ${_opts} + + + # 3. Concatenates the output files from each jobs + for f in token token_int score text text_id; do + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | LC_ALL=C sort -k1 >"${_dir}/${f}" + done + done + fi + + if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then + log "Stage 16: Scoring SA-ASR (oracle profile)" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}" + + python local/proce_text.py ${_data}/text ${_data}/text.proc + python local/proce_text.py ${_dir}/text ${_dir}/text.proc + + python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + + python local/process_text_spk_merge.py ${_dir} + python local/process_text_spk_merge.py ${_data} + + python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer + tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt + cat ${_dir}/text.cpcer.txt + + done + + fi + + if [ ${stage} -le 17 ] && [ ${stop_stage} -ge 17 ]; then + log "Stage 17: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}" + + if ${gpu_inference}; then + _cmd="${cuda_cmd}" + inference_nj=$[${ngpu}*${njob_infer}] + _ngpu=1 + + else + _cmd="${decode_cmd}" + inference_nj=$inference_nj + _ngpu=0 + fi + + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + if "${use_lm}"; then + if "${use_word_lm}"; then + _opts+="--word_lm_train_config ${lm_exp}/config.yaml " + _opts+="--word_lm_file ${lm_exp}/${inference_lm} " + else + _opts+="--lm_train_config ${lm_exp}/config.yaml " + _opts+="--lm_file ${lm_exp}/${inference_lm} " + fi + fi + + # 2. Generate run.sh + log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script" + mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" + _logdir="${_dir}/logdir" + mkdir -p "${_logdir}" + + _feats_type="$(<${_data}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + else + _scp=feats.scp + _type=kaldi_ark + fi + + # 1. Split the key file + key_file=${_data}/${_scp} + split_scps="" + _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)") + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Submit decoding jobs + log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'" + # shellcheck disable=SC2086 + ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --nbest 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob_infer} \ + --gpuid_list ${device} \ + --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \ + --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --allow_variable_data_keys true \ + --asr_train_config "${sa_asr_exp}"/config.yaml \ + --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode sa_asr \ + ${_opts} + + # 3. Concatenates the output files from each jobs + for f in token token_int score text text_id; do + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | LC_ALL=C sort -k1 >"${_dir}/${f}" + done + done + fi + + if [ ${stage} -le 18 ] && [ ${stop_stage} -ge 18 ]; then + log "Stage 18: Scoring SA-ASR (cluster profile)" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" + + python local/proce_text.py ${_data}/text ${_data}/text.proc + python local/proce_text.py ${_dir}/text ${_dir}/text.proc + + python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + + python local/process_text_spk_merge.py ${_dir} + python local/process_text_spk_merge.py ${_data} + + python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer + tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt + cat ${_dir}/text.cpcer.txt + + done + + fi + +else + log "Skip the evaluation stages" +fi + + +log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/egs/alimeeting/sa-asr/asr_local_infer.sh b/egs/alimeeting/sa-asr/asr_local_infer.sh new file mode 100755 index 000000000..8e8148ff8 --- /dev/null +++ b/egs/alimeeting/sa-asr/asr_local_infer.sh @@ -0,0 +1,590 @@ +#!/usr/bin/env bash + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} +min() { + local a b + a=$1 + for b in "$@"; do + if [ "${b}" -le "${a}" ]; then + a="${b}" + fi + done + echo "${a}" +} +SECONDS=0 + +# General configuration +stage=1 # Processes starts from the specified stage. +stop_stage=10000 # Processes is stopped at the specified stage. +skip_data_prep=false # Skip data preparation stages. +skip_train=false # Skip training stages. +skip_eval=false # Skip decoding and evaluation stages. +skip_upload=true # Skip packing and uploading stages. +ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu). +num_nodes=1 # The number of nodes. +nj=16 # The number of parallel jobs. +inference_nj=16 # The number of parallel jobs in decoding. +gpu_inference=false # Whether to perform gpu decoding. +njob_infer=4 +dumpdir=dump2 # Directory to dump features. +expdir=exp # Directory to save experiments. +python=python3 # Specify python to execute espnet commands. +device=0 + +# Data preparation related +local_data_opts= # The options given to local/data.sh. + +# Speed perturbation related +speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space). + +# Feature extraction related +feats_type=raw # Feature type (raw or fbank_pitch). +audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw). +fs=16000 # Sampling rate. +min_wav_duration=0.1 # Minimum duration in second. +max_wav_duration=20 # Maximum duration in second. + +# Tokenization related +token_type=bpe # Tokenization type (char or bpe). +nbpe=30 # The number of BPE vocabulary. +bpemode=unigram # Mode of BPE (unigram or bpe). +oov="" # Out of vocabulary symbol. +blank="" # CTC blank symbol +sos_eos="" # sos and eos symbole +bpe_input_sentence_size=100000000 # Size of input sentence for BPE. +bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE +bpe_char_cover=1.0 # character coverage when modeling BPE + +# Language model related +use_lm=true # Use language model for ASR decoding. +lm_tag= # Suffix to the result dir for language model training. +lm_exp= # Specify the direcotry path for LM experiment. + # If this option is specified, lm_tag is ignored. +lm_stats_dir= # Specify the direcotry path for LM statistics. +lm_config= # Config for language model training. +lm_args= # Arguments for language model training, e.g., "--max_epoch 10". + # Note that it will overwrite args in lm config. +use_word_lm=false # Whether to use word language model. +num_splits_lm=1 # Number of splitting for lm corpus. +# shellcheck disable=SC2034 +word_vocab_size=10000 # Size of word vocabulary. + +# ASR model related +asr_tag= # Suffix to the result dir for asr model training. +asr_exp= # Specify the direcotry path for ASR experiment. + # If this option is specified, asr_tag is ignored. +sa_asr_exp= +asr_stats_dir= # Specify the direcotry path for ASR statistics. +asr_config= # Config for asr model training. +sa_asr_config= +asr_args= # Arguments for asr model training, e.g., "--max_epoch 10". + # Note that it will overwrite args in asr config. +feats_normalize=global_mvn # Normalizaton layer type. +num_splits_asr=1 # Number of splitting for lm corpus. + +# Decoding related +inference_tag= # Suffix to the result dir for decoding. +inference_config= # Config for decoding. +inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1". + # Note that it will overwrite args in inference config. +sa_asr_inference_tag= +sa_asr_inference_args= + +inference_lm=valid.loss.ave.pb # Language modle path for decoding. +inference_asr_model=valid.acc.ave.pb # ASR model path for decoding. + # e.g. + # inference_asr_model=train.loss.best.pth + # inference_asr_model=3epoch.pth + # inference_asr_model=valid.acc.best.pth + # inference_asr_model=valid.loss.ave.pth +inference_sa_asr_model=valid.acc_spk.ave.pb +download_model= # Download a model from Model Zoo and use it for decoding. + +# [Task dependent] Set the datadir name created by local/data.sh +train_set= # Name of training set. +valid_set= # Name of validation set used for monitoring/tuning network training. +test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified. +bpe_train_text= # Text file path of bpe training set. +lm_train_text= # Text file path of language model training set. +lm_dev_text= # Text file path of language model development set. +lm_test_text= # Text file path of language model evaluation set. +nlsyms_txt=none # Non-linguistic symbol list if existing. +cleaner=none # Text cleaner. +g2p=none # g2p method (needed if token_type=phn). +lang=zh # The language type of corpus. +score_opts= # The options given to sclite scoring +local_score_opts= # The options given to local/score.sh. + +help_message=$(cat << EOF +Usage: $0 --train-set "" --valid-set "" --test_sets "" + +Options: + # General configuration + --stage # Processes starts from the specified stage (default="${stage}"). + --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}"). + --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}"). + --skip_train # Skip training stages (default="${skip_train}"). + --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}"). + --skip_upload # Skip packing and uploading stages (default="${skip_upload}"). + --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}"). + --num_nodes # The number of nodes (default="${num_nodes}"). + --nj # The number of parallel jobs (default="${nj}"). + --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}"). + --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}"). + --dumpdir # Directory to dump features (default="${dumpdir}"). + --expdir # Directory to save experiments (default="${expdir}"). + --python # Specify python to execute espnet commands (default="${python}"). + --device # Which GPUs are use for local training (defalut="${device}"). + + # Data preparation related + --local_data_opts # The options given to local/data.sh (default="${local_data_opts}"). + + # Speed perturbation related + --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}"). + + # Feature extraction related + --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}"). + --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}"). + --fs # Sampling rate (default="${fs}"). + --min_wav_duration # Minimum duration in second (default="${min_wav_duration}"). + --max_wav_duration # Maximum duration in second (default="${max_wav_duration}"). + + # Tokenization related + --token_type # Tokenization type (char or bpe, default="${token_type}"). + --nbpe # The number of BPE vocabulary (default="${nbpe}"). + --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}"). + --oov # Out of vocabulary symbol (default="${oov}"). + --blank # CTC blank symbol (default="${blank}"). + --sos_eos # sos and eos symbole (default="${sos_eos}"). + --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}"). + --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}"). + --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}"). + + # Language model related + --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}"). + --lm_exp # Specify the direcotry path for LM experiment. + # If this option is specified, lm_tag is ignored (default="${lm_exp}"). + --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}"). + --lm_config # Config for language model training (default="${lm_config}"). + --lm_args # Arguments for language model training (default="${lm_args}"). + # e.g., --lm_args "--max_epoch 10" + # Note that it will overwrite args in lm config. + --use_word_lm # Whether to use word language model (default="${use_word_lm}"). + --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}"). + --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}"). + + # ASR model related + --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}"). + --asr_exp # Specify the direcotry path for ASR experiment. + # If this option is specified, asr_tag is ignored (default="${asr_exp}"). + --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}"). + --asr_config # Config for asr model training (default="${asr_config}"). + --asr_args # Arguments for asr model training (default="${asr_args}"). + # e.g., --asr_args "--max_epoch 10" + # Note that it will overwrite args in asr config. + --feats_normalize # Normalizaton layer type (default="${feats_normalize}"). + --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}"). + + # Decoding related + --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}"). + --inference_config # Config for decoding (default="${inference_config}"). + --inference_args # Arguments for decoding (default="${inference_args}"). + # e.g., --inference_args "--lm_weight 0.1" + # Note that it will overwrite args in inference config. + --inference_lm # Language modle path for decoding (default="${inference_lm}"). + --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}"). + --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}"). + + # [Task dependent] Set the datadir name created by local/data.sh + --train_set # Name of training set (required). + --valid_set # Name of validation set used for monitoring/tuning network training (required). + --test_sets # Names of test sets. + # Multiple items (e.g., both dev and eval sets) can be specified (required). + --bpe_train_text # Text file path of bpe training set. + --lm_train_text # Text file path of language model training set. + --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}"). + --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}"). + --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}"). + --cleaner # Text cleaner (default="${cleaner}"). + --g2p # g2p method (default="${g2p}"). + --lang # The language type of corpus (default=${lang}). + --score_opts # The options given to sclite scoring (default="{score_opts}"). + --local_score_opts # The options given to local/score.sh (default="{local_score_opts}"). +EOF +) + +log "$0 $*" +# Save command line args for logging (they will be lost after utils/parse_options.sh) +run_args=$(python -m funasr.utils.cli_utils $0 "$@") +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + log "${help_message}" + log "Error: No positional arguments are required." + exit 2 +fi + +. ./path.sh + + +# Check required arguments +[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; }; +[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; }; +[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; }; + +# Check feature type +if [ "${feats_type}" = raw ]; then + data_feats=${dumpdir}/raw +elif [ "${feats_type}" = fbank_pitch ]; then + data_feats=${dumpdir}/fbank_pitch +elif [ "${feats_type}" = fbank ]; then + data_feats=${dumpdir}/fbank +elif [ "${feats_type}" == extracted ]; then + data_feats=${dumpdir}/extracted +else + log "${help_message}" + log "Error: not supported: --feats_type ${feats_type}" + exit 2 +fi + +# Use the same text as ASR for bpe training if not specified. +[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text" +# Use the same text as ASR for lm training if not specified. +[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text" +# Use the same text as ASR for lm training if not specified. +[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text" +# Use the text of the 1st evaldir if lm_test is not specified +[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text" + +# Check tokenization type +if [ "${lang}" != noinfo ]; then + token_listdir=data/${lang}_token_list +else + token_listdir=data/token_list +fi +bpedir="${token_listdir}/bpe_${bpemode}${nbpe}" +bpeprefix="${bpedir}"/bpe +bpemodel="${bpeprefix}".model +bpetoken_list="${bpedir}"/tokens.txt +chartoken_list="${token_listdir}"/char/tokens.txt +# NOTE: keep for future development. +# shellcheck disable=SC2034 +wordtoken_list="${token_listdir}"/word/tokens.txt + +if [ "${token_type}" = bpe ]; then + token_list="${bpetoken_list}" +elif [ "${token_type}" = char ]; then + token_list="${chartoken_list}" + bpemodel=none +elif [ "${token_type}" = word ]; then + token_list="${wordtoken_list}" + bpemodel=none +else + log "Error: not supported --token_type '${token_type}'" + exit 2 +fi +if ${use_word_lm}; then + log "Error: Word LM is not supported yet" + exit 2 + + lm_token_list="${wordtoken_list}" + lm_token_type=word +else + lm_token_list="${token_list}" + lm_token_type="${token_type}" +fi + + +# Set tag for naming of model directory +if [ -z "${asr_tag}" ]; then + if [ -n "${asr_config}" ]; then + asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}" + else + asr_tag="train_${feats_type}" + fi + if [ "${lang}" != noinfo ]; then + asr_tag+="_${lang}_${token_type}" + else + asr_tag+="_${token_type}" + fi + if [ "${token_type}" = bpe ]; then + asr_tag+="${nbpe}" + fi + # Add overwritten arg's info + if [ -n "${asr_args}" ]; then + asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" + fi + if [ -n "${speed_perturb_factors}" ]; then + asr_tag+="_sp" + fi +fi +if [ -z "${lm_tag}" ]; then + if [ -n "${lm_config}" ]; then + lm_tag="$(basename "${lm_config}" .yaml)" + else + lm_tag="train" + fi + if [ "${lang}" != noinfo ]; then + lm_tag+="_${lang}_${lm_token_type}" + else + lm_tag+="_${lm_token_type}" + fi + if [ "${lm_token_type}" = bpe ]; then + lm_tag+="${nbpe}" + fi + # Add overwritten arg's info + if [ -n "${lm_args}" ]; then + lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" + fi +fi + +# The directory used for collect-stats mode +if [ -z "${asr_stats_dir}" ]; then + if [ "${lang}" != noinfo ]; then + asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}" + else + asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}" + fi + if [ "${token_type}" = bpe ]; then + asr_stats_dir+="${nbpe}" + fi + if [ -n "${speed_perturb_factors}" ]; then + asr_stats_dir+="_sp" + fi +fi +if [ -z "${lm_stats_dir}" ]; then + if [ "${lang}" != noinfo ]; then + lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}" + else + lm_stats_dir="${expdir}/lm_stats_${lm_token_type}" + fi + if [ "${lm_token_type}" = bpe ]; then + lm_stats_dir+="${nbpe}" + fi +fi +# The directory used for training commands +if [ -z "${asr_exp}" ]; then + asr_exp="${expdir}/asr_${asr_tag}" +fi +if [ -z "${lm_exp}" ]; then + lm_exp="${expdir}/lm_${lm_tag}" +fi + + +if [ -z "${inference_tag}" ]; then + if [ -n "${inference_config}" ]; then + inference_tag="$(basename "${inference_config}" .yaml)" + else + inference_tag=inference + fi + # Add overwritten arg's info + if [ -n "${inference_args}" ]; then + inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")" + fi + if "${use_lm}"; then + inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" + fi + inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" +fi + +if [ -z "${sa_asr_inference_tag}" ]; then + if [ -n "${inference_config}" ]; then + sa_asr_inference_tag="$(basename "${inference_config}" .yaml)" + else + sa_asr_inference_tag=sa_asr_inference + fi + # Add overwritten arg's info + if [ -n "${sa_asr_inference_args}" ]; then + sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")" + fi + if "${use_lm}"; then + sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" + fi + sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" +fi + +train_cmd="run.pl" +cuda_cmd="run.pl" +decode_cmd="run.pl" + +# ========================== Main stages start from here. ========================== + +if ! "${skip_data_prep}"; then + + if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + if [ "${feats_type}" = raw ]; then + log "Stage 1: Format wav.scp: data/ -> ${data_feats}" + + # ====== Recreating "wav.scp" ====== + # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |", + # shouldn't be used in training process. + # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file + # and it can also change the audio-format and sampling rate. + # If nothing is need, then format_wav_scp.sh does nothing: + # i.e. the input file format and rate is same as the output. + + for dset in "${test_sets}" ; do + + _suf="" + + utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" + + rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur} + _opts= + if [ -e data/"${dset}"/segments ]; then + # "segments" is used for splitting wav files which are written in "wav".scp + # into utterances. The file format of segments: + # + # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5" + # Where the time is written in seconds. + _opts+="--segments data/${dset}/segments " + fi + # shellcheck disable=SC2086 + scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ + --audio-format "${audio_format}" --fs "${fs}" ${_opts} \ + "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}" + + echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type" + done + + else + log "Error: not supported: --feats_type ${feats_type}" + exit 2 + fi + fi + + if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + log "Stage 2: Generate speaker profile by spectral-cluster" + mkdir -p "profile_log" + for dset in "${test_sets}"; do + # generate cluster_profile with spectral-cluster directly (for infering and without oracle information) + python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log" + log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)" + done + fi + +else + log "Skip the stages for data preparation" +fi + + +# ========================== Data preparation is done here. ========================== + +if ! "${skip_eval}"; then + + if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + log "Stage 3: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}" + + if ${gpu_inference}; then + _cmd="${cuda_cmd}" + inference_nj=$[${ngpu}*${njob_infer}] + _ngpu=1 + + else + _cmd="${decode_cmd}" + inference_nj=$njob_infer + _ngpu=0 + fi + + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + if "${use_lm}"; then + if "${use_word_lm}"; then + _opts+="--word_lm_train_config ${lm_exp}/config.yaml " + _opts+="--word_lm_file ${lm_exp}/${inference_lm} " + else + _opts+="--lm_train_config ${lm_exp}/config.yaml " + _opts+="--lm_file ${lm_exp}/${inference_lm} " + fi + fi + + # 2. Generate run.sh + log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script" + mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" + _logdir="${_dir}/logdir" + mkdir -p "${_logdir}" + + _feats_type="$(<${_data}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + else + _scp=feats.scp + _type=kaldi_ark + fi + + # 1. Split the key file + key_file=${_data}/${_scp} + split_scps="" + _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)") + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Submit decoding jobs + log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'" + # shellcheck disable=SC2086 + ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --nbest 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob_infer} \ + --gpuid_list ${device} \ + --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \ + --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --allow_variable_data_keys true \ + --asr_train_config "${sa_asr_exp}"/config.yaml \ + --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode sa_asr \ + ${_opts} + + # 3. Concatenates the output files from each jobs + for f in token token_int score text text_id; do + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | LC_ALL=C sort -k1 >"${_dir}/${f}" + done + done + fi + + if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + log "Stage 4: Generate SA-ASR results (cluster profile)" + + for dset in ${test_sets}; do + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" + + python local/process_text_spk_merge.py ${_dir} + done + + fi + +else + log "Skip the evaluation stages" +fi + + +log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml new file mode 100644 index 000000000..88fdbc20b --- /dev/null +++ b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml @@ -0,0 +1,6 @@ +beam_size: 20 +penalty: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +ctc_weight: 0.6 +lm_weight: 0.3 diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml new file mode 100644 index 000000000..a8c996875 --- /dev/null +++ b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml @@ -0,0 +1,88 @@ +# network architecture +frontend: default +frontend_conf: + n_fft: 400 + win_length: 400 + hop_length: 160 + use_channel: 0 + +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder architecture type + normalize_before: true + rel_pos_type: latest + pos_enc_layer_type: rel_pos + selfattention_layer_type: rel_selfattn + activation_type: swish + macaron_style: true + use_cnn_module: true + cnn_module_kernel: 15 + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# ctc related +ctc_conf: + ignore_nan_grad: true + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# minibatch related +batch_type: numel +batch_bins: 10000000 # reduce/increase this number according to your GPU memory + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 100 +val_scheduler_criterion: + - valid + - acc +best_model_criterion: +- - valid + - acc + - max +keep_nbest_models: 10 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 diff --git a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml new file mode 100644 index 000000000..68520ae23 --- /dev/null +++ b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml @@ -0,0 +1,29 @@ +lm: transformer +lm_conf: + pos_enc: null + embed_unit: 128 + att_unit: 512 + head: 8 + unit: 2048 + layer: 16 + dropout_rate: 0.1 + +# optimization related +grad_clip: 5.0 +batch_type: numel +batch_bins: 500000 # 4gpus * 500000 +accum_grad: 1 +max_epoch: 15 # 15epoch is enougth + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + +best_model_criterion: +- - valid + - loss + - min +keep_nbest_models: 10 # 10 is good. diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml new file mode 100644 index 000000000..e91db1804 --- /dev/null +++ b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml @@ -0,0 +1,116 @@ +# network architecture +frontend: default +frontend_conf: + n_fft: 400 + win_length: 400 + hop_length: 160 + use_channel: 0 + +# encoder related +asr_encoder: conformer +asr_encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder architecture type + normalize_before: true + pos_enc_layer_type: rel_pos + selfattention_layer_type: rel_selfattn + activation_type: swish + macaron_style: true + use_cnn_module: true + cnn_module_kernel: 15 + +spk_encoder: resnet34_diar +spk_encoder_conf: + use_head_conv: true + batchnorm_momentum: 0.5 + use_head_maxpool: false + num_nodes_pooling_layer: 256 + layers_in_block: + - 3 + - 4 + - 6 + - 3 + filters_in_block: + - 32 + - 64 + - 128 + - 256 + pooling_type: statistic + num_nodes_resnet1: 256 + num_nodes_last_layer: 256 + batchnorm_momentum: 0.5 + +# decoder related +decoder: sa_decoder +decoder_conf: + attention_heads: 4 + linear_units: 2048 + asr_num_blocks: 6 + spk_num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + spk_weight: 0.5 + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +ctc_conf: + ignore_nan_grad: true + +# minibatch related +batch_type: numel +batch_bins: 10000000 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 60 +val_scheduler_criterion: + - valid + - loss +best_model_criterion: +- - valid + - acc + - max +- - valid + - acc_spk + - max +- - valid + - loss + - min +keep_nbest_models: 10 + +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 8000 + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 + diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh new file mode 100755 index 000000000..8151bae30 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +help_messge=$(cat << EOF +Usage: $0 + +Options: + --no_overlap (bool): Whether to ignore the overlapping utterance in the training set. + --tgt (string): Which set to process, test or train. +EOF +) + +SECONDS=0 +tgt=Train #Train or Eval + + +log "$0 $*" +echo $tgt +. ./utils/parse_options.sh + +. ./path.sh + +AliMeeting="${PWD}/dataset" + +if [ $# -gt 2 ]; then + log "${help_message}" + exit 2 +fi + + +if [ ! -d "${AliMeeting}" ]; then + log "Error: ${AliMeeting} is empty." + exit 2 +fi + +# To absolute path +AliMeeting=$(cd ${AliMeeting}; pwd) +echo $AliMeeting +far_raw_dir=${AliMeeting}/${tgt}_Ali_far/ +near_raw_dir=${AliMeeting}/${tgt}_Ali_near/ + +far_dir=data/local/${tgt}_Ali_far +near_dir=data/local/${tgt}_Ali_near +far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker +mkdir -p $far_single_speaker_dir + +stage=1 +stop_stage=4 +mkdir -p $far_dir +mkdir -p $near_dir + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + log "stage 1:process alimeeting near dir" + + find -L $near_raw_dir/audio_dir -iname "*.wav" > $near_dir/wavlist + awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid + find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist + n1_wav=$(wc -l < $near_dir/wavlist) + n2_text=$(wc -l < $near_dir/textgrid.flist) + log near file found $n1_wav wav and $n2_text text. + + paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp + + # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp + cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp + + python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False + cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text + utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk + #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1 + #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk + utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt + utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments + sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1 + sed -e 's/!//g' $near_dir/tmp1> $near_dir/tmp2 + sed -e 's/?//g' $near_dir/tmp2> $near_dir/text + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + log "stage 2:process alimeeting far dir" + + find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist + awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid + find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist + n1_wav=$(wc -l < $far_dir/wavlist) + n2_text=$(wc -l < $far_dir/textgrid.flist) + log far file found $n1_wav wav and $n2_text text. + + paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp + + cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp + + python local/alimeeting_process_overlap_force.py --path $far_dir \ + --no-overlap false --mars True \ + --overlap_length 0.8 --max_length 7 + + cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text + utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk + #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk + + utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt + utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments + sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1 + sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2 + sed -e 's/!//g' $far_dir/tmp2> $far_dir/tmp3 + sed -e 's/?//g' $far_dir/tmp3> $far_dir/text +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + log "stage 3: finali data process" + + utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near + utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far + + sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo + sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo + + # remove space in text + for x in ${tgt}_Ali_near ${tgt}_Ali_far; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + + log "Successfully finished. [elapsed=${SECONDS}s]" +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + log "stage 4: process alimeeting far dir (single speaker by oracle time strap)" + cp -r $far_dir/* $far_single_speaker_dir + mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath + paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist + python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir + + cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text + utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt + + ./utils/fix_data_dir.sh $far_single_speaker_dir + utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker + + # remove space in text + for x in ${tgt}_Ali_far_single_speaker; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + log "Successfully finished. [elapsed=${SECONDS}s]" +fi \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh new file mode 100755 index 000000000..382a05669 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh @@ -0,0 +1,129 @@ +#!/usr/bin/env bash +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +help_messge=$(cat << EOF +Usage: $0 + +Options: + --no_overlap (bool): Whether to ignore the overlapping utterance in the training set. + --tgt (string): Which set to process, test or train. +EOF +) + +SECONDS=0 +tgt=Train #Train or Eval + + +log "$0 $*" +echo $tgt +. ./utils/parse_options.sh + +. ./path.sh + +AliMeeting="${PWD}/dataset" + +if [ $# -gt 2 ]; then + log "${help_message}" + exit 2 +fi + + +if [ ! -d "${AliMeeting}" ]; then + log "Error: ${AliMeeting} is empty." + exit 2 +fi + +# To absolute path +AliMeeting=$(cd ${AliMeeting}; pwd) +echo $AliMeeting +far_raw_dir=${AliMeeting}/${tgt}_Ali_far/ + +far_dir=data/local/${tgt}_Ali_far +far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker +mkdir -p $far_single_speaker_dir + +stage=1 +stop_stage=3 +mkdir -p $far_dir + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + log "stage 1:process alimeeting far dir" + + find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist + awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid + find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist + n1_wav=$(wc -l < $far_dir/wavlist) + n2_text=$(wc -l < $far_dir/textgrid.flist) + log far file found $n1_wav wav and $n2_text text. + + paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp + + cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp + + python local/alimeeting_process_overlap_force.py --path $far_dir \ + --no-overlap false --mars True \ + --overlap_length 0.8 --max_length 7 + + cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text + utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk + #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk + + utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt + utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments + sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1 + sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2 + sed -e 's/!//g' $far_dir/tmp2> $far_dir/tmp3 + sed -e 's/?//g' $far_dir/tmp3> $far_dir/text +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + log "stage 2: finali data process" + + utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far + + sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo + sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo + + # remove space in text + for x in ${tgt}_Ali_far; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + + log "Successfully finished. [elapsed=${SECONDS}s]" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + log "stage 3:process alimeeting far dir (single speaker by oracal time strap)" + cp -r $far_dir/* $far_single_speaker_dir + mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath + paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist + python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir + + cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text + utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt + + ./utils/fix_data_dir.sh $far_single_speaker_dir + utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker + + # remove space in text + for x in ${tgt}_Ali_far_single_speaker; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + log "Successfully finished. [elapsed=${SECONDS}s]" +fi \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py new file mode 100755 index 000000000..8ece75706 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +""" +Process the textgrid files +""" +import argparse +import codecs +from distutils.util import strtobool +from pathlib import Path +import textgrid +import pdb + +class Segment(object): + def __init__(self, uttid, spkr, stime, etime, text): + self.uttid = uttid + self.spkr = spkr + self.spkr_all = uttid+"-"+spkr + self.stime = round(stime, 2) + self.etime = round(etime, 2) + self.text = text + self.spk_text = {uttid+"-"+spkr: text} + + def change_stime(self, time): + self.stime = time + + def change_etime(self, time): + self.etime = time + + +def get_args(): + parser = argparse.ArgumentParser(description="process the textgrid files") + parser.add_argument("--path", type=str, required=True, help="Data path") + parser.add_argument( + "--no-overlap", + type=strtobool, + default=False, + help="Whether to ignore the overlapping utterances.", + ) + parser.add_argument( + "--max_length", + default=100000, + type=float, + help="overlap speech max time,if longger than max length should cut", + ) + parser.add_argument( + "--overlap_length", + default=1, + type=float, + help="if length longer than max length, speech overlength shorter, is cut", + ) + parser.add_argument( + "--mars", + type=strtobool, + default=False, + help="Whether to process mars data set.", + ) + args = parser.parse_args() + return args + + +def preposs_overlap(segments,max_length,overlap_length): + new_segments = [] + # init a helper list to store all overlap segments + tmp_segments = segments[0] + min_stime = segments[0].stime + max_etime = segments[0].etime + overlap_length_big = 1.5 + max_length_big = 15 + for i in range(1, len(segments)): + if segments[i].stime >= max_etime: + # doesn't overlap with preivous segments + new_segments.append(tmp_segments) + tmp_segments = segments[i] + min_stime = segments[i].stime + max_etime = segments[i].etime + else: + # overlap with previous segments + dur_time = max_etime - min_stime + if dur_time < max_length: + if min_stime > segments[i].stime: + min_stime = segments[i].stime + if max_etime < segments[i].etime: + max_etime = segments[i].etime + tmp_segments.stime = min_stime + tmp_segments.etime = max_etime + tmp_segments.text = tmp_segments.text + "src" + segments[i].text + spk_name =segments[i].uttid +"-" + segments[i].spkr + if spk_name in tmp_segments.spk_text: + tmp_segments.spk_text[spk_name] += segments[i].text + else: + tmp_segments.spk_text[spk_name] = segments[i].text + tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name + else: + overlap_time = max_etime - segments[i].stime + if dur_time < max_length_big: + overlap_length_option = overlap_length + else: + overlap_length_option = overlap_length_big + if overlap_time > overlap_length_option: + if min_stime > segments[i].stime: + min_stime = segments[i].stime + if max_etime < segments[i].etime: + max_etime = segments[i].etime + tmp_segments.stime = min_stime + tmp_segments.etime = max_etime + tmp_segments.text = tmp_segments.text + "src" + segments[i].text + spk_name =segments[i].uttid +"-" + segments[i].spkr + if spk_name in tmp_segments.spk_text: + tmp_segments.spk_text[spk_name] += segments[i].text + else: + tmp_segments.spk_text[spk_name] = segments[i].text + tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name + else: + new_segments.append(tmp_segments) + tmp_segments = segments[i] + min_stime = segments[i].stime + max_etime = segments[i].etime + + return new_segments + +def filter_overlap(segments): + new_segments = [] + # init a helper list to store all overlap segments + tmp_segments = [segments[0]] + min_stime = segments[0].stime + max_etime = segments[0].etime + + for i in range(1, len(segments)): + if segments[i].stime >= max_etime: + # doesn't overlap with preivous segments + if len(tmp_segments) == 1: + new_segments.append(tmp_segments[0]) + # TODO: for multi-spkr asr, we can reset the stime/etime to + # min_stime/max_etime for generating a max length mixutre speech + tmp_segments = [segments[i]] + min_stime = segments[i].stime + max_etime = segments[i].etime + else: + # overlap with previous segments + tmp_segments.append(segments[i]) + if min_stime > segments[i].stime: + min_stime = segments[i].stime + if max_etime < segments[i].etime: + max_etime = segments[i].etime + + return new_segments + + +def main(args): + wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8") + textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8") + + # get the path of textgrid file for each utterance + utt2textgrid = {} + for line in textgrid_flist: + path = Path(line.strip()) + uttid = path.stem + utt2textgrid[uttid] = path + + # parse the textgrid file for each utterance + all_segments = [] + for line in wav_scp: + uttid = line.strip().split(" ")[0] + uttid_part=uttid + if args.mars == True: + uttid_list = uttid.split("_") + uttid_part= uttid_list[0]+"_"+uttid_list[1] + if uttid_part not in utt2textgrid: + print("%s doesn't have transcription" % uttid) + continue + + segments = [] + tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part]) + for i in range(tg.__len__()): + for j in range(tg[i].__len__()): + if tg[i][j].mark: + segments.append( + Segment( + uttid, + tg[i].name, + tg[i][j].minTime, + tg[i][j].maxTime, + tg[i][j].mark.strip(), + ) + ) + + segments = sorted(segments, key=lambda x: x.stime) + + if args.no_overlap: + segments = filter_overlap(segments) + else: + segments = preposs_overlap(segments,args.max_length,args.overlap_length) + all_segments += segments + + wav_scp.close() + textgrid_flist.close() + + segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8") + utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8") + text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8") + utt2spk_file_fifo = codecs.open(Path(args.path) / "utt2spk_all_fifo", "w", "utf-8") + + for i in range(len(all_segments)): + utt_name = "%s-%s-%07d-%07d" % ( + all_segments[i].uttid, + all_segments[i].spkr, + all_segments[i].stime * 100, + all_segments[i].etime * 100, + ) + + segments_file.write( + "%s %s %.2f %.2f\n" + % ( + utt_name, + all_segments[i].uttid, + all_segments[i].stime, + all_segments[i].etime, + ) + ) + utt2spk_file.write( + "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr) + ) + utt2spk_file_fifo.write( + "%s %s\n" % (utt_name, all_segments[i].spkr_all) + ) + text_file.write("%s %s\n" % (utt_name, all_segments[i].text)) + + segments_file.close() + utt2spk_file.close() + text_file.close() + utt2spk_file_fifo.close() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py new file mode 100755 index 000000000..81c19659a --- /dev/null +++ b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +""" +Process the textgrid files +""" +import argparse +import codecs +from distutils.util import strtobool +from pathlib import Path +import textgrid +import pdb + +class Segment(object): + def __init__(self, uttid, spkr, stime, etime, text): + self.uttid = uttid + self.spkr = spkr + self.stime = round(stime, 2) + self.etime = round(etime, 2) + self.text = text + + def change_stime(self, time): + self.stime = time + + def change_etime(self, time): + self.etime = time + + +def get_args(): + parser = argparse.ArgumentParser(description="process the textgrid files") + parser.add_argument("--path", type=str, required=True, help="Data path") + parser.add_argument( + "--no-overlap", + type=strtobool, + default=False, + help="Whether to ignore the overlapping utterances.", + ) + parser.add_argument( + "--mars", + type=strtobool, + default=False, + help="Whether to process mars data set.", + ) + args = parser.parse_args() + return args + + +def filter_overlap(segments): + new_segments = [] + # init a helper list to store all overlap segments + tmp_segments = [segments[0]] + min_stime = segments[0].stime + max_etime = segments[0].etime + + for i in range(1, len(segments)): + if segments[i].stime >= max_etime: + # doesn't overlap with preivous segments + if len(tmp_segments) == 1: + new_segments.append(tmp_segments[0]) + # TODO: for multi-spkr asr, we can reset the stime/etime to + # min_stime/max_etime for generating a max length mixutre speech + tmp_segments = [segments[i]] + min_stime = segments[i].stime + max_etime = segments[i].etime + else: + # overlap with previous segments + tmp_segments.append(segments[i]) + if min_stime > segments[i].stime: + min_stime = segments[i].stime + if max_etime < segments[i].etime: + max_etime = segments[i].etime + + return new_segments + + +def main(args): + wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8") + textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8") + + # get the path of textgrid file for each utterance + utt2textgrid = {} + for line in textgrid_flist: + path = Path(line.strip()) + uttid = path.stem + utt2textgrid[uttid] = path + + # parse the textgrid file for each utterance + all_segments = [] + for line in wav_scp: + uttid = line.strip().split(" ")[0] + uttid_part=uttid + if args.mars == True: + uttid_list = uttid.split("_") + uttid_part= uttid_list[0]+"_"+uttid_list[1] + if uttid_part not in utt2textgrid: + print("%s doesn't have transcription" % uttid) + continue + #pdb.set_trace() + segments = [] + try: + tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part]) + except: + pdb.set_trace() + for i in range(tg.__len__()): + for j in range(tg[i].__len__()): + if tg[i][j].mark: + segments.append( + Segment( + uttid, + tg[i].name, + tg[i][j].minTime, + tg[i][j].maxTime, + tg[i][j].mark.strip(), + ) + ) + + segments = sorted(segments, key=lambda x: x.stime) + + if args.no_overlap: + segments = filter_overlap(segments) + + all_segments += segments + + wav_scp.close() + textgrid_flist.close() + + segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8") + utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8") + text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8") + + for i in range(len(all_segments)): + utt_name = "%s-%s-%07d-%07d" % ( + all_segments[i].uttid, + all_segments[i].spkr, + all_segments[i].stime * 100, + all_segments[i].etime * 100, + ) + + segments_file.write( + "%s %s %.2f %.2f\n" + % ( + utt_name, + all_segments[i].uttid, + all_segments[i].stime, + all_segments[i].etime, + ) + ) + utt2spk_file.write( + "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr) + ) + text_file.write("%s %s\n" % (utt_name, all_segments[i].text)) + + segments_file.close() + utt2spk_file.close() + text_file.close() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/alimeeting/sa-asr/local/compute_cpcer.py b/egs/alimeeting/sa-asr/local/compute_cpcer.py new file mode 100644 index 000000000..f4d4a7978 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/compute_cpcer.py @@ -0,0 +1,91 @@ +import editdistance +import sys +import os +from itertools import permutations + + +def load_transcripts(file_path): + trans_list = [] + for one_line in open(file_path, "rt"): + meeting_id, trans = one_line.strip().split(" ") + trans_list.append((meeting_id.strip(), trans.strip())) + + return trans_list + +def calc_spk_trans(trans): + spk_trans_ = [x.strip() for x in trans.split("$")] + spk_trans = [] + for i in range(len(spk_trans_)): + spk_trans.append((str(i), spk_trans_[i])) + return spk_trans + +def calc_cer(ref_trans, hyp_trans): + ref_spk_trans = calc_spk_trans(ref_trans) + hyp_spk_trans = calc_spk_trans(hyp_trans) + ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans) + num_spk = max(len(ref_spk_trans), len(hyp_spk_trans)) + ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans))) + hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans))) + + errors, counts, permutes = [], [], [] + min_error = 0 + cost_dict = {} + for perm in permutations(range(num_spk)): + flag = True + p_err, p_count = 0, 0 + for idx, p in enumerate(perm): + if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0: + flag = False + break + cost_key = "{}-{}".format(idx, p) + if cost_key in cost_dict: + _e = cost_dict[cost_key] + else: + _e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1]) + cost_dict[cost_key] = _e + if _e > min_error > 0: + flag = False + break + p_err += _e + p_count += len(ref_spk_trans[idx][1]) + + if flag: + if p_err < min_error or min_error == 0: + min_error = p_err + + errors.append(p_err) + counts.append(p_count) + permutes.append(perm) + + sd_cer = [(err, cnt, err/cnt, permute) + for err, cnt, permute in zip(errors, counts, permutes)] + # import ipdb;ipdb.set_trace() + best_rst = min(sd_cer, key=lambda x: x[2]) + + return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num + + +def main(): + ref=sys.argv[1] + hyp=sys.argv[2] + result_path=sys.argv[3] + ref_list = load_transcripts(ref) + hyp_list = load_transcripts(hyp) + result_file = open(result_path,'w') + error, count = 0, 0 + for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list): + assert ref_id == hyp_id + mid = ref_id + dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans) + error, count = error + dist, count + length + result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num)) + + # print("{} {:.2f} {} {}".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num)) + + result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0)) + result_file.close() + # print("Sum/Avg: {:.2f}".format(error / count * 100.0)) + + +if __name__ == '__main__': + main() diff --git a/egs/alimeeting/sa-asr/local/compute_wer.py b/egs/alimeeting/sa-asr/local/compute_wer.py new file mode 100755 index 000000000..349a3f609 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/compute_wer.py @@ -0,0 +1,157 @@ +import os +import numpy as np +import sys + +def compute_wer(ref_file, + hyp_file, + cer_detail_file): + rst = { + 'Wrd': 0, + 'Corr': 0, + 'Ins': 0, + 'Del': 0, + 'Sub': 0, + 'Snt': 0, + 'Err': 0.0, + 'S.Err': 0.0, + 'wrong_words': 0, + 'wrong_sentences': 0 + } + + hyp_dict = {} + ref_dict = {} + with open(hyp_file, 'r') as hyp_reader: + for line in hyp_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + hyp_dict[key] = value + with open(ref_file, 'r') as ref_reader: + for line in ref_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + ref_dict[key] = value + + cer_detail_writer = open(cer_detail_file, 'w') + for hyp_key in hyp_dict: + if hyp_key in ref_dict: + out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key]) + rst['Wrd'] += out_item['nwords'] + rst['Corr'] += out_item['cor'] + rst['wrong_words'] += out_item['wrong'] + rst['Ins'] += out_item['ins'] + rst['Del'] += out_item['del'] + rst['Sub'] += out_item['sub'] + rst['Snt'] += 1 + if out_item['wrong'] > 0: + rst['wrong_sentences'] += 1 + cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n') + cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n') + cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n') + + if rst['Wrd'] > 0: + rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) + if rst['Snt'] > 0: + rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) + + cer_detail_writer.write('\n') + cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) + + ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n') + cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n') + cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n') + + +def compute_wer_by_line(hyp, + ref): + hyp = list(map(lambda x: x.lower(), hyp)) + ref = list(map(lambda x: x.lower(), ref)) + + len_hyp = len(hyp) + len_ref = len(ref) + + cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) + + ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) + + for i in range(len_hyp + 1): + cost_matrix[i][0] = i + for j in range(len_ref + 1): + cost_matrix[0][j] = j + + for i in range(1, len_hyp + 1): + for j in range(1, len_ref + 1): + if hyp[i - 1] == ref[j - 1]: + cost_matrix[i][j] = cost_matrix[i - 1][j - 1] + else: + substitution = cost_matrix[i - 1][j - 1] + 1 + insertion = cost_matrix[i - 1][j] + 1 + deletion = cost_matrix[i][j - 1] + 1 + + compare_val = [substitution, insertion, deletion] + + min_val = min(compare_val) + operation_idx = compare_val.index(min_val) + 1 + cost_matrix[i][j] = min_val + ops_matrix[i][j] = operation_idx + + match_idx = [] + i = len_hyp + j = len_ref + rst = { + 'nwords': len_ref, + 'cor': 0, + 'wrong': 0, + 'ins': 0, + 'del': 0, + 'sub': 0 + } + while i >= 0 or j >= 0: + i_idx = max(0, i) + j_idx = max(0, j) + + if ops_matrix[i_idx][j_idx] == 0: # correct + if i - 1 >= 0 and j - 1 >= 0: + match_idx.append((j - 1, i - 1)) + rst['cor'] += 1 + + i -= 1 + j -= 1 + + elif ops_matrix[i_idx][j_idx] == 2: # insert + i -= 1 + rst['ins'] += 1 + + elif ops_matrix[i_idx][j_idx] == 3: # delete + j -= 1 + rst['del'] += 1 + + elif ops_matrix[i_idx][j_idx] == 1: # substitute + i -= 1 + j -= 1 + rst['sub'] += 1 + + if i < 0 and j >= 0: + rst['del'] += 1 + elif j < 0 and i >= 0: + rst['ins'] += 1 + + match_idx.reverse() + wrong_cnt = cost_matrix[len_hyp][len_ref] + rst['wrong'] = wrong_cnt + + return rst + +def print_cer_detail(rst): + return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor']) + + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub=" + + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords']) + + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords'])) + +if __name__ == '__main__': + if len(sys.argv) != 4: + print("usage : python compute-wer.py test.ref test.hyp test.wer") + sys.exit(0) + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + cer_detail_file = sys.argv[3] + compute_wer(ref_file, hyp_file, cer_detail_file) diff --git a/egs/alimeeting/sa-asr/local/download_xvector_model.py b/egs/alimeeting/sa-asr/local/download_xvector_model.py new file mode 100644 index 000000000..7da655944 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/download_xvector_model.py @@ -0,0 +1,6 @@ +from modelscope.hub.snapshot_download import snapshot_download +import sys + + +cache_dir = sys.argv[1] +model_dir = snapshot_download('damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch', cache_dir=cache_dir) diff --git a/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py new file mode 100644 index 000000000..e6061625b --- /dev/null +++ b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py @@ -0,0 +1,22 @@ +import sys +if __name__=="__main__": + uttid_path=sys.argv[1] + src_path=sys.argv[2] + tgt_path=sys.argv[3] + uttid_file=open(uttid_path,'r') + uttid_line=uttid_file.readlines() + uttid_file.close() + ori_utt2spk_all_fifo_file=open(src_path+'/utt2spk_all_fifo','r') + ori_utt2spk_all_fifo_line=ori_utt2spk_all_fifo_file.readlines() + ori_utt2spk_all_fifo_file.close() + new_utt2spk_all_fifo_file=open(tgt_path+'/utt2spk_all_fifo','w') + + uttid_list=[] + for line in uttid_line: + uttid_list.append(line.strip()) + + for line in ori_utt2spk_all_fifo_line: + if line.strip().split(' ')[0] in uttid_list: + new_utt2spk_all_fifo_file.write(line) + + new_utt2spk_all_fifo_file.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py new file mode 100644 index 000000000..c37abf9a0 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py @@ -0,0 +1,167 @@ +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +import numpy as np +import sys +import os +import soundfile +from itertools import permutations +from sklearn.metrics.pairwise import cosine_similarity +from sklearn import cluster + + +def custom_spectral_clustering(affinity, min_n_clusters=2, max_n_clusters=4, refine=True, + threshold=0.995, laplacian_type="graph_cut"): + if refine: + # Symmetrization + affinity = np.maximum(affinity, np.transpose(affinity)) + # Diffusion + affinity = np.matmul(affinity, np.transpose(affinity)) + # Row-wise max normalization + row_max = affinity.max(axis=1, keepdims=True) + affinity = affinity / row_max + + # a) Construct S and set diagonal elements to 0 + affinity = affinity - np.diag(np.diag(affinity)) + # b) Compute Laplacian matrix L and perform normalization: + degree = np.diag(np.sum(affinity, axis=1)) + laplacian = degree - affinity + if laplacian_type == "random_walk": + degree_norm = np.diag(1 / (np.diag(degree) + 1e-10)) + laplacian_norm = degree_norm.dot(laplacian) + else: + degree_half = np.diag(degree) ** 0.5 + 1e-15 + laplacian_norm = laplacian / degree_half[:, np.newaxis] / degree_half + + # c) Compute eigenvalues and eigenvectors of L_norm + eigenvalues, eigenvectors = np.linalg.eig(laplacian_norm) + eigenvalues = eigenvalues.real + eigenvectors = eigenvectors.real + index_array = np.argsort(eigenvalues) + eigenvalues = eigenvalues[index_array] + eigenvectors = eigenvectors[:, index_array] + + # d) Compute the number of clusters k + k = min_n_clusters + for k in range(min_n_clusters, max_n_clusters + 1): + if eigenvalues[k] > threshold: + break + k = max(k, min_n_clusters) + spectral_embeddings = eigenvectors[:, :k] + # print(mid, k, eigenvalues[:10]) + + spectral_embeddings = spectral_embeddings / np.linalg.norm(spectral_embeddings, axis=1, ord=2, keepdims=True) + solver = cluster.KMeans(n_clusters=k, max_iter=1000, random_state=42) + solver.fit(spectral_embeddings) + return solver.labels_ + + +if __name__ == "__main__": + path = sys.argv[1] # dump2/raw/Eval_Ali_far + raw_path = sys.argv[2] # data/local/Eval_Ali_far + threshold = float(sys.argv[3]) # 0.996 + sv_threshold = float(sys.argv[4]) # 0.815 + wav_scp_file = open(path+'/wav.scp', 'r') + wav_scp = wav_scp_file.readlines() + wav_scp_file.close() + raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r') + raw_meeting_scp = raw_meeting_scp_file.readlines() + raw_meeting_scp_file.close() + segments_scp_file = open(raw_path + '/segments', 'r') + segments_scp = segments_scp_file.readlines() + segments_scp_file.close() + + segments_map = {} + for line in segments_scp: + line_list = line.strip().split(' ') + meeting = line_list[1] + seg = (float(line_list[-2]), float(line_list[-1])) + if meeting not in segments_map.keys(): + segments_map[meeting] = [seg] + else: + segments_map[meeting].append(seg) + + inference_sv_pipline = pipeline( + task=Tasks.speaker_verification, + model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch' + ) + + chunk_len = int(1.5*16000) # 1.5 seconds + hop_len = int(0.75*16000) # 0.75 seconds + + os.system("mkdir -p " + path + "/cluster_profile_infer") + cluster_spk_num_file = open(path + '/cluster_spk_num', 'w') + meeting_map = {} + for line in raw_meeting_scp: + meeting = line.strip().split('\t')[0] + wav_path = line.strip().split('\t')[1] + wav = soundfile.read(wav_path)[0] + # take the first channel + if wav.ndim == 2: + wav=wav[:, 0] + # gen_seg_embedding + segments_list = segments_map[meeting] + + # import ipdb;ipdb.set_trace() + all_seg_embedding_list = [] + for seg in segments_list: + wav_seg = wav[int(seg[0] * 16000): int(seg[1] * 16000)] + wav_seg_len = wav_seg.shape[0] + i = 0 + while i < wav_seg_len: + if i + chunk_len < wav_seg_len: + cur_wav_chunk = wav_seg[i: i+chunk_len] + else: + cur_wav_chunk=wav_seg[i: ] + # chunks under 0.2s are ignored + if cur_wav_chunk.shape[0] >= 0.2 * 16000: + cur_chunk_embedding = inference_sv_pipline(audio_in=cur_wav_chunk)["spk_embedding"] + all_seg_embedding_list.append(cur_chunk_embedding) + i += hop_len + all_seg_embedding = np.vstack(all_seg_embedding_list) + # all_seg_embedding (n, dim) + + # compute affinity + affinity=cosine_similarity(all_seg_embedding) + + affinity = np.maximum(affinity - sv_threshold, 0.0001) / (affinity.max() - sv_threshold) + + # clustering + labels = custom_spectral_clustering( + affinity=affinity, + min_n_clusters=2, + max_n_clusters=4, + refine=True, + threshold=threshold, + laplacian_type="graph_cut") + + + cluster_dict={} + for j in range(labels.shape[0]): + if labels[j] not in cluster_dict.keys(): + cluster_dict[labels[j]] = np.atleast_2d(all_seg_embedding[j]) + else: + cluster_dict[labels[j]] = np.concatenate((cluster_dict[labels[j]], np.atleast_2d(all_seg_embedding[j]))) + + emb_list = [] + # get cluster center + for k in cluster_dict.keys(): + cluster_dict[k] = np.mean(cluster_dict[k], axis=0) + emb_list.append(cluster_dict[k]) + + spk_num = len(emb_list) + profile_for_infer = np.vstack(emb_list) + # save profile for each meeting + np.save(path + '/cluster_profile_infer/' + meeting + '.npy', profile_for_infer) + meeting_map[meeting] = (path + '/cluster_profile_infer/' + meeting + '.npy', spk_num) + cluster_spk_num_file.write(meeting + ' ' + str(spk_num) + '\n') + cluster_spk_num_file.flush() + + cluster_spk_num_file.close() + + profile_scp = open(path + "/cluster_profile_infer.scp", 'w') + for line in wav_scp: + uttid = line.strip().split(' ')[0] + meeting = uttid.split('-')[0] + profile_scp.write(uttid + ' ' + meeting_map[meeting][0] + '\n') + profile_scp.flush() + profile_scp.close() diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py new file mode 100644 index 000000000..18286b42d --- /dev/null +++ b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py @@ -0,0 +1,70 @@ +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +import numpy as np +import sys +import os +import soundfile + + +if __name__=="__main__": + path = sys.argv[1] # dump2/raw/Eval_Ali_far + raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker + raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r') + raw_meeting_scp = raw_meeting_scp_file.readlines() + raw_meeting_scp_file.close() + segments_scp_file = open(raw_path + '/segments', 'r') + segments_scp = segments_scp_file.readlines() + segments_scp_file.close() + + oracle_emb_dir = path + '/oracle_embedding/' + os.system("mkdir -p " + oracle_emb_dir) + oracle_emb_scp_file = open(path+'/oracle_embedding.scp', 'w') + + raw_wav_map = {} + for line in raw_meeting_scp: + meeting = line.strip().split('\t')[0] + wav_path = line.strip().split('\t')[1] + raw_wav_map[meeting] = wav_path + + spk_map = {} + for line in segments_scp: + line_list = line.strip().split(' ') + meeting = line_list[1] + spk_id = line_list[0].split('_')[3] + spk = meeting + '_' + spk_id + time_start = float(line_list[-2]) + time_end = float(line_list[-1]) + if time_end - time_start > 0.5: + if spk not in spk_map.keys(): + spk_map[spk] = [(int(time_start * 16000), int(time_end * 16000))] + else: + spk_map[spk].append((int(time_start * 16000), int(time_end * 16000))) + + inference_sv_pipline = pipeline( + task=Tasks.speaker_verification, + model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch' + ) + + for spk in spk_map.keys(): + meeting = spk.split('_SPK')[0] + wav_path = raw_wav_map[meeting] + wav = soundfile.read(wav_path)[0] + # take the first channel + if wav.ndim == 2: + wav = wav[:, 0] + all_seg_embedding_list=[] + # import ipdb;ipdb.set_trace() + for seg_time in spk_map[spk]: + if seg_time[0] < wav.shape[0] - 0.5 * 16000: + if seg_time[1] > wav.shape[0]: + cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: ])["spk_embedding"] + else: + cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: seg_time[1]])["spk_embedding"] + all_seg_embedding_list.append(cur_seg_embedding) + all_seg_embedding = np.vstack(all_seg_embedding_list) + spk_embedding = np.mean(all_seg_embedding, axis=0) + np.save(oracle_emb_dir + spk + '.npy', spk_embedding) + oracle_emb_scp_file.write(spk + ' ' + oracle_emb_dir + spk + '.npy' + '\n') + oracle_emb_scp_file.flush() + + oracle_emb_scp_file.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py new file mode 100644 index 000000000..f44fcd449 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py @@ -0,0 +1,59 @@ +import random +import numpy as np +import os +import sys + + +if __name__=="__main__": + path = sys.argv[1] # dump2/raw/Eval_Ali_far + wav_scp_file = open(path+"/wav.scp", 'r') + wav_scp = wav_scp_file.readlines() + wav_scp_file.close() + spk2id_file = open(path + "/spk2id", 'r') + spk2id = spk2id_file.readlines() + spk2id_file.close() + embedding_scp_file = open(path + "/oracle_embedding.scp", 'r') + embedding_scp = embedding_scp_file.readlines() + embedding_scp_file.close() + + embedding_map = {} + for line in embedding_scp: + spk = line.strip().split(' ')[0] + if spk not in embedding_map.keys(): + emb=np.load(line.strip().split(' ')[1]) + embedding_map[spk] = emb + + meeting_map_tmp = {} + global_spk_list = [] + for line in spk2id: + line_list = line.strip().split(' ') + meeting = line_list[0].split('-')[0] + spk_id = line_list[0].split('-')[-1].split('_')[-1] + spk = meeting + '_' + spk_id + global_spk_list.append(spk) + if meeting in meeting_map_tmp.keys(): + meeting_map_tmp[meeting].append(spk) + else: + meeting_map_tmp[meeting] = [spk] + + meeting_map = {} + os.system('mkdir -p ' + path + '/oracle_profile_nopadding') + for meeting in meeting_map_tmp.keys(): + emb_list = [] + for i in range(len(meeting_map_tmp[meeting])): + spk = meeting_map_tmp[meeting][i] + emb_list.append(embedding_map[spk]) + profile = np.vstack(emb_list) + np.save(path + '/oracle_profile_nopadding/' + meeting + '.npy', profile) + meeting_map[meeting] = path + '/oracle_profile_nopadding/' + meeting + '.npy' + + profile_scp = open(path + '/oracle_profile_nopadding.scp', 'w') + profile_map_scp = open(path + '/oracle_profile_nopadding_spk_list', 'w') + + for line in wav_scp: + uttid = line.strip().split(' ')[0] + meeting = uttid.split('-')[0] + profile_scp.write(uttid + ' ' + meeting_map[meeting] + '\n') + profile_map_scp.write(uttid + ' ' + '$'.join(meeting_map_tmp[meeting]) + '\n') + profile_scp.close() + profile_map_scp.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py new file mode 100644 index 000000000..b70a32a19 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py @@ -0,0 +1,68 @@ +import random +import numpy as np +import os +import sys + + +if __name__=="__main__": + path = sys.argv[1] # dump2/raw/Train_Ali_far + wav_scp_file = open(path+"/wav.scp", 'r') + wav_scp = wav_scp_file.readlines() + wav_scp_file.close() + spk2id_file = open(path+"/spk2id", 'r') + spk2id = spk2id_file.readlines() + spk2id_file.close() + embedding_scp_file = open(path + "/oracle_embedding.scp", 'r') + embedding_scp = embedding_scp_file.readlines() + embedding_scp_file.close() + + embedding_map = {} + for line in embedding_scp: + spk = line.strip().split(' ')[0] + if spk not in embedding_map.keys(): + emb = np.load(line.strip().split(' ')[1]) + embedding_map[spk] = emb + + meeting_map_tmp = {} + global_spk_list = [] + for line in spk2id: + line_list = line.strip().split(' ') + meeting = line_list[0].split('-')[0] + spk_id = line_list[0].split('-')[-1].split('_')[-1] + spk = meeting+'_' + spk_id + global_spk_list.append(spk) + if meeting in meeting_map_tmp.keys(): + meeting_map_tmp[meeting].append(spk) + else: + meeting_map_tmp[meeting] = [spk] + + for meeting in meeting_map_tmp.keys(): + num = len(meeting_map_tmp[meeting]) + if num < 4: + global_spk_list_tmp = global_spk_list[: ] + for spk in meeting_map_tmp[meeting]: + global_spk_list_tmp.remove(spk) + padding_spk = random.sample(global_spk_list_tmp, 4 - num) + meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk + + meeting_map = {} + os.system('mkdir -p ' + path + '/oracle_profile_padding') + for meeting in meeting_map_tmp.keys(): + emb_list = [] + for i in range(len(meeting_map_tmp[meeting])): + spk = meeting_map_tmp[meeting][i] + emb_list.append(embedding_map[spk]) + profile = np.vstack(emb_list) + np.save(path + '/oracle_profile_padding/' + meeting + '.npy',profile) + meeting_map[meeting] = path + '/oracle_profile_padding/' + meeting + '.npy' + + profile_scp = open(path + '/oracle_profile_padding.scp', 'w') + profile_map_scp = open(path + '/oracle_profile_padding_spk_list', 'w') + + for line in wav_scp: + uttid = line.strip().split(' ')[0] + meeting = uttid.split('-')[0] + profile_scp.write(uttid+' ' + meeting_map[meeting] + '\n') + profile_map_scp.write(uttid+' ' + '$'.join(meeting_map_tmp[meeting]) + '\n') + profile_scp.close() + profile_map_scp.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/proce_text.py b/egs/alimeeting/sa-asr/local/proce_text.py new file mode 100755 index 000000000..e56cc0f37 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/proce_text.py @@ -0,0 +1,32 @@ + +import sys +import re + +in_f = sys.argv[1] +out_f = sys.argv[2] + + +with open(in_f, "r", encoding="utf-8") as f: + lines = f.readlines() + +with open(out_f, "w", encoding="utf-8") as f: + for line in lines: + outs = line.strip().split(" ", 1) + if len(outs) == 2: + idx, text = outs + text = re.sub("", "", text) + text = re.sub("", "", text) + text = re.sub("@@", "", text) + text = re.sub("@", "", text) + text = re.sub("", "", text) + text = re.sub(" ", "", text) + text = re.sub("\$", "", text) + text = text.lower() + else: + idx = outs[0] + text = " " + + text = [x for x in text] + text = " ".join(text) + out = "{} {}\n".format(idx, text) + f.write(out) diff --git a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py new file mode 100755 index 000000000..d900bb17a --- /dev/null +++ b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +Process the textgrid files +""" +import argparse +import codecs +from distutils.util import strtobool +from pathlib import Path +import textgrid +import pdb + +def get_args(): + parser = argparse.ArgumentParser(description="process the textgrid files") + parser.add_argument("--path", type=str, required=True, help="Data path") + args = parser.parse_args() + return args + +class Segment(object): + def __init__(self, uttid, text): + self.uttid = uttid + self.text = text + +def main(args): + text = codecs.open(Path(args.path) / "text", "r", "utf-8") + spk2utt = codecs.open(Path(args.path) / "spk2utt", "r", "utf-8") + utt2spk = codecs.open(Path(args.path) / "utt2spk_all_fifo", "r", "utf-8") + spk2id = codecs.open(Path(args.path) / "spk2id", "w", "utf-8") + + spkid_map = {} + meetingid_map = {} + for line in spk2utt: + spkid = line.strip().split(" ")[0] + meeting_id_list = spkid.split("_")[:3] + meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2] + if meeting_id not in meetingid_map: + meetingid_map[meeting_id] = 1 + else: + meetingid_map[meeting_id] += 1 + spkid_map[spkid] = meetingid_map[meeting_id] + spk2id.write("%s %s\n" % (spkid, meetingid_map[meeting_id])) + + utt2spklist = {} + for line in utt2spk: + uttid = line.strip().split(" ")[0] + spkid = line.strip().split(" ")[1] + spklist = spkid.split("$") + tmp = [] + for index in range(len(spklist)): + tmp.append(spkid_map[spklist[index]]) + utt2spklist[uttid] = tmp + # parse the textgrid file for each utterance + all_segments = [] + for line in text: + uttid = line.strip().split(" ")[0] + context = line.strip().split(" ")[1] + spklist = utt2spklist[uttid] + length_text = len(context) + cnt = 0 + tmp_text = "" + for index in range(length_text): + if context[index] != "$": + tmp_text += str(spklist[cnt]) + else: + tmp_text += "$" + cnt += 1 + tmp_seg = Segment(uttid,tmp_text) + all_segments.append(tmp_seg) + + text.close() + utt2spk.close() + spk2utt.close() + spk2id.close() + + text_id = codecs.open(Path(args.path) / "text_id", "w", "utf-8") + + for i in range(len(all_segments)): + uttid_tmp = all_segments[i].uttid + text_tmp = all_segments[i].text + + text_id.write("%s %s\n" % (uttid_tmp, text_tmp)) + + text_id.close() + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/alimeeting/sa-asr/local/process_text_id.py b/egs/alimeeting/sa-asr/local/process_text_id.py new file mode 100644 index 000000000..0a9506e29 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/process_text_id.py @@ -0,0 +1,24 @@ +import sys +if __name__=="__main__": + path=sys.argv[1] + + text_id_old_file=open(path+"/text_id",'r') + text_id_old=text_id_old_file.readlines() + text_id_old_file.close() + + text_id=open(path+"/text_id_train",'w') + for line in text_id_old: + uttid=line.strip().split(' ')[0] + old_id=line.strip().split(' ')[1] + pre_id='0' + new_id_list=[] + for i in old_id: + if i == '$': + new_id_list.append(pre_id) + else: + new_id_list.append(str(int(i)-1)) + pre_id=str(int(i)-1) + new_id_list.append(pre_id) + new_id=' '.join(new_id_list) + text_id.write(uttid+' '+new_id+'\n') + text_id.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/process_text_spk_merge.py b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py new file mode 100644 index 000000000..f15d509b5 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py @@ -0,0 +1,55 @@ +import sys + + +if __name__ == "__main__": + path=sys.argv[1] + text_scp_file = open(path + '/text', 'r') + text_scp = text_scp_file.readlines() + text_scp_file.close() + text_id_scp_file = open(path + '/text_id', 'r') + text_id_scp = text_id_scp_file.readlines() + text_id_scp_file.close() + text_spk_merge_file = open(path + '/text_spk_merge', 'w') + assert len(text_scp) == len(text_id_scp) + + meeting_map = {} # {meeting_id: [(start_time, text, text_id), (start_time, text, text_id), ...]} + for i in range(len(text_scp)): + text_line = text_scp[i].strip().split(' ') + text_id_line = text_id_scp[i].strip().split(' ') + assert text_line[0] == text_id_line[0] + if len(text_line) > 1: + uttid = text_line[0] + text = text_line[1] + text_id = text_id_line[1] + meeting_id = uttid.split('-')[0] + start_time = int(uttid.split('-')[-2]) + if meeting_id not in meeting_map: + meeting_map[meeting_id] = [(start_time,text,text_id)] + else: + meeting_map[meeting_id].append((start_time,text,text_id)) + + for meeting_id in sorted(meeting_map.keys()): + cur_meeting_list = sorted(meeting_map[meeting_id], key=lambda x: x[0]) + text_spk_merge_map = {} #{1: text1, 2: text2, ...} + for cur_utt in cur_meeting_list: + cur_text = cur_utt[1] + cur_text_id = cur_utt[2] + assert len(cur_text)==len(cur_text_id) + if len(cur_text) != 0: + cur_text_split = cur_text.split('$') + cur_text_id_split = cur_text_id.split('$') + assert len(cur_text_split) == len(cur_text_id_split) + for i in range(len(cur_text_split)): + if len(cur_text_split[i]) != 0: + spk_id = int(cur_text_id_split[i][0]) + if spk_id not in text_spk_merge_map.keys(): + text_spk_merge_map[spk_id] = cur_text_split[i] + else: + text_spk_merge_map[spk_id] += cur_text_split[i] + text_spk_merge_list = [] + for spk_id in sorted(text_spk_merge_map.keys()): + text_spk_merge_list.append(text_spk_merge_map[spk_id]) + text_spk_merge_file.write(meeting_id + ' ' + '$'.join(text_spk_merge_list) + '\n') + text_spk_merge_file.flush() + + text_spk_merge_file.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py new file mode 100755 index 000000000..fdf246090 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" +Process the textgrid files +""" +import argparse +import codecs +from distutils.util import strtobool +from pathlib import Path +import textgrid +import pdb +import numpy as np +import sys +import math + + +class Segment(object): + def __init__(self, uttid, spkr, stime, etime, text): + self.uttid = uttid + self.spkr = spkr + self.stime = round(stime, 2) + self.etime = round(etime, 2) + self.text = text + + def change_stime(self, time): + self.stime = time + + def change_etime(self, time): + self.etime = time + + +def get_args(): + parser = argparse.ArgumentParser(description="process the textgrid files") + parser.add_argument("--path", type=str, required=True, help="Data path") + args = parser.parse_args() + return args + + + +def main(args): + textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8") + segment_file = codecs.open(Path(args.path)/"segments", "w", "utf-8") + utt2spk = codecs.open(Path(args.path)/"utt2spk", "w", "utf-8") + + # get the path of textgrid file for each utterance + for line in textgrid_flist: + line_array = line.strip().split(" ") + path = Path(line_array[1]) + uttid = line_array[0] + + try: + tg = textgrid.TextGrid.fromFile(path) + except: + pdb.set_trace() + num_spk = tg.__len__() + spk2textgrid = {} + spk2weight = {} + weight2spk = {} + cnt = 2 + xmax = 0 + for i in range(tg.__len__()): + spk_name = tg[i].name + if spk_name not in spk2weight: + spk2weight[spk_name] = cnt + weight2spk[cnt] = spk_name + cnt = cnt * 2 + segments = [] + for j in range(tg[i].__len__()): + if tg[i][j].mark: + if xmax < tg[i][j].maxTime: + xmax = tg[i][j].maxTime + segments.append( + Segment( + uttid, + tg[i].name, + tg[i][j].minTime, + tg[i][j].maxTime, + tg[i][j].mark.strip(), + ) + ) + segments = sorted(segments, key=lambda x: x.stime) + spk2textgrid[spk_name] = segments + olp_label = np.zeros((num_spk, int(xmax/0.01)), dtype=np.int32) + for spkid in spk2weight.keys(): + weight = spk2weight[spkid] + segments = spk2textgrid[spkid] + idx = int(math.log2(weight) )- 1 + for i in range(len(segments)): + stime = segments[i].stime + etime = segments[i].etime + olp_label[idx, int(stime/0.01): int(etime/0.01)] = weight + sum_label = olp_label.sum(axis=0) + stime = 0 + pre_value = 0 + for pos in range(sum_label.shape[0]): + if sum_label[pos] in weight2spk: + if pre_value in weight2spk: + if sum_label[pos] != pre_value: + spkids = weight2spk[pre_value] + spkid_array = spkids.split("_") + spkid = spkid_array[-1] + #spkid = uttid+spkid + if round(stime*0.01, 2) != round((pos-1)*0.01, 2): + segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2))) + utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid)) + stime = pos + pre_value = sum_label[pos] + else: + stime = pos + pre_value = sum_label[pos] + else: + if pre_value in weight2spk: + spkids = weight2spk[pre_value] + spkid_array = spkids.split("_") + spkid = spkid_array[-1] + #spkid = uttid+spkid + if round(stime*0.01, 2) != round((pos-1)*0.01, 2): + segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2))) + utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid)) + stime = pos + pre_value = sum_label[pos] + textgrid_flist.close() + segment_file.close() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/alimeeting/sa-asr/local/text_format.pl b/egs/alimeeting/sa-asr/local/text_format.pl new file mode 100755 index 000000000..45f1f6428 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/text_format.pl @@ -0,0 +1,14 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter +# Copyright Chao Weng + +# normalizations for hkust trascript +# see the docs/trans-guidelines.pdf for details + +while () { + @A = split(" ", $_); + if (@A == 1) { + next; + } + print $_ +} diff --git a/egs/alimeeting/sa-asr/local/text_normalize.pl b/egs/alimeeting/sa-asr/local/text_normalize.pl new file mode 100755 index 000000000..ac301d466 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/text_normalize.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter +# Copyright Chao Weng + +# normalizations for hkust trascript +# see the docs/trans-guidelines.pdf for details + +while () { + @A = split(" ", $_); + print "$A[0] "; + for ($n = 1; $n < @A; $n++) { + $tmp = $A[$n]; + if ($tmp =~ //) {$tmp =~ s:::g;} + if ($tmp =~ /<%>/) {$tmp =~ s:<%>::g;} + if ($tmp =~ /<->/) {$tmp =~ s:<->::g;} + if ($tmp =~ /<\$>/) {$tmp =~ s:<\$>::g;} + if ($tmp =~ /<#>/) {$tmp =~ s:<#>::g;} + if ($tmp =~ /<_>/) {$tmp =~ s:<_>::g;} + if ($tmp =~ //) {$tmp =~ s:::g;} + if ($tmp =~ /`/) {$tmp =~ s:`::g;} + if ($tmp =~ /&/) {$tmp =~ s:&::g;} + if ($tmp =~ /,/) {$tmp =~ s:,::g;} + if ($tmp =~ /[a-zA-Z]/) {$tmp=uc($tmp);} + if ($tmp =~ /A/) {$tmp =~ s:A:A:g;} + if ($tmp =~ /a/) {$tmp =~ s:a:A:g;} + if ($tmp =~ /b/) {$tmp =~ s:b:B:g;} + if ($tmp =~ /c/) {$tmp =~ s:c:C:g;} + if ($tmp =~ /k/) {$tmp =~ s:k:K:g;} + if ($tmp =~ /t/) {$tmp =~ s:t:T:g;} + if ($tmp =~ /,/) {$tmp =~ s:,::g;} + if ($tmp =~ /丶/) {$tmp =~ s:丶::g;} + if ($tmp =~ /。/) {$tmp =~ s:。::g;} + if ($tmp =~ /、/) {$tmp =~ s:、::g;} + if ($tmp =~ /?/) {$tmp =~ s:?::g;} + print "$tmp "; + } + print "\n"; +} diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa-asr/path.sh new file mode 100755 index 000000000..3aa13d0c2 --- /dev/null +++ b/egs/alimeeting/sa-asr/path.sh @@ -0,0 +1,6 @@ +export FUNASR_DIR=$PWD/../../.. + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PATH=$FUNASR_DIR/funasr/bin:$PATH +export PATH=$PWD/utils/:$PATH \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py new file mode 100755 index 000000000..1fd63d690 --- /dev/null +++ b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +import argparse +import logging +from io import BytesIO +from pathlib import Path +from typing import Tuple, Optional + +import kaldiio +import humanfriendly +import numpy as np +import resampy +import soundfile +from tqdm import tqdm +from typeguard import check_argument_types + +from funasr.utils.cli_utils import get_commandline_args +from funasr.fileio.read_text import read_2column_text +from funasr.fileio.sound_scp import SoundScpWriter + + +def humanfriendly_or_none(value: str): + if value in ("none", "None", "NONE"): + return None + return humanfriendly.parse_size(value) + + +def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]: + """ + + >>> str2int_tuple('3,4,5') + (3, 4, 5) + + """ + assert check_argument_types() + if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"): + return None + return tuple(map(int, integers.strip().split(","))) + + +def main(): + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + logging.basicConfig(level=logging.INFO, format=logfmt) + logging.info(get_commandline_args()) + + parser = argparse.ArgumentParser( + description='Create waves list from "wav.scp"', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("scp") + parser.add_argument("outdir") + parser.add_argument( + "--name", + default="wav", + help="Specify the prefix word of output file name " 'such as "wav.scp"', + ) + parser.add_argument("--segments", default=None) + parser.add_argument( + "--fs", + type=humanfriendly_or_none, + default=None, + help="If the sampling rate specified, " "Change the sampling rate.", + ) + parser.add_argument("--audio-format", default="wav") + group = parser.add_mutually_exclusive_group() + group.add_argument("--ref-channels", default=None, type=str2int_tuple) + group.add_argument("--utt2ref-channels", default=None, type=str) + args = parser.parse_args() + + out_num_samples = Path(args.outdir) / f"utt2num_samples" + + if args.ref_channels is not None: + + def utt2ref_channels(x) -> Tuple[int, ...]: + return args.ref_channels + + elif args.utt2ref_channels is not None: + utt2ref_channels_dict = read_2column_text(args.utt2ref_channels) + + def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]: + chs_str = d[x] + return tuple(map(int, chs_str.split())) + + else: + utt2ref_channels = None + + Path(args.outdir).mkdir(parents=True, exist_ok=True) + out_wavscp = Path(args.outdir) / f"{args.name}.scp" + if args.segments is not None: + # Note: kaldiio supports only wav-pcm-int16le file. + loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments) + if args.audio_format.endswith("ark"): + fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") + fscp = out_wavscp.open("w") + else: + writer = SoundScpWriter( + args.outdir, + out_wavscp, + format=args.audio_format, + ) + + with out_num_samples.open("w") as fnum_samples: + for uttid, (rate, wave) in tqdm(loader): + # wave: (Time,) or (Time, Nmic) + if wave.ndim == 2 and utt2ref_channels is not None: + wave = wave[:, utt2ref_channels(uttid)] + + if args.fs is not None and args.fs != rate: + # FIXME(kamo): To use sox? + wave = resampy.resample( + wave.astype(np.float64), rate, args.fs, axis=0 + ) + wave = wave.astype(np.int16) + rate = args.fs + if args.audio_format.endswith("ark"): + if "flac" in args.audio_format: + suf = "flac" + elif "wav" in args.audio_format: + suf = "wav" + else: + raise RuntimeError("wav.ark or flac") + + # NOTE(kamo): Using extended ark format style here. + # This format is incompatible with Kaldi + kaldiio.save_ark( + fark, + {uttid: (wave, rate)}, + scp=fscp, + append=True, + write_function=f"soundfile_{suf}", + ) + + else: + writer[uttid] = rate, wave + fnum_samples.write(f"{uttid} {len(wave)}\n") + else: + if args.audio_format.endswith("ark"): + fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") + else: + wavdir = Path(args.outdir) / f"data_{args.name}" + wavdir.mkdir(parents=True, exist_ok=True) + + with Path(args.scp).open("r") as fscp, out_wavscp.open( + "w" + ) as fout, out_num_samples.open("w") as fnum_samples: + for line in tqdm(fscp): + uttid, wavpath = line.strip().split(None, 1) + + if wavpath.endswith("|"): + # Streaming input e.g. cat a.wav | + with kaldiio.open_like_kaldi(wavpath, "rb") as f: + with BytesIO(f.read()) as g: + wave, rate = soundfile.read(g, dtype=np.int16) + if wave.ndim == 2 and utt2ref_channels is not None: + wave = wave[:, utt2ref_channels(uttid)] + + if args.fs is not None and args.fs != rate: + # FIXME(kamo): To use sox? + wave = resampy.resample( + wave.astype(np.float64), rate, args.fs, axis=0 + ) + wave = wave.astype(np.int16) + rate = args.fs + + if args.audio_format.endswith("ark"): + if "flac" in args.audio_format: + suf = "flac" + elif "wav" in args.audio_format: + suf = "wav" + else: + raise RuntimeError("wav.ark or flac") + + # NOTE(kamo): Using extended ark format style here. + # This format is incompatible with Kaldi + kaldiio.save_ark( + fark, + {uttid: (wave, rate)}, + scp=fout, + append=True, + write_function=f"soundfile_{suf}", + ) + else: + owavpath = str(wavdir / f"{uttid}.{args.audio_format}") + soundfile.write(owavpath, wave, rate) + fout.write(f"{uttid} {owavpath}\n") + else: + wave, rate = soundfile.read(wavpath, dtype=np.int16) + if wave.ndim == 2 and utt2ref_channels is not None: + wave = wave[:, utt2ref_channels(uttid)] + save_asis = False + + elif args.audio_format.endswith("ark"): + save_asis = False + + elif Path(wavpath).suffix == "." + args.audio_format and ( + args.fs is None or args.fs == rate + ): + save_asis = True + + else: + save_asis = False + + if save_asis: + # Neither --segments nor --fs are specified and + # the line doesn't end with "|", + # i.e. not using unix-pipe, + # only in this case, + # just using the original file as is. + fout.write(f"{uttid} {wavpath}\n") + else: + if args.fs is not None and args.fs != rate: + # FIXME(kamo): To use sox? + wave = resampy.resample( + wave.astype(np.float64), rate, args.fs, axis=0 + ) + wave = wave.astype(np.int16) + rate = args.fs + + if args.audio_format.endswith("ark"): + if "flac" in args.audio_format: + suf = "flac" + elif "wav" in args.audio_format: + suf = "wav" + else: + raise RuntimeError("wav.ark or flac") + + # NOTE(kamo): Using extended ark format style here. + # This format is not supported in Kaldi. + kaldiio.save_ark( + fark, + {uttid: (wave, rate)}, + scp=fout, + append=True, + write_function=f"soundfile_{suf}", + ) + else: + owavpath = str(wavdir / f"{uttid}.{args.audio_format}") + soundfile.write(owavpath, wave, rate) + fout.write(f"{uttid} {owavpath}\n") + fnum_samples.write(f"{uttid} {len(wave)}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py new file mode 100755 index 000000000..b0c61e5b4 --- /dev/null +++ b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +import sys + + +def get_commandline_args(no_executable=True): + extra_chars = [ + " ", + ";", + "&", + "|", + "<", + ">", + "?", + "*", + "~", + "`", + '"', + "'", + "\\", + "{", + "}", + "(", + ")", + ] + + # Escape the extra characters for shell + argv = [ + arg.replace("'", "'\\''") + if all(char not in arg for char in extra_chars) + else "'" + arg.replace("'", "'\\''") + "'" + for arg in sys.argv + ] + + if no_executable: + return " ".join(argv[1:]) + else: + return sys.executable + " " + " ".join(argv) + + +def main(): + print(get_commandline_args()) + + +if __name__ == "__main__": + main() diff --git a/egs/alimeeting/sa-asr/run_m2met_2023.sh b/egs/alimeeting/sa-asr/run_m2met_2023.sh new file mode 100755 index 000000000..807e49948 --- /dev/null +++ b/egs/alimeeting/sa-asr/run_m2met_2023.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +ngpu=4 +device="0,1,2,3" + +#stage 1 creat both near and far +stage=1 +stop_stage=18 + + +train_set=Train_Ali_far +valid_set=Eval_Ali_far +test_sets="Test_Ali_far" +asr_config=conf/train_asr_conformer.yaml +sa_asr_config=conf/train_sa_asr_conformer.yaml +inference_config=conf/decode_asr_rnn.yaml + +lm_config=conf/train_lm_transformer.yaml +use_lm=false +use_wordlm=false +./asr_local.sh \ + --device ${device} \ + --ngpu ${ngpu} \ + --stage ${stage} \ + --stop_stage ${stop_stage} \ + --gpu_inference true \ + --njob_infer 4 \ + --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \ + --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \ + --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \ + --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \ + --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \ + --lang zh \ + --audio_format wav \ + --feats_type raw \ + --token_type char \ + --use_lm ${use_lm} \ + --use_word_lm ${use_wordlm} \ + --lm_config "${lm_config}" \ + --asr_config "${asr_config}" \ + --sa_asr_config "${sa_asr_config}" \ + --inference_config "${inference_config}" \ + --train_set "${train_set}" \ + --valid_set "${valid_set}" \ + --test_sets "${test_sets}" \ + --lm_train_text "data/${train_set}/text" "$@" diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh new file mode 100755 index 000000000..d35e6a693 --- /dev/null +++ b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +ngpu=4 +device="0,1,2,3" + +stage=1 +stop_stage=4 + + +train_set=Train_Ali_far +valid_set=Eval_Ali_far +test_sets="Test_2023_Ali_far" +asr_config=conf/train_asr_conformer.yaml +sa_asr_config=conf/train_sa_asr_conformer.yaml +inference_config=conf/decode_asr_rnn.yaml + +lm_config=conf/train_lm_transformer.yaml +use_lm=false +use_wordlm=false +./asr_local_infer.sh \ + --device ${device} \ + --ngpu ${ngpu} \ + --stage ${stage} \ + --stop_stage ${stop_stage} \ + --gpu_inference true \ + --njob_infer 4 \ + --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \ + --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \ + --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \ + --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \ + --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \ + --lang zh \ + --audio_format wav \ + --feats_type raw \ + --token_type char \ + --use_lm ${use_lm} \ + --use_word_lm ${use_wordlm} \ + --lm_config "${lm_config}" \ + --asr_config "${asr_config}" \ + --sa_asr_config "${sa_asr_config}" \ + --inference_config "${inference_config}" \ + --train_set "${train_set}" \ + --valid_set "${valid_set}" \ + --test_sets "${test_sets}" \ + --lm_train_text "data/${train_set}/text" "$@" diff --git a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh new file mode 100755 index 000000000..15e4563f1 --- /dev/null +++ b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +set -euo pipefail +SECONDS=0 +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} +help_message=$(cat << EOF +Usage: $0 [ []] +e.g. +$0 data/test/wav.scp data/test_format/ + +Format 'wav.scp': In short words, +changing "kaldi-datadir" to "modified-kaldi-datadir" + +The 'wav.scp' format in kaldi is very flexible, +e.g. It can use unix-pipe as describing that wav file, +but it sometime looks confusing and make scripts more complex. +This tools creates actual wav files from 'wav.scp' +and also segments wav files using 'segments'. + +Options + --fs + --segments + --nj + --cmd +EOF +) + +out_filename=wav.scp +cmd=utils/run.pl +nj=30 +fs=none +segments= + +ref_channels= +utt2ref_channels= + +audio_format=wav +write_utt2num_samples=true + +log "$0 $*" +. utils/parse_options.sh + +if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then + log "${help_message}" + log "Error: invalid command line arguments" + exit 1 +fi + +. ./path.sh # Setup the environment + +scp=$1 +if [ ! -f "${scp}" ]; then + log "${help_message}" + echo "$0: Error: No such file: ${scp}" + exit 1 +fi +dir=$2 + + +if [ $# -eq 2 ]; then + logdir=${dir}/logs + outdir=${dir}/data + +elif [ $# -eq 3 ]; then + logdir=$3 + outdir=${dir}/data + +elif [ $# -eq 4 ]; then + logdir=$3 + outdir=$4 +fi + + +mkdir -p ${logdir} + +rm -f "${dir}/${out_filename}" + + +opts= +if [ -n "${utt2ref_channels}" ]; then + opts="--utt2ref-channels ${utt2ref_channels} " +elif [ -n "${ref_channels}" ]; then + opts="--ref-channels ${ref_channels} " +fi + + +if [ -n "${segments}" ]; then + log "[info]: using ${segments}" + nutt=$(<${segments} wc -l) + nj=$((nj /dev/null + +# concatenate the .scp files together. +for n in $(seq ${nj}); do + cat "${outdir}/format.${n}/wav.scp" || exit 1; +done > "${dir}/${out_filename}" || exit 1 + +if "${write_utt2num_samples}"; then + for n in $(seq ${nj}); do + cat "${outdir}/format.${n}/utt2num_samples" || exit 1; + done > "${dir}/utt2num_samples" || exit 1 +fi + +log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh new file mode 100755 index 000000000..9e08dba72 --- /dev/null +++ b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash + +# 2020 @kamo-naoyuki +# This file was copied from Kaldi and +# I deleted parts related to wav duration +# because we shouldn't use kaldi's command here +# and we don't need the files actually. + +# Copyright 2013 Johns Hopkins University (author: Daniel Povey) +# 2014 Tom Ko +# 2018 Emotech LTD (author: Pawel Swietojanski) +# Apache 2.0 + +# This script operates on a directory, such as in data/train/, +# that contains some subset of the following files: +# wav.scp +# spk2utt +# utt2spk +# text +# +# It generates the files which are used for perturbing the speed of the original data. + +export LC_ALL=C +set -euo pipefail + +if [[ $# != 3 ]]; then + echo "Usage: perturb_data_dir_speed.sh " + echo "e.g.:" + echo " $0 0.9 data/train_si284 data/train_si284p" + exit 1 +fi + +factor=$1 +srcdir=$2 +destdir=$3 +label="sp" +spk_prefix="${label}${factor}-" +utt_prefix="${label}${factor}-" + +#check is sox on the path + +! command -v sox &>/dev/null && echo "sox: command not found" && exit 1; + +if [[ ! -f ${srcdir}/utt2spk ]]; then + echo "$0: no such file ${srcdir}/utt2spk" + exit 1; +fi + +if [[ ${destdir} == "${srcdir}" ]]; then + echo "$0: this script requires and to be different." + exit 1 +fi + +mkdir -p "${destdir}" + +<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map" +<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map" +<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map" +if [[ ! -f ${srcdir}/utt2uniq ]]; then + <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq" +else + <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq" +fi + + +<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \ + utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk + +utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt + +if [[ -f ${srcdir}/segments ]]; then + + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \ + utils/apply_map.pl -f 2 "${destdir}"/reco_map | \ + awk -v factor="${factor}" \ + '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \ + >"${destdir}"/segments + + utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ + # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" + awk -v factor="${factor}" \ + '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} + else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } + else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ + > "${destdir}"/wav.scp + if [[ -f ${srcdir}/reco2file_and_channel ]]; then + utils/apply_map.pl -f 1 "${destdir}"/reco_map \ + <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel + fi + +else # no segments->wav indexed by utterance. + if [[ -f ${srcdir}/wav.scp ]]; then + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ + # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" + awk -v factor="${factor}" \ + '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} + else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } + else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ + > "${destdir}"/wav.scp + fi +fi + +if [[ -f ${srcdir}/text ]]; then + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text +fi +if [[ -f ${srcdir}/spk2gender ]]; then + utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender +fi +if [[ -f ${srcdir}/utt2lang ]]; then + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang +fi + +rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null +echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}" + +utils/validate_data_dir.sh --no-feats --no-text "${destdir}" diff --git a/egs/alimeeting/sa-asr/utils/apply_map.pl b/egs/alimeeting/sa-asr/utils/apply_map.pl new file mode 100755 index 000000000..725d3463a --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/apply_map.pl @@ -0,0 +1,97 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0. + +# This program is a bit like ./sym2int.pl in that it applies a map +# to things in a file, but it's a bit more general in that it doesn't +# assume the things being mapped to are single tokens, they could +# be sequences of tokens. See the usage message. + + +$permissive = 0; + +for ($x = 0; $x <= 2; $x++) { + + if (@ARGV > 0 && $ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } + } + + if (@ARGV > 0 && $ARGV[0] eq '--permissive') { + shift @ARGV; + # Mapping is optional (missing key is printed to output) + $permissive = 1; + } +} + +if(@ARGV != 1) { + print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n"; + print STDERR <<'EOF'; +Usage: apply_map.pl [options] map output + options: [-f ] [--permissive] + This applies a map to some specified fields of some input text: + For each line in the map file: the first field is the thing we + map from, and the remaining fields are the sequence we map it to. + The -f (field-range) option says which fields of the input file the map + map should apply to. + If the --permissive option is supplied, fields which are not present + in the map will be left as they were. + Applies the map 'map' to all input text, where each line of the map + is interpreted as a map from the first field to the list of the other fields + Note: can look like 4-5, or 4-, or 5-, or 1, it means the field + range in the input to apply the map to. + e.g.: echo A B | apply_map.pl a.txt + where a.txt is: + A a1 a2 + B b + will produce: + a1 a2 b +EOF + exit(1); +} + +($map_file) = @ARGV; +open(M, "<$map_file") || die "Error opening map file $map_file: $!"; + +while () { + @A = split(" ", $_); + @A >= 1 || die "apply_map.pl: empty line."; + $i = shift @A; + $o = join(" ", @A); + $map{$i} = $o; +} + +while() { + @A = split(" ", $_); + for ($x = 0; $x < @A; $x++) { + if ( (!defined $field_begin || $x >= $field_begin) + && (!defined $field_end || $x <= $field_end)) { + $a = $A[$x]; + if (!defined $map{$a}) { + if (!$permissive) { + die "apply_map.pl: undefined key $a in $map_file\n"; + } else { + print STDERR "apply_map.pl: warning! missing key $a in $map_file\n"; + } + } else { + $A[$x] = $map{$a}; + } + } + } + print join(" ", @A) . "\n"; +} diff --git a/egs/alimeeting/sa-asr/utils/combine_data.sh b/egs/alimeeting/sa-asr/utils/combine_data.sh new file mode 100755 index 000000000..e1eba8539 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/combine_data.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# 2014 David Snyder + +# This script combines the data from multiple source directories into +# a single destination directory. + +# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information +# about what these directories contain. + +# Begin configuration section. +extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..." +skip_fix=false # skip the fix_data_dir.sh in the end +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# -lt 2 ]; then + echo "Usage: combine_data.sh [--extra-files 'file1 file2'] ..." + echo "Note, files that don't appear in all source dirs will not be combined," + echo "with the exception of utt2uniq and segments, which are created where necessary." + exit 1 +fi + +dest=$1; +shift; + +first_src=$1; + +rm -r $dest 2>/dev/null || true +mkdir -p $dest; + +export LC_ALL=C + +for dir in $*; do + if [ ! -f $dir/utt2spk ]; then + echo "$0: no such file $dir/utt2spk" + exit 1; + fi +done + +# Check that frame_shift are compatible, where present together with features. +dir_with_frame_shift= +for dir in $*; do + if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then + if [[ $dir_with_frame_shift ]] && + ! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then + echo "$0:error: different frame_shift in directories $dir and " \ + "$dir_with_frame_shift. Cannot combine features." + exit 1; + fi + dir_with_frame_shift=$dir + fi +done + +# W.r.t. utt2uniq file the script has different behavior compared to other files +# it is not compulsary for it to exist in src directories, but if it exists in +# even one it should exist in all. We will create the files where necessary +has_utt2uniq=false +for in_dir in $*; do + if [ -f $in_dir/utt2uniq ]; then + has_utt2uniq=true + break + fi +done + +if $has_utt2uniq; then + # we are going to create an utt2uniq file in the destdir + for in_dir in $*; do + if [ ! -f $in_dir/utt2uniq ]; then + # we assume that utt2uniq is a one to one mapping + cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}' + else + cat $in_dir/utt2uniq + fi + done | sort -k1 > $dest/utt2uniq + echo "$0: combined utt2uniq" +else + echo "$0 [info]: not combining utt2uniq as it does not exist" +fi +# some of the old scripts might provide utt2uniq as an extrafile, so just remove it +extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g") + +# segments are treated similarly to utt2uniq. If it exists in some, but not all +# src directories, then we generate segments where necessary. +has_segments=false +for in_dir in $*; do + if [ -f $in_dir/segments ]; then + has_segments=true + break + fi +done + +if $has_segments; then + for in_dir in $*; do + if [ ! -f $in_dir/segments ]; then + echo "$0 [info]: will generate missing segments for $in_dir" 1>&2 + utils/data/get_segments_for_data.sh $in_dir + else + cat $in_dir/segments + fi + done | sort -k1 > $dest/segments + echo "$0: combined segments" +else + echo "$0 [info]: not combining segments as it does not exist" +fi + +for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do + exists_somewhere=false + absent_somewhere=false + for d in $*; do + if [ -f $d/$file ]; then + exists_somewhere=true + else + absent_somewhere=true + fi + done + + if ! $absent_somewhere; then + set -o pipefail + ( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1; + set +o pipefail + echo "$0: combined $file" + else + if ! $exists_somewhere; then + echo "$0 [info]: not combining $file as it does not exist" + else + echo "$0 [info]: **not combining $file as it does not exist everywhere**" + fi + fi +done + +utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt + +if [[ $dir_with_frame_shift ]]; then + cp $dir_with_frame_shift/frame_shift $dest +fi + +if ! $skip_fix ; then + utils/fix_data_dir.sh $dest || exit 1; +fi + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/copy_data_dir.sh b/egs/alimeeting/sa-asr/utils/copy_data_dir.sh new file mode 100755 index 000000000..9fd420c42 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/copy_data_dir.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash + +# Copyright 2013 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +# This script operates on a directory, such as in data/train/, +# that contains some subset of the following files: +# feats.scp +# wav.scp +# vad.scp +# spk2utt +# utt2spk +# text +# +# It copies to another directory, possibly adding a specified prefix or a suffix +# to the utterance and/or speaker names. Note, the recording-ids stay the same. +# + + +# begin configuration section +spk_prefix= +utt_prefix= +spk_suffix= +utt_suffix= +validate_opts= # should rarely be needed. +# end configuration section + +. utils/parse_options.sh + +if [ $# != 2 ]; then + echo "Usage: " + echo " $0 [options] " + echo "e.g.:" + echo " $0 --spk-prefix=1- --utt-prefix=1- data/train data/train_1" + echo "Options" + echo " --spk-prefix= # Prefix for speaker ids, default empty" + echo " --utt-prefix= # Prefix for utterance ids, default empty" + echo " --spk-suffix= # Suffix for speaker ids, default empty" + echo " --utt-suffix= # Suffix for utterance ids, default empty" + exit 1; +fi + + +export LC_ALL=C + +srcdir=$1 +destdir=$2 + +if [ ! -f $srcdir/utt2spk ]; then + echo "copy_data_dir.sh: no such file $srcdir/utt2spk" + exit 1; +fi + +if [ "$destdir" == "$srcdir" ]; then + echo "$0: this script requires and to be different." + exit 1 +fi + +set -e; + +mkdir -p $destdir + +cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/utt_map +cat $srcdir/spk2utt | awk -v p=$spk_prefix -v s=$spk_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/spk_map + +if [ ! -f $srcdir/utt2uniq ]; then + if [[ ! -z $utt_prefix || ! -z $utt_suffix ]]; then + cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $1);}' > $destdir/utt2uniq + fi +else + cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq +fi + +cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \ + utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk + +utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt + +if [ -f $srcdir/feats.scp ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp +fi + +if [ -f $srcdir/vad.scp ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp +fi + +if [ -f $srcdir/segments ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments + cp $srcdir/wav.scp $destdir +else # no segments->wav indexed by utt. + if [ -f $srcdir/wav.scp ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp + fi +fi + +if [ -f $srcdir/reco2file_and_channel ]; then + cp $srcdir/reco2file_and_channel $destdir/ +fi + +if [ -f $srcdir/text ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text +fi +if [ -f $srcdir/utt2dur ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur +fi +if [ -f $srcdir/utt2num_frames ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames +fi +if [ -f $srcdir/reco2dur ]; then + if [ -f $srcdir/segments ]; then + cp $srcdir/reco2dur $destdir/reco2dur + else + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur + fi +fi +if [ -f $srcdir/spk2gender ]; then + utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender +fi +if [ -f $srcdir/cmvn.scp ]; then + utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp +fi +for f in frame_shift stm glm ctm; do + if [ -f $srcdir/$f ]; then + cp $srcdir/$f $destdir + fi +done + +rm $destdir/spk_map $destdir/utt_map + +echo "$0: copied data from $srcdir to $destdir" + +for f in feats.scp cmvn.scp vad.scp utt2lang utt2uniq utt2dur utt2num_frames text wav.scp reco2file_and_channel frame_shift stm glm ctm; do + if [ -f $destdir/$f ] && [ ! -f $srcdir/$f ]; then + echo "$0: file $f exists in dest $destdir but not in src $srcdir. Moving it to" + echo " ... $destdir/.backup/$f" + mkdir -p $destdir/.backup + mv $destdir/$f $destdir/.backup/ + fi +done + + +[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" +[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text" + +utils/validate_data_dir.sh $validate_opts $destdir diff --git a/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh b/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh new file mode 100755 index 000000000..24f51e723 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash + +# Copyright 2016 Johns Hopkins University (author: Daniel Povey) +# 2018 Andrea Carmantini +# Apache 2.0 + +# This script operates on a data directory, such as in data/train/, and adds the +# reco2dur file if it does not already exist. The file 'reco2dur' maps from +# recording to the duration of the recording in seconds. This script works it +# out from the 'wav.scp' file, or, if utterance-ids are the same as recording-ids, from the +# utt2dur file (it first tries interrogating the headers, and if this fails, it reads the wave +# files in entirely.) +# We could use durations from segments file, but that's not the duration of the recordings +# but the sum of utterance lenghts (silence in between could be excluded from segments) +# For sum of utterance lenghts: +# awk 'FNR==NR{uttdur[$1]=$2;next} +# { for(i=2;i<=NF;i++){dur+=uttdur[$i];} +# print $1 FS dur; dur=0 }' $data/utt2dur $data/reco2utt + + +frame_shift=0.01 +cmd=run.pl +nj=4 + +. utils/parse_options.sh +. ./path.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 data/train" + echo " Options:" + echo " --frame-shift # frame shift in seconds. Only relevant when we are" + echo " # getting duration from feats.scp (default: 0.01). " + exit 1 +fi + +export LC_ALL=C + +data=$1 + + +if [ -s $data/reco2dur ] && \ + [ $(wc -l < $data/wav.scp) -eq $(wc -l < $data/reco2dur) ]; then + echo "$0: $data/reco2dur already exists with the expected length. We won't recompute it." + exit 0; +fi + +if [ -s $data/utt2dur ] && \ + [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ] && \ + [ ! -s $data/segments ]; then + + echo "$0: $data/wav.scp indexed by utt-id; copying utt2dur to reco2dur" + cp $data/utt2dur $data/reco2dur && exit 0; + +elif [ -f $data/wav.scp ]; then + echo "$0: obtaining durations from recordings" + + # if the wav.scp contains only lines of the form + # utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph | + if cat $data/wav.scp | perl -e ' + while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space. + @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ && + $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); } + $reco = $A[0]; $sphere_file = $A[4]; + + if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; } + $sample_rate = -1; $sample_count = -1; + for ($n = 0; $n <= 30; $n++) { + $line = ; + if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; } + if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; } + if ($line =~ m/end_head/) { break; } + } + close(F); + if ($sample_rate == -1 || $sample_count == -1) { + die "could not parse sphere header from $sphere_file"; + } + $duration = $sample_count * 1.0 / $sample_rate; + print "$reco $duration\n"; + } ' > $data/reco2dur; then + echo "$0: successfully obtained recording lengths from sphere-file headers" + else + echo "$0: could not get recording lengths from sphere-file headers, using wav-to-duration" + if ! command -v wav-to-duration >/dev/null; then + echo "$0: wav-to-duration is not on your path" + exit 1; + fi + + read_entire_file=false + if grep -q 'sox.*speed' $data/wav.scp; then + read_entire_file=true + echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow." + echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or " + echo "... perturb_data_dir_speed_3way.sh." + fi + + num_recos=$(wc -l <$data/wav.scp) + if [ $nj -gt $num_recos ]; then + nj=$num_recos + fi + + temp_data_dir=$data/wav${nj}split + wavscps=$(for n in `seq $nj`; do echo $temp_data_dir/$n/wav.scp; done) + subdirs=$(for n in `seq $nj`; do echo $temp_data_dir/$n; done) + + if ! mkdir -p $subdirs >&/dev/null; then + for n in `seq $nj`; do + mkdir -p $temp_data_dir/$n + done + fi + + utils/split_scp.pl $data/wav.scp $wavscps + + + $cmd JOB=1:$nj $data/log/get_reco_durations.JOB.log \ + wav-to-duration --read-entire-file=$read_entire_file \ + scp:$temp_data_dir/JOB/wav.scp ark,t:$temp_data_dir/JOB/reco2dur || \ + { echo "$0: there was a problem getting the durations"; exit 1; } # This could + + for n in `seq $nj`; do + cat $temp_data_dir/$n/reco2dur + done > $data/reco2dur + fi + rm -r $temp_data_dir +else + echo "$0: Expected $data/wav.scp to exist" + exit 1 +fi + +len1=$(wc -l < $data/wav.scp) +len2=$(wc -l < $data/reco2dur) +if [ "$len1" != "$len2" ]; then + echo "$0: warning: length of reco2dur does not equal that of wav.scp, $len2 != $len1" + if [ $len1 -gt $[$len2*2] ]; then + echo "$0: less than half of recordings got a duration: failing." + exit 1 + fi +fi + +echo "$0: computed $data/reco2dur" + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh b/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh new file mode 100755 index 000000000..6b161b31e --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +# This script operates on a data directory, such as in data/train/, +# and writes new segments to stdout. The file 'segments' maps from +# utterance to time offsets into a recording, with the format: +# +# This script assumes utterance and recording ids are the same (i.e., that +# wav.scp is indexed by utterance), and uses durations from 'utt2dur', +# created if necessary by get_utt2dur.sh. + +. ./path.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 data/train > data/train/segments" + exit 1 +fi + +data=$1 + +if [ ! -s $data/utt2dur ]; then + utils/data/get_utt2dur.sh $data 1>&2 || exit 1; +fi + +# 0 +awk '{ print $1, $1, 0, $2 }' $data/utt2dur + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh b/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh new file mode 100755 index 000000000..5ee7ea30d --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh @@ -0,0 +1,135 @@ +#!/usr/bin/env bash + +# Copyright 2016 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +# This script operates on a data directory, such as in data/train/, and adds the +# utt2dur file if it does not already exist. The file 'utt2dur' maps from +# utterance to the duration of the utterance in seconds. This script works it +# out from the 'segments' file, or, if not present, from the wav.scp file (it +# first tries interrogating the headers, and if this fails, it reads the wave +# files in entirely.) + +frame_shift=0.01 +cmd=run.pl +nj=4 +read_entire_file=false + +. utils/parse_options.sh +. ./path.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 data/train" + echo " Options:" + echo " --frame-shift # frame shift in seconds. Only relevant when we are" + echo " # getting duration from feats.scp, and only if the " + echo " # file frame_shift does not exist (default: 0.01). " + exit 1 +fi + +export LC_ALL=C + +data=$1 + +if [ -s $data/utt2dur ] && \ + [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ]; then + echo "$0: $data/utt2dur already exists with the expected length. We won't recompute it." + exit 0; +fi + +if [ -s $data/segments ]; then + echo "$0: working out $data/utt2dur from $data/segments" + awk '{len=$4-$3; print $1, len;}' < $data/segments > $data/utt2dur +elif [[ -s $data/frame_shift && -f $data/utt2num_frames ]]; then + echo "$0: computing $data/utt2dur from $data/{frame_shift,utt2num_frames}." + frame_shift=$(cat $data/frame_shift) || exit 1 + # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift. + awk -v fs=$frame_shift '{ $2=($2+1.5)*fs; print }' <$data/utt2num_frames >$data/utt2dur +elif [ -f $data/wav.scp ]; then + echo "$0: segments file does not exist so getting durations from wave files" + + # if the wav.scp contains only lines of the form + # utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph | + if perl <$data/wav.scp -e ' + while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space. + @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ && + $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); } + $utt = $A[0]; $sphere_file = $A[4]; + + if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; } + $sample_rate = -1; $sample_count = -1; + for ($n = 0; $n <= 30; $n++) { + $line = ; + if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; } + if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; } + if ($line =~ m/end_head/) { break; } + } + close(F); + if ($sample_rate == -1 || $sample_count == -1) { + die "could not parse sphere header from $sphere_file"; + } + $duration = $sample_count * 1.0 / $sample_rate; + print "$utt $duration\n"; + } ' > $data/utt2dur; then + echo "$0: successfully obtained utterance lengths from sphere-file headers" + else + echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration" + if ! command -v wav-to-duration >/dev/null; then + echo "$0: wav-to-duration is not on your path" + exit 1; + fi + + if grep -q 'sox.*speed' $data/wav.scp; then + read_entire_file=true + echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow." + echo "... It is much faster if you call get_utt2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or " + echo "... perturb_data_dir_speed_3way.sh." + fi + + + num_utts=$(wc -l <$data/utt2spk) + if [ $nj -gt $num_utts ]; then + nj=$num_utts + fi + + utils/data/split_data.sh --per-utt $data $nj + sdata=$data/split${nj}utt + + $cmd JOB=1:$nj $data/log/get_durations.JOB.log \ + wav-to-duration --read-entire-file=$read_entire_file \ + scp:$sdata/JOB/wav.scp ark,t:$sdata/JOB/utt2dur || \ + { echo "$0: there was a problem getting the durations"; exit 1; } + + for n in `seq $nj`; do + cat $sdata/$n/utt2dur + done > $data/utt2dur + fi +elif [ -f $data/feats.scp ]; then + echo "$0: wave file does not exist so getting durations from feats files" + if [[ -s $data/frame_shift ]]; then + frame_shift=$(cat $data/frame_shift) || exit 1 + echo "$0: using frame_shift=$frame_shift from file $data/frame_shift" + fi + # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift. + feat-to-len scp:$data/feats.scp ark,t:- | + awk -v frame_shift=$frame_shift '{print $1, ($2+1.5)*frame_shift}' >$data/utt2dur +else + echo "$0: Expected $data/wav.scp, $data/segments or $data/feats.scp to exist" + exit 1 +fi + +len1=$(wc -l < $data/utt2spk) +len2=$(wc -l < $data/utt2dur) +if [ "$len1" != "$len2" ]; then + echo "$0: warning: length of utt2dur does not equal that of utt2spk, $len2 != $len1" + if [ $len1 -gt $[$len2*2] ]; then + echo "$0: less than half of utterances got a duration: failing." + exit 1 + fi +fi + +echo "$0: computed $data/utt2dur" + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/data/split_data.sh b/egs/alimeeting/sa-asr/utils/data/split_data.sh new file mode 100755 index 000000000..8aa71a1f2 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/data/split_data.sh @@ -0,0 +1,160 @@ +#!/usr/bin/env bash +# Copyright 2010-2013 Microsoft Corporation +# Johns Hopkins University (Author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +split_per_spk=true +if [ "$1" == "--per-utt" ]; then + split_per_spk=false + shift +fi + +if [ $# != 2 ]; then + echo "Usage: $0 [--per-utt] " + echo "E.g.: $0 data/train 50" + echo "It creates its output in e.g. data/train/split50/{1,2,3,...50}, or if the " + echo "--per-utt option was given, in e.g. data/train/split50utt/{1,2,3,...50}." + echo "" + echo "This script will not split the data-dir if it detects that the output is newer than the input." + echo "By default it splits per speaker (so each speaker is in only one split dir)," + echo "but with the --per-utt option it will ignore the speaker information while splitting." + exit 1 +fi + +data=$1 +numsplit=$2 + +if ! [ "$numsplit" -gt 0 ]; then + echo "Invalid num-split argument $numsplit"; + exit 1; +fi + +if $split_per_spk; then + warning_opt= +else + # suppress warnings from filter_scps.pl about 'some input lines were output + # to multiple files'. + warning_opt="--no-warn" +fi + +n=0; +feats="" +wavs="" +utt2spks="" +texts="" + +nu=`cat $data/utt2spk | wc -l` +nf=`cat $data/feats.scp 2>/dev/null | wc -l` +nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file +if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then + echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can " + echo "** use utils/fix_data_dir.sh $data to fix this." +fi +if [ -f $data/text ] && [ $nu -ne $nt ]; then + echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can " + echo "** use utils/fix_data_dir.sh to fix this." +fi + + +if $split_per_spk; then + utt2spk_opt="--utt2spk=$data/utt2spk" + utt="" +else + utt2spk_opt= + utt="utt" +fi + +s1=$data/split${numsplit}${utt}/1 +if [ ! -d $s1 ]; then + need_to_split=true +else + need_to_split=false + for f in utt2spk spk2utt spk2warp feats.scp text wav.scp cmvn.scp spk2gender \ + vad.scp segments reco2file_and_channel utt2lang; do + if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then + need_to_split=true + fi + done +fi + +if ! $need_to_split; then + exit 0; +fi + +utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done) + +directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done) + +# if this mkdir fails due to argument-list being too long, iterate. +if ! mkdir -p $directories >&/dev/null; then + for n in `seq $numsplit`; do + mkdir -p $data/split${numsplit}${utt}/$n + done +fi + +# If lockfile is not installed, just don't lock it. It's not a big deal. +which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock +trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM + +utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 + +for n in `seq $numsplit`; do + dsn=$data/split${numsplit}${utt}/$n + utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1; +done + +maybe_wav_scp= +if [ ! -f $data/segments ]; then + maybe_wav_scp=wav.scp # If there is no segments file, then wav file is + # indexed per utt. +fi + +# split some things that are indexed by utterance. +for f in feats.scp text vad.scp utt2lang $maybe_wav_scp utt2dur utt2num_frames; do + if [ -f $data/$f ]; then + utils/filter_scps.pl JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/utt2spk $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1; + fi +done + +# split some things that are indexed by speaker +for f in spk2gender spk2warp cmvn.scp; do + if [ -f $data/$f ]; then + utils/filter_scps.pl $warning_opt JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/spk2utt $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1; + fi +done + +if [ -f $data/segments ]; then + utils/filter_scps.pl JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/utt2spk $data/segments $data/split${numsplit}${utt}/JOB/segments || exit 1 + for n in `seq $numsplit`; do + dsn=$data/split${numsplit}${utt}/$n + awk '{print $2;}' $dsn/segments | sort | uniq > $dsn/tmp.reco # recording-ids. + done + if [ -f $data/reco2file_and_channel ]; then + utils/filter_scps.pl $warning_opt JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/tmp.reco $data/reco2file_and_channel \ + $data/split${numsplit}${utt}/JOB/reco2file_and_channel || exit 1 + fi + if [ -f $data/wav.scp ]; then + utils/filter_scps.pl $warning_opt JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/tmp.reco $data/wav.scp \ + $data/split${numsplit}${utt}/JOB/wav.scp || exit 1 + fi + for f in $data/split${numsplit}${utt}/*/tmp.reco; do rm $f; done +fi + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/filter_scp.pl b/egs/alimeeting/sa-asr/utils/filter_scp.pl new file mode 100755 index 000000000..b76d37f41 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/alimeeting/sa-asr/utils/fix_data_dir.sh b/egs/alimeeting/sa-asr/utils/fix_data_dir.sh new file mode 100755 index 000000000..ed4710d0b --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/fix_data_dir.sh @@ -0,0 +1,215 @@ +#!/usr/bin/env bash + +# This script makes sure that only the segments present in +# all of "feats.scp", "wav.scp" [if present], segments [if present] +# text, and utt2spk are present in any of them. +# It puts the original contents of data-dir into +# data-dir/.backup + +cmd="$@" + +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: utils/data/fix_data_dir.sh " + echo "e.g.: utils/data/fix_data_dir.sh data/train" + echo "This script helps ensure that the various files in a data directory" + echo "are correctly sorted and filtered, for example removing utterances" + echo "that have no features (if feats.scp is present)" + exit 1 +fi + +data=$1 + +if [ -f $data/images.scp ]; then + image/fix_data_dir.sh $cmd + exit $? +fi + +mkdir -p $data/.backup + +[ ! -d $data ] && echo "$0: no such directory $data" && exit 1; + +[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1; + +set -e -o pipefail -u + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted { + file=$1 + sort -k1,1 -u <$file >$file.tmp + if ! cmp -s $file $file.tmp; then + echo "$0: file $1 is not in sorted order or not unique, sorting it" + mv $file.tmp $file + else + rm $file.tmp + fi +} + +for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \ + reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + check_sorted $data/$x + fi +done + + +function filter_file { + filter=$1 + file_to_filter=$2 + cp $file_to_filter ${file_to_filter}.tmp + utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter + if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then + length1=$(cat ${file_to_filter}.tmp | wc -l) + length2=$(cat ${file_to_filter} | wc -l) + if [ $length1 -ne $length2 ]; then + echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter." + fi + fi + rm $file_to_filter.tmp +} + +function filter_recordings { + # We call this once before the stage when we filter on utterance-id, and once + # after. + + if [ -f $data/segments ]; then + # We have a segments file -> we need to filter this and the file wav.scp, and + # reco2file_and_utt, if it exists, to make sure they have the same list of + # recording-ids. + + if [ ! -f $data/wav.scp ]; then + echo "$0: $data/segments exists but not $data/wav.scp" + exit 1; + fi + awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings + n1=$(cat $tmpdir/recordings | wc -l) + [ ! -s $tmpdir/recordings ] && \ + echo "Empty list of recordings (bad file $data/segments)?" && exit 1; + utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp + mv $tmpdir/recordings.tmp $tmpdir/recordings + + + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + filter_file $tmpdir/recordings $data/segments + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + rm $data/segments.tmp + + filter_file $tmpdir/recordings $data/wav.scp + [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel + [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur + true + fi +} + +function filter_speakers { + # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... + utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + for s in cmvn.scp spk2gender; do + f=$data/$s + if [ -f $f ]; then + filter_file $f $tmpdir/speakers + fi + done + + filter_file $tmpdir/speakers $data/spk2utt + utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk + + for s in cmvn.scp spk2gender $spk_extra_files; do + f=$data/$s + if [ -f $f ]; then + filter_file $tmpdir/speakers $f + fi + done +} + +function filter_utts { + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + + ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order (fix this yourself)" && exit 1; + + ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; + + ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \ + echo "spk2utt is not in sorted order (fix this yourself)" && exit 1; + + if [ -f $data/utt2uniq ]; then + ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \ + echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1; + fi + + maybe_wav= + maybe_reco2dur= + [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist. + [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts + + maybe_utt2dur= + if [ -f $data/utt2dur ]; then + cat $data/utt2dur | \ + awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2dur.ok || exit 1 + maybe_utt2dur=utt2dur.ok + fi + + maybe_utt2num_frames= + if [ -f $data/utt2num_frames ]; then + cat $data/utt2num_frames | \ + awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2num_frames.ok || exit 1 + maybe_utt2num_frames=utt2num_frames.ok + fi + + for x in feats.scp text segments utt2lang $maybe_wav $maybe_utt2dur $maybe_utt2num_frames; do + if [ -f $data/$x ]; then + utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp + mv $tmpdir/utts.tmp $tmpdir/utts + fi + done + rm $data/utt2dur.ok 2>/dev/null || true + rm $data/utt2num_frames.ok 2>/dev/null || true + + [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \ + rm $tmpdir/utts && exit 1; + + + if [ -f $data/utt2spk ]; then + new_nutts=$(cat $tmpdir/utts | wc -l) + old_nutts=$(cat $data/utt2spk | wc -l) + if [ $new_nutts -ne $old_nutts ]; then + echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts" + else + echo "fix_data_dir.sh: kept all $old_nutts utterances." + fi + fi + + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then + utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x + fi + fi + done + +} + +filter_recordings +filter_speakers +filter_utts +filter_speakers +filter_recordings + +utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + +echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/egs/alimeeting/sa-asr/utils/parse_options.sh b/egs/alimeeting/sa-asr/utils/parse_options.sh new file mode 100755 index 000000000..71fb9e5ea --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl new file mode 100755 index 000000000..23992f25d --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl @@ -0,0 +1,27 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +while(<>){ + @A = split(" ", $_); + @A > 1 || die "Invalid line in spk2utt file: $_"; + $s = shift @A; + foreach $u ( @A ) { + print "$u $s\n"; + } +} + + diff --git a/egs/alimeeting/sa-asr/utils/split_scp.pl b/egs/alimeeting/sa-asr/utils/split_scp.pl new file mode 100755 index 000000000..0876dcb6d --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/split_scp.pl @@ -0,0 +1,246 @@ +#!/usr/bin/env perl + +# Copyright 2010-2011 Microsoft Corporation + +# See ../../COPYING for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This program splits up any kind of .scp or archive-type file. +# If there is no utt2spk option it will work on any text file and +# will split it up with an approximately equal number of lines in +# each but. +# With the --utt2spk option it will work on anything that has the +# utterance-id as the first entry on each line; the utt2spk file is +# of the form "utterance speaker" (on each line). +# It splits it into equal size chunks as far as it can. If you use the utt2spk +# option it will make sure these chunks coincide with speaker boundaries. In +# this case, if there are more chunks than speakers (and in some other +# circumstances), some of the resulting chunks will be empty and it will print +# an error message and exit with nonzero status. +# You will normally call this like: +# split_scp.pl scp scp.1 scp.2 scp.3 ... +# or +# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ... +# Note that you can use this script to split the utt2spk file itself, +# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ... + +# You can also call the scripts like: +# split_scp.pl -j 3 0 scp scp.0 +# [note: with this option, it assumes zero-based indexing of the split parts, +# i.e. the second number must be 0 <= n < num-jobs.] + +use warnings; + +$num_jobs = 0; +$job_id = 0; +$utt2spk_file = ""; +$one_based = 0; + +for ($x = 1; $x <= 3 && @ARGV > 0; $x++) { + if ($ARGV[0] eq "-j") { + shift @ARGV; + $num_jobs = shift @ARGV; + $job_id = shift @ARGV; + } + if ($ARGV[0] =~ /--utt2spk=(.+)/) { + $utt2spk_file=$1; + shift; + } + if ($ARGV[0] eq '--one-based') { + $one_based = 1; + shift @ARGV; + } +} + +if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 || + $job_id - $one_based >= $num_jobs)) { + die "$0: Invalid job number/index values for '-j $num_jobs $job_id" . + ($one_based ? " --one-based" : "") . "'\n" +} + +$one_based + and $job_id--; + +if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) { + die +"Usage: split_scp.pl [--utt2spk=] in.scp out1.scp out2.scp ... + or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=] in.scp [out.scp] + ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n"; +} + +$error = 0; +$inscp = shift @ARGV; +if ($num_jobs == 0) { # without -j option + @OUTPUTS = @ARGV; +} else { + for ($j = 0; $j < $num_jobs; $j++) { + if ($j == $job_id) { + if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; } + else { push @OUTPUTS, "-"; } + } else { + push @OUTPUTS, "/dev/null"; + } + } +} + +if ($utt2spk_file ne "") { # We have the --utt2spk option... + open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n"; + while(<$u_fh>) { + @A = split; + @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n"; + ($u,$s) = @A; + $utt2spk{$u} = $s; + } + close $u_fh; + open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; + @spkrs = (); + while(<$i_fh>) { + @A = split; + if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; } + $u = $A[0]; + $s = $utt2spk{$u}; + defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n"; + if(!defined $spk_count{$s}) { + push @spkrs, $s; + $spk_count{$s} = 0; + $spk_data{$s} = []; # ref to new empty array. + } + $spk_count{$s}++; + push @{$spk_data{$s}}, $_; + } + # Now split as equally as possible .. + # First allocate spks to files by allocating an approximately + # equal number of speakers. + $numspks = @spkrs; # number of speakers. + $numscps = @OUTPUTS; # number of output files. + if ($numspks < $numscps) { + die "$0: Refusing to split data because number of speakers $numspks " . + "is less than the number of output .scp files $numscps\n"; + } + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scparray[$scpidx] = []; # [] is array reference. + } + for ($spkidx = 0; $spkidx < $numspks; $spkidx++) { + $scpidx = int(($spkidx*$numscps) / $numspks); + $spk = $spkrs[$spkidx]; + push @{$scparray[$scpidx]}, $spk; + $scpcount[$scpidx] += $spk_count{$spk}; + } + + # Now will try to reassign beginning + ending speakers + # to different scp's and see if it gets more balanced. + # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2. + # We can show that if considering changing just 2 scp's, we minimize + # this by minimizing the squared difference in sizes. This is + # equivalent to minimizing the absolute difference in sizes. This + # shows this method is bound to converge. + + $changed = 1; + while($changed) { + $changed = 0; + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + # First try to reassign ending spk of this scp. + if($scpidx < $numscps-1) { + $sz = @{$scparray[$scpidx]}; + if($sz > 0) { + $spk = $scparray[$scpidx]->[$sz-1]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx]; + $nutt2 = $scpcount[$scpidx+1]; + if( abs( ($nutt2+$count) - ($nutt1-$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx+1] += $count; + $scpcount[$scpidx] -= $count; + pop @{$scparray[$scpidx]}; + unshift @{$scparray[$scpidx+1]}, $spk; + $changed = 1; + } + } + } + if($scpidx > 0 && @{$scparray[$scpidx]} > 0) { + $spk = $scparray[$scpidx]->[0]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx-1]; + $nutt2 = $scpcount[$scpidx]; + if( abs( ($nutt2-$count) - ($nutt1+$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx-1] += $count; + $scpcount[$scpidx] -= $count; + shift @{$scparray[$scpidx]}; + push @{$scparray[$scpidx-1]}, $spk; + $changed = 1; + } + } + } + } + # Now print out the files... + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scpfile = $OUTPUTS[$scpidx]; + ($scpfile ne '-' ? open($f_fh, '>', $scpfile) + : open($f_fh, '>&', \*STDOUT)) || + die "$0: Could not open scp file $scpfile for writing: $!\n"; + $count = 0; + if(@{$scparray[$scpidx]} == 0) { + print STDERR "$0: eError: split_scp.pl producing empty .scp file " . + "$scpfile (too many splits and too few speakers?)\n"; + $error = 1; + } else { + foreach $spk ( @{$scparray[$scpidx]} ) { + print $f_fh @{$spk_data{$spk}}; + $count += $spk_count{$spk}; + } + $count == $scpcount[$scpidx] || die "Count mismatch [code error]"; + } + close($f_fh); + } +} else { + # This block is the "normal" case where there is no --utt2spk + # option and we just break into equal size chunks. + + open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; + + $numscps = @OUTPUTS; # size of array. + @F = (); + while(<$i_fh>) { + push @F, $_; + } + $numlines = @F; + if($numlines == 0) { + print STDERR "$0: error: empty input scp file $inscp\n"; + $error = 1; + } + $linesperscp = int( $numlines / $numscps); # the "whole part".. + $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n"; + $remainder = $numlines - ($linesperscp * $numscps); + ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder"; + # [just doing int() rounds down]. + $n = 0; + for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) { + $scpfile = $OUTPUTS[$scpidx]; + ($scpfile ne '-' ? open($o_fh, '>', $scpfile) + : open($o_fh, '>&', \*STDOUT)) || + die "$0: Could not open scp file $scpfile for writing: $!\n"; + for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) { + print $o_fh $F[$n++]; + } + close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n"; + } + $n == $numlines || die "$n != $numlines [code error]"; +} + +exit ($error); diff --git a/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl new file mode 100755 index 000000000..6e0e438ca --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# converts an utt2spk file to a spk2utt file. +# Takes input from the stdin or from a file argument; +# output goes to the standard out. + +if ( @ARGV > 1 ) { + die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; +} + +while(<>){ + @A = split(" ", $_); + @A == 2 || die "Invalid line in utt2spk file: $_"; + ($u,$s) = @A; + if(!$seen_spk{$s}) { + $seen_spk{$s} = 1; + push @spklist, $s; + } + push (@{$spk_hash{$s}}, "$u"); +} +foreach $s (@spklist) { + $l = join(' ',@{$spk_hash{$s}}); + print "$s $l\n"; +} diff --git a/egs/alimeeting/sa-asr/utils/validate_data_dir.sh b/egs/alimeeting/sa-asr/utils/validate_data_dir.sh new file mode 100755 index 000000000..3eec443a0 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/validate_data_dir.sh @@ -0,0 +1,404 @@ +#!/usr/bin/env bash + +cmd="$@" + +no_feats=false +no_wav=false +no_text=false +no_spk_sort=false +non_print=false + + +function show_help +{ + echo "Usage: $0 [--no-feats] [--no-text] [--non-print] [--no-wav] [--no-spk-sort] " + echo "The --no-xxx options mean that the script does not require " + echo "xxx.scp to be present, but it will check it if it is present." + echo "--no-spk-sort means that the script does not require the utt2spk to be " + echo "sorted by the speaker-id in addition to being sorted by utterance-id." + echo "--non-print ignore the presence of non-printable characters." + echo "By default, utt2spk is expected to be sorted by both, which can be " + echo "achieved by making the speaker-id prefixes of the utterance-ids" + echo "e.g.: $0 data/train" +} + +while [ $# -ne 0 ] ; do + case "$1" in + "--no-feats") + no_feats=true; + ;; + "--no-text") + no_text=true; + ;; + "--non-print") + non_print=true; + ;; + "--no-wav") + no_wav=true; + ;; + "--no-spk-sort") + no_spk_sort=true; + ;; + *) + if ! [ -z "$data" ] ; then + show_help; + exit 1 + fi + data=$1 + ;; + esac + shift +done + + + +if [ ! -d $data ]; then + echo "$0: no such directory $data" + exit 1; +fi + +if [ -f $data/images.scp ]; then + cmd=${cmd/--no-wav/} # remove --no-wav if supplied + image/validate_data_dir.sh $cmd + exit $? +fi + +for f in spk2utt utt2spk; do + if [ ! -f $data/$f ]; then + echo "$0: no such file $f" + exit 1; + fi + if [ ! -s $data/$f ]; then + echo "$0: empty file $f" + exit 1; + fi +done + +! cat $data/utt2spk | awk '{if (NF != 2) exit(1); }' && \ + echo "$0: $data/utt2spk has wrong format." && exit; + +ns=$(wc -l < $data/spk2utt) +if [ "$ns" == 1 ]; then + echo "$0: WARNING: you have only one speaker. This probably a bad idea." + echo " Search for the word 'bold' in http://kaldi-asr.org/doc/data_prep.html" + echo " for more information." +fi + + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted_and_uniq { + ! perl -ne '((substr $_,-1) eq "\n") or die "file $ARGV has invalid newline";' $1 && exit 1; + ! awk '{print $1}' < $1 | sort -uC && echo "$0: file $1 is not sorted or has duplicates" && exit 1; +} + +function partial_diff { + diff -U1 $1 $2 | (head -n 6; echo "..."; tail -n 6) + n1=`cat $1 | wc -l` + n2=`cat $2 | wc -l` + echo "[Lengths are $1=$n1 versus $2=$n2]" +} + +check_sorted_and_uniq $data/utt2spk + +if ! $no_spk_sort; then + ! sort -k2 -C $data/utt2spk && \ + echo "$0: utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; +fi + +check_sorted_and_uniq $data/spk2utt + +! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \ + <(utils/spk2utt_to_utt2spk.pl $data/spk2utt) && \ + echo "$0: spk2utt and utt2spk do not seem to match" && exit 1; + +cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts + +if [ ! -f $data/text ] && ! $no_text; then + echo "$0: no such file $data/text (if this is by design, specify --no-text)" + exit 1; +fi + +num_utts=`cat $tmpdir/utts | wc -l` +if ! $no_text; then + if ! $non_print; then + if locale -a | grep "C.UTF-8" >/dev/null; then + L=C.UTF-8 + else + L=en_US.UTF-8 + fi + n_non_print=$(LC_ALL="$L" grep -c '[^[:print:][:space:]]' $data/text) && \ + echo "$0: text contains $n_non_print lines with non-printable characters" &&\ + exit 1; + fi + utils/validate_text.pl $data/text || exit 1; + check_sorted_and_uniq $data/text + text_len=`cat $data/text | wc -l` + illegal_sym_list=" #0" + for x in $illegal_sym_list; do + if grep -w "$x" $data/text > /dev/null; then + echo "$0: Error: in $data, text contains illegal symbol $x" + exit 1; + fi + done + awk '{print $1}' < $data/text > $tmpdir/utts.txt + if ! cmp -s $tmpdir/utts{,.txt}; then + echo "$0: Error: in $data, utterance lists extracted from utt2spk and text" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.txt} + exit 1; + fi +fi + +if [ -f $data/segments ] && [ ! -f $data/wav.scp ]; then + echo "$0: in directory $data, segments file exists but no wav.scp" + exit 1; +fi + + +if [ ! -f $data/wav.scp ] && ! $no_wav; then + echo "$0: no such file $data/wav.scp (if this is by design, specify --no-wav)" + exit 1; +fi + +if [ -f $data/wav.scp ]; then + check_sorted_and_uniq $data/wav.scp + + if grep -E -q '^\S+\s+~' $data/wav.scp; then + # note: it's not a good idea to have any kind of tilde in wav.scp, even if + # part of a command, as it would cause compatibility problems if run by + # other users, but this used to be not checked for so we let it slide unless + # it's something of the form "foo ~/foo.wav" (i.e. a plain file name) which + # would definitely cause problems as the fopen system call does not do + # tilde expansion. + echo "$0: Please do not use tilde (~) in your wav.scp." + exit 1; + fi + + if [ -f $data/segments ]; then + + check_sorted_and_uniq $data/segments + # We have a segments file -> interpret wav file as "recording-ids" not utterance-ids. + ! cat $data/segments | \ + awk '{if (NF != 4 || $4 <= $3) { print "Bad line in segments file", $0; exit(1); }}' && \ + echo "$0: badly formatted segments file" && exit 1; + + segments_len=`cat $data/segments | wc -l` + if [ -f $data/text ]; then + ! cmp -s $tmpdir/utts <(awk '{print $1}' <$data/segments) && \ + echo "$0: Utterance list differs between $data/utt2spk and $data/segments " && \ + echo "$0: Lengths are $segments_len vs $num_utts" && \ + exit 1 + fi + + cat $data/segments | awk '{print $2}' | sort | uniq > $tmpdir/recordings + awk '{print $1}' $data/wav.scp > $tmpdir/recordings.wav + if ! cmp -s $tmpdir/recordings{,.wav}; then + echo "$0: Error: in $data, recording-ids extracted from segments and wav.scp" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/recordings{,.wav} + exit 1; + fi + if [ -f $data/reco2file_and_channel ]; then + # this file is needed only for ctm scoring; it's indexed by recording-id. + check_sorted_and_uniq $data/reco2file_and_channel + ! cat $data/reco2file_and_channel | \ + awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) { + if ( NF == 3 && $3 == "1" ) { + warning_issued = 1; + } else { + print "Bad line ", $0; exit 1; + } + } + } + END { + if (warning_issued == 1) { + print "The channel should be marked as A or B, not 1! You should change it ASAP! " + } + }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1; + cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/recordings.r2fc + if ! cmp -s $tmpdir/recordings{,.r2fc}; then + echo "$0: Error: in $data, recording-ids extracted from segments and reco2file_and_channel" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/recordings{,.r2fc} + exit 1; + fi + fi + else + # No segments file -> assume wav.scp indexed by utterance. + cat $data/wav.scp | awk '{print $1}' > $tmpdir/utts.wav + if ! cmp -s $tmpdir/utts{,.wav}; then + echo "$0: Error: in $data, utterance lists extracted from utt2spk and wav.scp" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.wav} + exit 1; + fi + + if [ -f $data/reco2file_and_channel ]; then + # this file is needed only for ctm scoring; it's indexed by recording-id. + check_sorted_and_uniq $data/reco2file_and_channel + ! cat $data/reco2file_and_channel | \ + awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) { + if ( NF == 3 && $3 == "1" ) { + warning_issued = 1; + } else { + print "Bad line ", $0; exit 1; + } + } + } + END { + if (warning_issued == 1) { + print "The channel should be marked as A or B, not 1! You should change it ASAP! " + } + }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1; + cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/utts.r2fc + if ! cmp -s $tmpdir/utts{,.r2fc}; then + echo "$0: Error: in $data, utterance-ids extracted from segments and reco2file_and_channel" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.r2fc} + exit 1; + fi + fi + fi +fi + +if [ ! -f $data/feats.scp ] && ! $no_feats; then + echo "$0: no such file $data/feats.scp (if this is by design, specify --no-feats)" + exit 1; +fi + +if [ -f $data/feats.scp ]; then + check_sorted_and_uniq $data/feats.scp + cat $data/feats.scp | awk '{print $1}' > $tmpdir/utts.feats + if ! cmp -s $tmpdir/utts{,.feats}; then + echo "$0: Error: in $data, utterance-ids extracted from utt2spk and features" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.feats} + exit 1; + fi +fi + + +if [ -f $data/cmvn.scp ]; then + check_sorted_and_uniq $data/cmvn.scp + cat $data/cmvn.scp | awk '{print $1}' > $tmpdir/speakers.cmvn + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + if ! cmp -s $tmpdir/speakers{,.cmvn}; then + echo "$0: Error: in $data, speaker lists extracted from spk2utt and cmvn" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/speakers{,.cmvn} + exit 1; + fi +fi + +if [ -f $data/spk2gender ]; then + check_sorted_and_uniq $data/spk2gender + ! cat $data/spk2gender | awk '{if (!((NF == 2 && ($2 == "m" || $2 == "f")))) exit 1; }' && \ + echo "$0: Mal-formed spk2gender file" && exit 1; + cat $data/spk2gender | awk '{print $1}' > $tmpdir/speakers.spk2gender + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + if ! cmp -s $tmpdir/speakers{,.spk2gender}; then + echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2gender" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/speakers{,.spk2gender} + exit 1; + fi +fi + +if [ -f $data/spk2warp ]; then + check_sorted_and_uniq $data/spk2warp + ! cat $data/spk2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \ + echo "$0: Mal-formed spk2warp file" && exit 1; + cat $data/spk2warp | awk '{print $1}' > $tmpdir/speakers.spk2warp + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + if ! cmp -s $tmpdir/speakers{,.spk2warp}; then + echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2warp" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/speakers{,.spk2warp} + exit 1; + fi +fi + +if [ -f $data/utt2warp ]; then + check_sorted_and_uniq $data/utt2warp + ! cat $data/utt2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \ + echo "$0: Mal-formed utt2warp file" && exit 1; + cat $data/utt2warp | awk '{print $1}' > $tmpdir/utts.utt2warp + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + if ! cmp -s $tmpdir/utts{,.utt2warp}; then + echo "$0: Error: in $data, utterance lists extracted from utt2spk and utt2warp" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.utt2warp} + exit 1; + fi +fi + +# check some optionally-required things +for f in vad.scp utt2lang utt2uniq; do + if [ -f $data/$f ]; then + check_sorted_and_uniq $data/$f + if ! cmp -s <( awk '{print $1}' $data/utt2spk ) \ + <( awk '{print $1}' $data/$f ); then + echo "$0: error: in $data, $f and utt2spk do not have identical utterance-id list" + exit 1; + fi + fi +done + + +if [ -f $data/utt2dur ]; then + check_sorted_and_uniq $data/utt2dur + cat $data/utt2dur | awk '{print $1}' > $tmpdir/utts.utt2dur + if ! cmp -s $tmpdir/utts{,.utt2dur}; then + echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2dur file" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.utt2dur} + exit 1; + fi + cat $data/utt2dur | \ + awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line utt2dur:" NR ":" $0; exit(1) }}' || exit 1 +fi + +if [ -f $data/utt2num_frames ]; then + check_sorted_and_uniq $data/utt2num_frames + cat $data/utt2num_frames | awk '{print $1}' > $tmpdir/utts.utt2num_frames + if ! cmp -s $tmpdir/utts{,.utt2num_frames}; then + echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2num_frames file" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.utt2num_frames} + exit 1 + fi + awk <$data/utt2num_frames '{ + if (NF != 2 || !($2 > 0) || $2 != int($2)) { + print "Bad line utt2num_frames:" NR ":" $0 + exit 1 } }' || exit 1 +fi + +if [ -f $data/reco2dur ]; then + check_sorted_and_uniq $data/reco2dur + cat $data/reco2dur | awk '{print $1}' > $tmpdir/recordings.reco2dur + if [ -f $tmpdir/recordings ]; then + if ! cmp -s $tmpdir/recordings{,.reco2dur}; then + echo "$0: Error: in $data, recording-ids extracted from segments and reco2dur file" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/recordings{,.reco2dur} + exit 1; + fi + else + if ! cmp -s $tmpdir/{utts,recordings.reco2dur}; then + echo "$0: Error: in $data, recording-ids extracted from wav.scp and reco2dur file" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/{utts,recordings.reco2dur} + exit 1; + fi + fi + cat $data/reco2dur | \ + awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line : " $0; exit(1) }}' || exit 1 +fi + + +echo "$0: Successfully validated data-directory $data" diff --git a/egs/alimeeting/sa-asr/utils/validate_text.pl b/egs/alimeeting/sa-asr/utils/validate_text.pl new file mode 100755 index 000000000..7f75cf12f --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/validate_text.pl @@ -0,0 +1,136 @@ +#!/usr/bin/env perl +# +#=============================================================================== +# Copyright 2017 Johns Hopkins University (author: Yenda Trmal ) +# Johns Hopkins University (author: Daniel Povey) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +# validation script for data//text +# to be called (preferably) from utils/validate_data_dir.sh +use strict; +use warnings; +use utf8; +use Fcntl qw< SEEK_SET >; + +# this function reads the opened file (supplied as a first +# parameter) into an array of lines. For each +# line, it tests whether it's a valid utf-8 compatible +# line. If all lines are valid utf-8, it returns the lines +# decoded as utf-8, otherwise it assumes the file's encoding +# is one of those 1-byte encodings, such as ISO-8859-x +# or Windows CP-X. +# Please recall we do not really care about +# the actually encoding, we just need to +# make sure the length of the (decoded) string +# is correct (to make the output formatting looking right). +sub get_utf8_or_bytestream { + use Encode qw(decode encode); + my $is_utf_compatible = 1; + my @unicode_lines; + my @raw_lines; + my $raw_text; + my $lineno = 0; + my $file = shift; + + while (<$file>) { + $raw_text = $_; + last unless $raw_text; + if ($is_utf_compatible) { + my $decoded_text = eval { decode("UTF-8", $raw_text, Encode::FB_CROAK) } ; + $is_utf_compatible = $is_utf_compatible && defined($decoded_text); + push @unicode_lines, $decoded_text; + } else { + #print STDERR "WARNING: the line $raw_text cannot be interpreted as UTF-8: $decoded_text\n"; + ; + } + push @raw_lines, $raw_text; + $lineno += 1; + } + + if (!$is_utf_compatible) { + return (0, @raw_lines); + } else { + return (1, @unicode_lines); + } +} + +# check if the given unicode string contain unicode whitespaces +# other than the usual four: TAB, LF, CR and SPACE +sub validate_utf8_whitespaces { + my $unicode_lines = shift; + use feature 'unicode_strings'; + for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) { + my $current_line = $unicode_lines->[$i]; + if ((substr $current_line, -1) ne "\n"){ + print STDERR "$0: The current line (nr. $i) has invalid newline\n"; + return 1; + } + my @A = split(" ", $current_line); + my $utt_id = $A[0]; + # we replace TAB, LF, CR, and SPACE + # this is to simplify the test + if ($current_line =~ /\x{000d}/) { + print STDERR "$0: The line for utterance $utt_id contains CR (0x0D) character\n"; + return 1; + } + $current_line =~ s/[\x{0009}\x{000a}\x{0020}]/./g; + if ($current_line =~/\s/) { + print STDERR "$0: The line for utterance $utt_id contains disallowed Unicode whitespaces\n"; + return 1; + } + } + return 0; +} + +# checks if the text in the file (supplied as the argument) is utf-8 compatible +# if yes, checks if it contains only allowed whitespaces. If no, then does not +# do anything. The function seeks to the original position in the file after +# reading the text. +sub check_allowed_whitespace { + my $file = shift; + my $filename = shift; + my $pos = tell($file); + (my $is_utf, my @lines) = get_utf8_or_bytestream($file); + seek($file, $pos, SEEK_SET); + if ($is_utf) { + my $has_invalid_whitespaces = validate_utf8_whitespaces(\@lines); + if ($has_invalid_whitespaces) { + print STDERR "$0: ERROR: text file '$filename' contains disallowed UTF-8 whitespace character(s)\n"; + return 0; + } + } + return 1; +} + +if(@ARGV != 1) { + die "Usage: validate_text.pl \n" . + "e.g.: validate_text.pl data/train/text\n"; +} + +my $text = shift @ARGV; + +if (-z "$text") { + print STDERR "$0: ERROR: file '$text' is empty or does not exist\n"; + exit 1; +} + +if(!open(FILE, "<$text")) { + print STDERR "$0: ERROR: failed to open $text\n"; + exit 1; +} + +check_allowed_whitespace(\*FILE, $text) or exit 1; +close(FILE); diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py index 47226021f..c18472f51 100644 --- a/funasr/bin/asr_inference.py +++ b/funasr/bin/asr_inference.py @@ -40,7 +40,6 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend header_colors = '\033[95m' @@ -91,8 +90,6 @@ class Speech2Text: asr_train_config, asr_model_file, cmvn_file, device ) frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) @@ -111,7 +108,7 @@ class Speech2Text: # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( - lm_train_config, lm_file, device + lm_train_config, lm_file, None, device ) scorers["lm"] = lm.lm @@ -142,6 +139,13 @@ class Speech2Text: pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + logging.info(f"Beam_search: {beam_search}") + logging.info(f"Decoding device={device}, dtype={dtype}") + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type @@ -198,16 +202,7 @@ class Speech2Text: if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} + batch = {"speech": speech, "speech_lengths": speech_lengths} # a. To device batch = to_device(batch, device=self.device) @@ -355,6 +350,9 @@ def inference_modelscope( if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", @@ -408,6 +406,7 @@ def inference_modelscope( data_path_and_name_and_type, dtype=dtype, fs=fs, + mc=True, batch_size=batch_size, key_file=key_file, num_workers=num_workers, @@ -452,7 +451,7 @@ def inference_modelscope( # Write the result to each file ibest_writer["token"][key] = " ".join(token) - # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["token_int"][key] = " ".join(map(str, token_int)) ibest_writer["score"][key] = str(hyp.score) if text is not None: @@ -463,6 +462,9 @@ def inference_modelscope( asr_utils.print_progress(finish_count / file_count) if writer is not None: ibest_writer["text"][key] = text + + logging.info("uttid: {}".format(key)) + logging.info("text predictions: {}\n".format(text)) return asr_result_list return _forward @@ -637,4 +639,4 @@ def main(cmd=None): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index e10ebf404..e165531f8 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -288,6 +288,9 @@ def inference_launch_funasr(**kwargs): if mode == "asr": from funasr.bin.asr_inference import inference return inference(**kwargs) + elif mode == "sa_asr": + from funasr.bin.sa_asr_inference import inference + return inference(**kwargs) elif mode == "uniasr": from funasr.bin.asr_inference_uniasr import inference return inference(**kwargs) @@ -342,4 +345,4 @@ def main(cmd=None): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py index bba50daf0..c1e2cb2a1 100755 --- a/funasr/bin/asr_train.py +++ b/funasr/bin/asr_train.py @@ -2,6 +2,14 @@ import os +import logging + +logging.basicConfig( + level='INFO', + format=f"[{os.uname()[1].split('.')[0]}]" + f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", +) + from funasr.tasks.asr import ASRTask @@ -27,7 +35,8 @@ if __name__ == '__main__': args = parse_args() # setup local gpu_id - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + if args.ngpu > 0: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) # DDP settings if args.ngpu > 1: @@ -38,9 +47,9 @@ if __name__ == '__main__': # re-compute batch size: when dataset type is small if args.dataset_type == "small": - if args.batch_size is not None: + if args.batch_size is not None and args.ngpu > 0: args.batch_size = args.batch_size * args.ngpu - if args.batch_bins is not None: + if args.batch_bins is not None and args.ngpu > 0: args.batch_bins = args.batch_bins * args.ngpu main(args=args) diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py new file mode 100644 index 000000000..be63af111 --- /dev/null +++ b/funasr/bin/sa_asr_inference.py @@ -0,0 +1,674 @@ +import argparse +import logging +import sys +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import Dict + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from funasr.fileio.datadir_writer import DatadirWriter +from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim +from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch +from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis +from funasr.modules.scorers.ctc import CTCPrefixScorer +from funasr.modules.scorers.length_bonus import LengthBonus +from funasr.modules.scorers.scorer_interface import BatchScorerInterface +from funasr.modules.subsampling import TooShortUttError +from funasr.tasks.sa_asr import ASRTask +from funasr.tasks.lm import LMTask +from funasr.text.build_tokenizer import build_tokenizer +from funasr.text.token_id_converter import TokenIDConverter +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from funasr.utils import asr_utils, wav_utils, postprocess_utils + + +header_colors = '\033[95m' +end_colors = '\033[0m' + + +class Speech2Text: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + frontend_conf: dict = None, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, device + ) + frontend = None + + logging.info("asr_model: {}".format(asr_model)) + logging.info("asr_train_args: {}".format(asr_train_args)) + asr_model.to(dtype=getattr(torch, dtype)).eval() + + decoder = asr_model.decoder + + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + token_list = asr_model.token_list + scorers.update( + decoder=decoder, + ctc=ctc, + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, None, device + ) + scorers["lm"] = lm.lm + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + logging.info(f"Beam_search: {beam_search}") + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + self.beam_search = beam_search + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + self.frontend = frontend + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray] + ) -> List[ + Tuple[ + Optional[str], + Optional[str], + List[str], + List[int], + Union[Hypothesis], + ] + ]: + """Inference + + Args: + speech: Input speech data + Returns: + text, text_id, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + if isinstance(profile, np.ndarray): + profile = torch.tensor(profile) + + batch = {"speech": speech, "speech_lengths": speech_lengths} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + asr_enc, _, spk_enc = self.asr_model.encode(**batch) + if isinstance(asr_enc, tuple): + asr_enc = asr_enc[0] + if isinstance(spk_enc, tuple): + spk_enc = spk_enc[0] + assert len(asr_enc) == 1, len(asr_enc) + assert len(spk_enc) == 1, len(spk_enc) + + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1: last_pos] + else: + token_int = hyp.yseq[1: last_pos].tolist() + + spk_weigths=torch.stack(hyp.spk_weigths, dim=0) + + token_ori = self.converter.ids2tokens(token_int) + text_ori = self.tokenizer.tokens2text(token_ori) + + text_ori_spklist = text_ori.split('$') + cur_index = 0 + spk_choose = [] + for i in range(len(text_ori_spklist)): + text_ori_split = text_ori_spklist[i] + n = len(text_ori_split) + spk_weights_local = spk_weigths[cur_index: cur_index + n] + cur_index = cur_index + n + 1 + spk_weights_local = spk_weights_local.mean(dim=0) + spk_choose_local = spk_weights_local.argmax(-1) + spk_choose.append(spk_choose_local.item() + 1) + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + + text_spklist = text.split('$') + assert len(spk_choose) == len(text_spklist) + + spk_list=[] + for i in range(len(text_spklist)): + text_split = text_spklist[i] + n = len(text_split) + spk_list.append(str(spk_choose[i]) * n) + + text_id = '$'.join(spk_list) + + assert len(text) == len(text_id) + + results.append((text, text_id, token, token_int, hyp)) + + assert check_return_type(results) + return results + +def inference( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + streaming: bool = False, + output_dir: Optional[str] = None, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + **kwargs, +): + inference_pipeline = inference_modelscope( + maxlenratio=maxlenratio, + minlenratio=minlenratio, + batch_size=batch_size, + beam_size=beam_size, + ngpu=ngpu, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + penalty=penalty, + log_level=log_level, + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + raw_inputs=raw_inputs, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + key_file=key_file, + word_lm_train_config=word_lm_train_config, + bpemodel=bpemodel, + allow_variable_data_keys=allow_variable_data_keys, + streaming=streaming, + output_dir=output_dir, + dtype=dtype, + seed=seed, + ngram_weight=ngram_weight, + nbest=nbest, + num_workers=num_workers, + **kwargs, + ) + return inference_pipeline(data_path_and_name_and_type, raw_inputs) + +def inference_modelscope( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + # data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + streaming: bool = False, + output_dir: Optional[str] = None, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + param_dict: dict = None, + **kwargs, +): + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + streaming=streaming, + ) + logging.info("speech2text_kwargs: {}".format(speech2text_kwargs)) + speech2text = Speech2Text(**speech2text_kwargs) + + def _forward(data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None, + **kwargs, + ): + # 3. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + fs=fs, + mc=True, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), + collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + finish_count = 0 + file_count = 1 + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + asr_result_list = [] + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + if output_path is not None: + writer = DatadirWriter(output_path) + else: + writer = None + + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + # N-best list of (text, token, token_int, hyp_object) + try: + results = speech2text(**batch) + except TooShortUttError as e: + logging.warning(f"Utterance {keys} {e}") + hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) + results = [[" ", ["sil"], [2], hyp]] * nbest + + # Only supporting batch_size==1 + key = keys[0] + for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results): + # Create a directory: outdir/{n}best_recog + if writer is not None: + ibest_writer = writer[f"{n}best_recog"] + + # Write the result to each file + ibest_writer["token"][key] = " ".join(token) + ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["score"][key] = str(hyp.score) + ibest_writer["text_id"][key] = text_id + + if text is not None: + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + item = {'key': key, 'value': text_postprocessed} + asr_result_list.append(item) + finish_count += 1 + asr_utils.print_progress(finish_count / file_count) + if writer is not None: + ibest_writer["text"][key] = text + + logging.info("uttid: {}".format(key)) + logging.info("text predictions: {}".format(text)) + logging.info("text_id predictions: {}\n".format(text_id)) + return asr_result_list + + return _forward + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="ASR Decoding", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--gpuid_list", + type=str, + default="", + help="The visible gpus", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=False, + action="append", + ) + group.add_argument("--raw_inputs", type=list, default=None) + # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--asr_train_config", + type=str, + help="ASR training configuration", + ) + group.add_argument( + "--asr_model_file", + type=str, + help="ASR model parameter file", + ) + group.add_argument( + "--cmvn_file", + type=str, + help="Global cmvn file", + ) + group.add_argument( + "--lm_train_config", + type=str, + help="LM training configuration", + ) + group.add_argument( + "--lm_file", + type=str, + help="LM parameter file", + ) + group.add_argument( + "--word_lm_train_config", + type=str, + help="Word LM training configuration", + ) + group.add_argument( + "--word_lm_file", + type=str, + help="Word LM parameter file", + ) + group.add_argument( + "--ngram_file", + type=str, + help="N-gram parameter file", + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + + group = parser.add_argument_group("Beam-search related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + group.add_argument("--beam_size", type=int, default=20, help="Beam size") + group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") + group.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain max output length. " + "If maxlenratio=0.0 (default), it uses a end-detect " + "function " + "to automatically find maximum hypothesis lengths." + "If maxlenratio<0.0, its absolute value is interpreted" + "as a constant max output length", + ) + group.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + group.add_argument( + "--ctc_weight", + type=float, + default=0.5, + help="CTC weight in joint decoding", + ) + group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") + group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") + group.add_argument("--streaming", type=str2bool, default=False) + + group = parser.add_argument_group("Text converter related") + group.add_argument( + "--token_type", + type=str_or_none, + default=None, + choices=["char", "bpe", None], + help="The token type for ASR model. " + "If not given, refers from the training args", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model path of sentencepiece. " + "If not given, refers from the training args", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + inference(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py new file mode 100755 index 000000000..c7c7c42a4 --- /dev/null +++ b/funasr/bin/sa_asr_train.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 + +import os + +import logging + +logging.basicConfig( + level='INFO', + format=f"[{os.uname()[1].split('.')[0]}]" + f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", +) + +from funasr.tasks.sa_asr import ASRTask + + +# for ASR Training +def parse_args(): + parser = ASRTask.get_parser() + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + args = parser.parse_args() + return args + + +def main(args=None, cmd=None): + # for ASR Training + ASRTask.main(args=args, cmd=cmd) + + +if __name__ == '__main__': + args = parse_args() + + # setup local gpu_id + if args.ngpu > 0: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + assert args.num_worker_count == 1 + + # re-compute batch size: when dataset type is small + if args.dataset_type == "small": + if args.batch_size is not None and args.ngpu > 0: + args.batch_size = args.batch_size * args.ngpu + if args.batch_bins is not None and args.ngpu > 0: + args.batch_bins = args.batch_bins * args.ngpu + + main(args=args) diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py index dc872b047..d757f7f8c 100644 --- a/funasr/fileio/sound_scp.py +++ b/funasr/fileio/sound_scp.py @@ -46,13 +46,15 @@ class SoundScpReader(collections.abc.Mapping): if self.normalize: # soundfile.read normalizes data to [-1,1] if dtype is not given array, rate = librosa.load( - wav, sr=self.dest_sample_rate, mono=not self.always_2d + wav, sr=self.dest_sample_rate, mono=self.always_2d ) else: array, rate = librosa.load( - wav, sr=self.dest_sample_rate, mono=not self.always_2d, dtype=self.dtype + wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype ) + if array.ndim==2: + array=array.transpose((1, 0)) return rate, array def get_path(self, key): diff --git a/funasr/losses/nll_loss.py b/funasr/losses/nll_loss.py new file mode 100644 index 000000000..7e4e29496 --- /dev/null +++ b/funasr/losses/nll_loss.py @@ -0,0 +1,47 @@ +import torch +from torch import nn + +class NllLoss(nn.Module): + """Nll loss. + + :param int size: the number of class + :param int padding_idx: ignored class id + :param bool normalize_length: normalize loss by sequence length if True + :param torch.nn.Module criterion: loss function + """ + + def __init__( + self, + size, + padding_idx, + normalize_length=False, + criterion=nn.NLLLoss(reduction='none'), + ): + """Construct an LabelSmoothingLoss object.""" + super(NllLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + + def forward(self, x, target): + """Compute loss between x and target. + + :param torch.Tensor x: prediction (batch, seqlen, class) + :param torch.Tensor target: + target signal masked with self.padding_id (batch, seqlen) + :return: scalar float value + :rtype torch.Tensor + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + with torch.no_grad(): + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + kl = self.criterion(x , target) + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore, 0).sum() / denom diff --git a/funasr/models/decoder/decoder_layer_sa_asr.py b/funasr/models/decoder/decoder_layer_sa_asr.py new file mode 100644 index 000000000..80afc5168 --- /dev/null +++ b/funasr/models/decoder/decoder_layer_sa_asr.py @@ -0,0 +1,169 @@ +import torch +from torch import nn + +from funasr.modules.layer_norm import LayerNorm + + +class SpeakerAttributeSpkDecoderFirstLayer(nn.Module): + + def __init__( + self, + size, + self_attn, + src_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an DecoderLayer object.""" + super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None): + + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = None + if tgt_mask is not None: + tgt_q_mask = tgt_mask[:, -1:, :] + + if self.concat_after: + tgt_concat = torch.cat( + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 + ) + x = residual + self.concat_linear1(tgt_concat) + else: + x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) + if not self.normalize_before: + x = self.norm1(x) + z = x + + residual = x + if self.normalize_before: + x = self.norm1(x) + + skip = self.src_attn(x, asr_memory, spk_memory, memory_mask) + + if self.concat_after: + x_concat = torch.cat( + (x, skip), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout(skip) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, asr_memory, spk_memory, memory_mask, z + +class SpeakerAttributeAsrDecoderFirstLayer(nn.Module): + + def __init__( + self, + size, + d_size, + src_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an DecoderLayer object.""" + super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__() + self.size = size + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.norm3 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + self.spk_linear = nn.Linear(d_size, size, bias=False) + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None): + + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = None + if tgt_mask is not None: + tgt_q_mask = tgt_mask[:, -1:, :] + + x = tgt_q + if self.normalize_before: + x = self.norm2(x) + if self.concat_after: + x_concat = torch.cat( + (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) + if not self.normalize_before: + x = self.norm2(x) + residual = x + + if dn!=None: + x = x + self.spk_linear(dn) + if self.normalize_before: + x = self.norm3(x) + + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask + + + diff --git a/funasr/models/decoder/transformer_decoder_sa_asr.py b/funasr/models/decoder/transformer_decoder_sa_asr.py new file mode 100644 index 000000000..949f9c898 --- /dev/null +++ b/funasr/models/decoder/transformer_decoder_sa_asr.py @@ -0,0 +1,291 @@ +from typing import Any +from typing import List +from typing import Sequence +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.attention import MultiHeadedAttention +from funasr.modules.attention import CosineDistanceAttention +from funasr.models.decoder.transformer_decoder import DecoderLayer +from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer +from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer +from funasr.modules.dynamic_conv import DynamicConvolution +from funasr.modules.dynamic_conv2d import DynamicConvolution2D +from funasr.modules.embedding import PositionalEncoding +from funasr.modules.layer_norm import LayerNorm +from funasr.modules.lightconv import LightweightConvolution +from funasr.modules.lightconv2d import LightweightConvolution2D +from funasr.modules.mask import subsequent_mask +from funasr.modules.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from funasr.modules.repeat import repeat +from funasr.modules.scorers.scorer_interface import BatchScorerInterface +from funasr.models.decoder.abs_decoder import AbsDecoder + +class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + spker_embedding_dim: int = 256, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + input_layer: str = "embed", + use_asr_output_layer: bool = True, + use_spk_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + ): + assert check_argument_types() + super().__init__() + attention_dim = encoder_output_size + + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(vocab_size, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") + + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + if use_asr_output_layer: + self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.asr_output_layer = None + + if use_spk_output_layer: + self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim) + else: + self.spk_output_layer = None + + self.cos_distance_att = CosineDistanceAttention() + + self.decoder1 = None + self.decoder2 = None + self.decoder3 = None + self.decoder4 = None + + def forward( + self, + asr_hs_pad: torch.Tensor, + spk_hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + profile: torch.Tensor, + profile_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + + asr_memory = asr_hs_pad + spk_memory = spk_hs_pad + memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device) + # Spk decoder + x = self.embed(tgt) + + x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1( + x, tgt_mask, asr_memory, spk_memory, memory_mask + ) + x, tgt_mask, spk_memory, memory_mask = self.decoder2( + x, tgt_mask, spk_memory, memory_mask + ) + if self.normalize_before: + x = self.after_norm(x) + if self.spk_output_layer is not None: + x = self.spk_output_layer(x) + dn, weights = self.cos_distance_att(x, profile, profile_lens) + # Asr decoder + x, tgt_mask, asr_memory, memory_mask = self.decoder3( + z, tgt_mask, asr_memory, memory_mask, dn + ) + x, tgt_mask, asr_memory, memory_mask = self.decoder4( + x, tgt_mask, asr_memory, memory_mask + ) + + if self.normalize_before: + x = self.after_norm(x) + if self.asr_output_layer is not None: + x = self.asr_output_layer(x) + + olens = tgt_mask.sum(1) + return x, weights, olens + + + def forward_one_step( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + asr_memory: torch.Tensor, + spk_memory: torch.Tensor, + profile: torch.Tensor, + cache: List[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + + x = self.embed(tgt) + + if cache is None: + cache = [None] * (2 + len(self.decoder2) + len(self.decoder4)) + new_cache = [] + x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1( + x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0] + ) + new_cache.append(x) + for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2): + x, tgt_mask, spk_memory, _ = decoder( + x, tgt_mask, spk_memory, None, cache=c + ) + new_cache.append(x) + if self.normalize_before: + x = self.after_norm(x) + else: + x = x + if self.spk_output_layer is not None: + x = self.spk_output_layer(x) + dn, weights = self.cos_distance_att(x, profile, None) + + x, tgt_mask, asr_memory, _ = self.decoder3( + z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1] + ) + new_cache.append(x) + + for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4): + x, tgt_mask, asr_memory, _ = decoder( + x, tgt_mask, asr_memory, None, cache=c + ) + new_cache.append(x) + + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.asr_output_layer is not None: + y = torch.log_softmax(self.asr_output_layer(y), dim=-1) + + return y, weights, new_cache + + def score(self, ys, state, asr_enc, spk_enc, profile): + """Score.""" + ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0) + logp, weights, state = self.forward_one_step( + ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state + ) + return logp.squeeze(0), weights.squeeze(), state + +class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder): + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + spker_embedding_dim: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + asr_num_blocks: int = 6, + spk_num_blocks: int = 3, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_asr_output_layer: bool = True, + use_spk_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + ): + assert check_argument_types() + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + spker_embedding_dim=spker_embedding_dim, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_asr_output_layer=use_asr_output_layer, + use_spk_output_layer=use_spk_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + + self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ) + self.decoder2 = repeat( + spk_num_blocks - 1, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + + self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer( + attention_dim, + spker_embedding_dim, + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ) + self.decoder4 = repeat( + asr_num_blocks - 1, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py new file mode 100644 index 000000000..0d4097ec2 --- /dev/null +++ b/funasr/models/e2e_sa_asr.py @@ -0,0 +1,521 @@ +# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging +from contextlib import contextmanager +from distutils.version import LooseVersion +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +import torch.nn.functional as F +from typeguard import check_argument_types + +from funasr.layers.abs_normalize import AbsNormalize +from funasr.losses.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from funasr.losses.nll_loss import NllLoss +from funasr.models.ctc import CTC +from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.postencoder.abs_postencoder import AbsPostEncoder +from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.modules.add_sos_eos import add_sos_eos +from funasr.modules.e2e_asr_common import ErrorCalculator +from funasr.modules.nets_utils import th_accuracy +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class ESPnetASRModel(AbsESPnetModel): + """CTC-attention hybrid Encoder-Decoder model""" + + def __init__( + self, + vocab_size: int, + max_spk_num: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + asr_encoder: AbsEncoder, + spk_encoder: torch.nn.Module, + postencoder: Optional[AbsPostEncoder], + decoder: AbsDecoder, + ctc: CTC, + spk_weight: float = 0.5, + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + ignore_id: int = -1, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + extract_feats_in_collect_stats: bool = True, + ): + assert check_argument_types() + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + assert 0.0 <= interctc_weight < 1.0, interctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.blank_id = 0 + self.sos = 1 + self.eos = 2 + self.vocab_size = vocab_size + self.max_spk_num=max_spk_num + self.ignore_id = ignore_id + self.spk_weight = spk_weight + self.ctc_weight = ctc_weight + self.interctc_weight = interctc_weight + self.token_list = token_list.copy() + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + self.preencoder = preencoder + self.postencoder = postencoder + self.asr_encoder = asr_encoder + self.spk_encoder = spk_encoder + + if not hasattr(self.asr_encoder, "interctc_use_conditioning"): + self.asr_encoder.interctc_use_conditioning = False + if self.asr_encoder.interctc_use_conditioning: + self.asr_encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.asr_encoder.output_size() + ) + + self.error_calculator = None + + + # we set self.decoder = None in the CTC mode since + # self.decoder parameters were never used and PyTorch complained + # and threw an Exception in the multi-GPU experiment. + # thanks Jeff Farris for pointing out the issue. + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + self.criterion_spk = NllLoss( + size=max_spk_num, + padding_idx=ignore_id, + normalize_length=length_normalized_loss, + ) + + if report_cer or report_wer: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, report_wer + ) + + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + profile: torch.Tensor, + profile_lengths: torch.Tensor, + text_id: torch.Tensor, + text_id_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + profile: (Batch, Length, Dim) + profile_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + + # for data-parallel + text = text[:, : text_lengths.max()] + + # 1. Encoder + asr_encoder_out, encoder_out_lens, spk_encoder_out = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(asr_encoder_out, tuple): + intermediate_outs = asr_encoder_out[1] + asr_encoder_out = asr_encoder_out[0] + + loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = None, None, None, None, None, None + loss_ctc, cer_ctc = None, None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + asr_encoder_out, encoder_out_lens, text, text_lengths + ) + + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss( + intermediate_out, encoder_out_lens, text, text_lengths + ) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats["loss_interctc_layer{}".format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None + ) + stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = ( + 1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss( + asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss_asr = loss_att + elif self.ctc_weight == 1.0: + loss_asr = loss_ctc + else: + loss_asr = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + + if self.spk_weight == 0.0: + loss = loss_asr + else: + loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr + + + stats = dict( + loss=loss.detach(), + loss_asr=loss_asr.detach(), + loss_att=loss_att.detach() if loss_att is not None else None, + loss_ctc=loss_ctc.detach() if loss_ctc is not None else None, + loss_spk=loss_spk.detach() if loss_spk is not None else None, + acc=acc_att, + acc_spk=acc_spk, + cer=cer_att, + wer=wer_att, + cer_ctc=cer_ctc, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + feats, feats_lengths = speech, speech_lengths + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + feats_raw = feats.clone() + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.asr_encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.asr_encoder( + feats, feats_lengths, ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0] + # import ipdb;ipdb.set_trace() + if encoder_out_spk_ori.size(1)!=encoder_out.size(1): + encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1) + else: + encoder_out_spk=encoder_out_spk_ori + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens + ) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + assert encoder_out_spk.size(0) == speech.size(0), ( + encoder_out_spk.size(), + speech.size(0), + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens, encoder_out_spk + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from transformer-decoder + + Normally, this function is called in batchify_nll. + + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) # [batch, seqlen, dim] + batch_size = decoder_out.size(0) + decoder_num_class = decoder_out.size(2) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + decoder_out.view(-1, decoder_num_class), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction="none", + ) + nll = nll.view(batch_size, -1) + nll = nll.sum(dim=1) + assert nll.size(0) == batch_size + return nll + + def batchify_nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + batch_size: int = 100, + ): + """Compute negative log likelihood(nll) from transformer-decoder + + To avoid OOM, this fuction seperate the input into batches. + Then call nll for each batch and combine and return results. + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + batch_size: int, samples each batch contain when computing nll, + you may change this to avoid OOM or increase + GPU memory usage + """ + total_num = encoder_out.size(0) + if total_num <= batch_size: + nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + else: + nll = [] + start_idx = 0 + while True: + end_idx = min(start_idx + batch_size, total_num) + batch_encoder_out = encoder_out[start_idx:end_idx, :, :] + batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] + batch_ys_pad = ys_pad[start_idx:end_idx, :] + batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] + batch_nll = self.nll( + batch_encoder_out, + batch_encoder_out_lens, + batch_ys_pad, + batch_ys_pad_lens, + ) + nll.append(batch_nll) + start_idx = end_idx + if start_idx == total_num: + break + nll = torch.cat(nll) + assert nll.size(0) == total_num + return nll + + def _calc_att_loss( + self, + asr_encoder_out: torch.Tensor, + spk_encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + profile: torch.Tensor, + profile_lens: torch.Tensor, + text_id: torch.Tensor, + text_id_lengths: torch.Tensor + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, weights_no_pad, _ = self.decoder( + asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens + ) + + spk_num_no_pad=weights_no_pad.size(-1) + pad=(0,self.max_spk_num-spk_num_no_pad) + weights=F.pad(weights_no_pad, pad, mode='constant', value=0) + + # pre_id=weights.argmax(-1) + # pre_text=decoder_out.argmax(-1) + # id_mask=(pre_id==text_id).to(dtype=text_id.dtype) + # pre_text_mask=pre_text*id_mask+1-id_mask #相同的地方不变,不同的地方设为1() + # padding_mask= ys_out_pad != self.ignore_id + # numerator = torch.sum(pre_text_mask.masked_select(padding_mask) == ys_out_pad.masked_select(padding_mask)) + # denominator = torch.sum(padding_mask) + # sd_acc = float(numerator) / float(denominator) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + loss_spk = self.criterion_spk(torch.log(weights), text_id) + + acc_spk= th_accuracy( + weights.view(-1, self.max_spk_num), + text_id, + ignore_label=self.ignore_id, + ) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py index 9671fe9d9..2e1b0c454 100644 --- a/funasr/models/frontend/default.py +++ b/funasr/models/frontend/default.py @@ -38,6 +38,7 @@ class DefaultFrontend(AbsFrontend): htk: bool = False, frontend_conf: Optional[dict] = get_default_kwargs(Frontend), apply_stft: bool = True, + use_channel: int = None, ): assert check_argument_types() super().__init__() @@ -77,6 +78,7 @@ class DefaultFrontend(AbsFrontend): ) self.n_mels = n_mels self.frontend_type = "default" + self.use_channel = use_channel def output_size(self) -> int: return self.n_mels @@ -100,9 +102,12 @@ class DefaultFrontend(AbsFrontend): if input_stft.dim() == 4: # h: (B, T, C, F) -> h: (B, T, F) if self.training: - # Select 1ch randomly - ch = np.random.randint(input_stft.size(2)) - input_stft = input_stft[:, :, ch, :] + if self.use_channel == None: + input_stft = input_stft[:, :, 0, :] + else: + # Select 1ch randomly + ch = np.random.randint(input_stft.size(2)) + input_stft = input_stft[:, :, ch, :] else: # Use the first channel input_stft = input_stft[:, :, 0, :] diff --git a/funasr/models/pooling/statistic_pooling.py b/funasr/models/pooling/statistic_pooling.py index 8f85de99d..39d94be5b 100644 --- a/funasr/models/pooling/statistic_pooling.py +++ b/funasr/models/pooling/statistic_pooling.py @@ -83,9 +83,9 @@ def windowed_statistic_pooling( num_chunk = int(math.ceil(tt / pooling_stride)) pad = pooling_size // 2 if len(xs_pad.shape) == 4: - features = F.pad(xs_pad, (0, 0, pad, pad), "reflect") + features = F.pad(xs_pad, (0, 0, pad, pad), "replicate") else: - features = F.pad(xs_pad, (pad, pad), "reflect") + features = F.pad(xs_pad, (pad, pad), "replicate") stat_list = [] for i in range(num_chunk): diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py index 62020796e..fcb3ed412 100644 --- a/funasr/modules/attention.py +++ b/funasr/modules/attention.py @@ -13,6 +13,9 @@ import torch from torch import nn from typing import Optional, Tuple +import torch.nn.functional as F +from funasr.modules.nets_utils import make_pad_mask + class MultiHeadedAttention(nn.Module): """Multi-Head Attention layer. @@ -959,3 +962,37 @@ class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): q, k, v = self.forward_qkv(query, key, value) scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) + + +class CosineDistanceAttention(nn.Module): + """ Compute Cosine Distance between spk decoder output and speaker profile + Args: + profile_path: speaker profile file path (.npy file) + """ + + def __init__(self): + super().__init__() + self.softmax = nn.Softmax(dim=-1) + + def forward(self, spk_decoder_out, profile, profile_lens=None): + """ + Args: + spk_decoder_out(torch.Tensor):(B, L, D) + spk_profiles(torch.Tensor):(B, N, D) + """ + x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D) + if profile_lens is not None: + + mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device) + min_value = float( + numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min + ) + weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value) + weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N) + else: + x = x[:, -1:, :, :] + weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1) + weights = self.softmax(weights_not_softmax) # (B, 1, N) + spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D) + + return spk_embedding, weights diff --git a/funasr/modules/beam_search/beam_search_sa_asr.py b/funasr/modules/beam_search/beam_search_sa_asr.py new file mode 100755 index 000000000..b2b6833c8 --- /dev/null +++ b/funasr/modules/beam_search/beam_search_sa_asr.py @@ -0,0 +1,525 @@ +"""Beam search module.""" + +from itertools import chain +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Tuple +from typing import Union + +import torch + +from funasr.modules.e2e_asr_common import end_detect +from funasr.modules.scorers.scorer_interface import PartialScorerInterface +from funasr.modules.scorers.scorer_interface import ScorerInterface +from funasr.models.decoder.abs_decoder import AbsDecoder + + +class Hypothesis(NamedTuple): + """Hypothesis data type.""" + + yseq: torch.Tensor + spk_weigths : List + score: Union[float, torch.Tensor] = 0 + scores: Dict[str, Union[float, torch.Tensor]] = dict() + states: Dict[str, Any] = dict() + + def asdict(self) -> dict: + """Convert data to JSON-friendly dict.""" + return self._replace( + yseq=self.yseq.tolist(), + score=float(self.score), + scores={k: float(v) for k, v in self.scores.items()}, + )._asdict() + + +class BeamSearch(torch.nn.Module): + """Beam search implementation.""" + + def __init__( + self, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + beam_size: int, + vocab_size: int, + sos: int, + eos: int, + token_list: List[str] = None, + pre_beam_ratio: float = 1.5, + pre_beam_score_key: str = None, + ): + """Initialize beam search. + + Args: + scorers (dict[str, ScorerInterface]): Dict of decoder modules + e.g., Decoder, CTCPrefixScorer, LM + The scorer will be ignored if it is `None` + weights (dict[str, float]): Dict of weights for each scorers + The scorer will be ignored if its weight is 0 + beam_size (int): The number of hypotheses kept during search + vocab_size (int): The number of vocabulary + sos (int): Start of sequence id + eos (int): End of sequence id + token_list (list[str]): List of tokens for debug log + pre_beam_score_key (str): key of scores to perform pre-beam search + pre_beam_ratio (float): beam size in the pre-beam search + will be `int(pre_beam_ratio * beam_size)` + + """ + super().__init__() + # set scorers + self.weights = weights + self.scorers = dict() + self.full_scorers = dict() + self.part_scorers = dict() + # this module dict is required for recursive cast + # `self.to(device, dtype)` in `recog.py` + self.nn_dict = torch.nn.ModuleDict() + for k, v in scorers.items(): + w = weights.get(k, 0) + if w == 0 or v is None: + continue + assert isinstance( + v, ScorerInterface + ), f"{k} ({type(v)}) does not implement ScorerInterface" + self.scorers[k] = v + if isinstance(v, PartialScorerInterface): + self.part_scorers[k] = v + else: + self.full_scorers[k] = v + if isinstance(v, torch.nn.Module): + self.nn_dict[k] = v + + # set configurations + self.sos = sos + self.eos = eos + self.token_list = token_list + self.pre_beam_size = int(pre_beam_ratio * beam_size) + self.beam_size = beam_size + self.n_vocab = vocab_size + if ( + pre_beam_score_key is not None + and pre_beam_score_key != "full" + and pre_beam_score_key not in self.full_scorers + ): + raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") + self.pre_beam_score_key = pre_beam_score_key + self.do_pre_beam = ( + self.pre_beam_score_key is not None + and self.pre_beam_size < self.n_vocab + and len(self.part_scorers) > 0 + ) + + def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]: + """Get an initial hypothesis data. + + Args: + x (torch.Tensor): The encoder output feature + + Returns: + Hypothesis: The initial hypothesis. + + """ + init_states = dict() + init_scores = dict() + for k, d in self.scorers.items(): + init_states[k] = d.init_state(x) + init_scores[k] = 0.0 + return [ + Hypothesis( + score=0.0, + scores=init_scores, + states=init_states, + yseq=torch.tensor([self.sos], device=x.device), + spk_weigths=[], + ) + ] + + @staticmethod + def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: + """Append new token to prefix tokens. + + Args: + xs (torch.Tensor): The prefix token + x (int): The new token to append + + Returns: + torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device + + """ + x = torch.tensor([x], dtype=xs.dtype, device=xs.device) + return torch.cat((xs, x)) + + def score_full( + self, hyp: Hypothesis, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.full_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.full_scorers` + and tensor score values of shape: `(self.n_vocab,)`, + and state dict that has string keys + and state values of `self.full_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.full_scorers.items(): + if isinstance(d, AbsDecoder): + scores[k], spk_weigths, states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc, spk_enc, profile) + else: + scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc) + return scores, spk_weigths, states + + def score_partial( + self, hyp: Hypothesis, ids: torch.Tensor, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.part_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + ids (torch.Tensor): 1D tensor of new partial tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.part_scorers` + and tensor score values of shape: `(len(ids),)`, + and state dict that has string keys + and state values of `self.part_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.part_scorers.items(): + if isinstance(d, AbsDecoder): + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc, spk_enc, profile) + else: + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc) + return scores, states + + def beam( + self, weighted_scores: torch.Tensor, ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute topk full token ids and partial token ids. + + Args: + weighted_scores (torch.Tensor): The weighted sum scores for each tokens. + Its shape is `(self.n_vocab,)`. + ids (torch.Tensor): The partial token ids to compute topk + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + The topk full token ids and partial token ids. + Their shapes are `(self.beam_size,)` + + """ + # no pre beam performed + if weighted_scores.size(0) == ids.size(0): + top_ids = weighted_scores.topk(self.beam_size)[1] + return top_ids, top_ids + + # mask pruned in pre-beam not to select in topk + tmp = weighted_scores[ids] + weighted_scores[:] = -float("inf") + weighted_scores[ids] = tmp + top_ids = weighted_scores.topk(self.beam_size)[1] + local_ids = weighted_scores[ids].topk(self.beam_size)[1] + return top_ids, local_ids + + @staticmethod + def merge_scores( + prev_scores: Dict[str, float], + next_full_scores: Dict[str, torch.Tensor], + full_idx: int, + next_part_scores: Dict[str, torch.Tensor], + part_idx: int, + ) -> Dict[str, torch.Tensor]: + """Merge scores for new hypothesis. + + Args: + prev_scores (Dict[str, float]): + The previous hypothesis scores by `self.scorers` + next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` + full_idx (int): The next token id for `next_full_scores` + next_part_scores (Dict[str, torch.Tensor]): + scores of partial tokens by `self.part_scorers` + part_idx (int): The new token id for `next_part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are scalar tensors by the scorers. + + """ + new_scores = dict() + for k, v in next_full_scores.items(): + new_scores[k] = prev_scores[k] + v[full_idx] + for k, v in next_part_scores.items(): + new_scores[k] = prev_scores[k] + v[part_idx] + return new_scores + + def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: + """Merge states for new hypothesis. + + Args: + states: states of `self.full_scorers` + part_states: states of `self.part_scorers` + part_idx (int): The new token id for `part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are states of the scorers. + + """ + new_states = dict() + for k, v in states.items(): + new_states[k] = v + for k, d in self.part_scorers.items(): + new_states[k] = d.select_state(part_states[k], part_idx) + return new_states + + def search( + self, running_hyps: List[Hypothesis], asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor + ) -> List[Hypothesis]: + """Search new tokens for running hypotheses and encoded speech x. + + Args: + running_hyps (List[Hypothesis]): Running hypotheses on beam + x (torch.Tensor): Encoded speech feature (T, D) + + Returns: + List[Hypotheses]: Best sorted hypotheses + + """ + # import ipdb;ipdb.set_trace() + best_hyps = [] + part_ids = torch.arange(self.n_vocab, device=asr_enc.device) # no pre-beam + for hyp in running_hyps: + # scoring + weighted_scores = torch.zeros(self.n_vocab, dtype=asr_enc.dtype, device=asr_enc.device) + scores, spk_weigths, states = self.score_full(hyp, asr_enc, spk_enc, profile) + for k in self.full_scorers: + weighted_scores += self.weights[k] * scores[k] + # partial scoring + if self.do_pre_beam: + pre_beam_scores = ( + weighted_scores + if self.pre_beam_score_key == "full" + else scores[self.pre_beam_score_key] + ) + part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] + part_scores, part_states = self.score_partial(hyp, part_ids, asr_enc, spk_enc, profile) + for k in self.part_scorers: + weighted_scores[part_ids] += self.weights[k] * part_scores[k] + # add previous hyp score + weighted_scores += hyp.score + + # update hyps + for j, part_j in zip(*self.beam(weighted_scores, part_ids)): + # will be (2 x beam at most) + best_hyps.append( + Hypothesis( + score=weighted_scores[j], + yseq=self.append_token(hyp.yseq, j), + scores=self.merge_scores( + hyp.scores, scores, j, part_scores, part_j + ), + states=self.merge_states(states, part_states, part_j), + spk_weigths=hyp.spk_weigths+[spk_weigths], + ) + ) + + # sort and prune 2 x beam -> beam + best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ + : min(len(best_hyps), self.beam_size) + ] + return best_hyps + + def forward( + self, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 + ) -> List[Hypothesis]: + """Perform beam search. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + minlenratio (float): Input length ratio to obtain min output length. + + Returns: + list[Hypothesis]: N-best decoding results + + """ + # import ipdb;ipdb.set_trace() + # set length bounds + if maxlenratio == 0: + maxlen = asr_enc.shape[0] + else: + maxlen = max(1, int(maxlenratio * asr_enc.size(0))) + minlen = int(minlenratio * asr_enc.size(0)) + logging.info("decoder input length: " + str(asr_enc.shape[0])) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # main loop of prefix search + running_hyps = self.init_hyp(asr_enc) + ended_hyps = [] + for i in range(maxlen): + logging.debug("position " + str(i)) + best = self.search(running_hyps, asr_enc, spk_enc, profile) + #import pdb;pdb.set_trace() + # post process of one iteration + running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) + # end detection + if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): + logging.info(f"end detected at {i}") + break + if len(running_hyps) == 0: + logging.info("no hypothesis. Finish decoding.") + break + else: + logging.debug(f"remained hypotheses: {len(running_hyps)}") + + nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) + # check the number of hypotheses reaching to eos + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform recognition " + "again with smaller minlenratio." + ) + return ( + [] + if minlenratio < 0.1 + else self.forward(asr_enc, spk_enc, profile, maxlenratio, max(0.0, minlenratio - 0.1)) + ) + + # report the best result + best = nbest_hyps[0] + for k, v in best.scores.items(): + logging.info( + f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" + ) + logging.info(f"total log probability: {best.score:.2f}") + logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") + if self.token_list is not None: + logging.info( + "best hypo: " + + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + + "\n" + ) + return nbest_hyps + + def post_process( + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: List[Hypothesis], + ended_hyps: List[Hypothesis], + ) -> List[Hypothesis]: + """Perform post-processing of beam search iterations. + + Args: + i (int): The length of hypothesis tokens. + maxlen (int): The maximum length of tokens in beam search. + maxlenratio (int): The maximum length ratio in beam search. + running_hyps (List[Hypothesis]): The running hypotheses in beam search. + ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. + + Returns: + List[Hypothesis]: The new running hypotheses. + + """ + logging.debug(f"the number of running hypotheses: {len(running_hyps)}") + if self.token_list is not None: + logging.debug( + "best hypo: " + + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) + ) + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last position in the loop") + running_hyps = [ + h._replace(yseq=self.append_token(h.yseq, self.eos)) + for h in running_hyps + ] + + # add ended hypotheses to a final list, and removed them from current hypotheses + # (this will be a problem, number of hyps < beam) + remained_hyps = [] + for hyp in running_hyps: + if hyp.yseq[-1] == self.eos: + # e.g., Word LM needs to add final score + for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): + s = d.final_score(hyp.states[k]) + hyp.scores[k] += s + hyp = hyp._replace(score=hyp.score + self.weights[k] * s) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + return remained_hyps + + +def beam_search( + x: torch.Tensor, + sos: int, + eos: int, + beam_size: int, + vocab_size: int, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + token_list: List[str] = None, + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + pre_beam_ratio: float = 1.5, + pre_beam_score_key: str = "full", +) -> list: + """Perform beam search with scorers. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + sos (int): Start of sequence id + eos (int): End of sequence id + beam_size (int): The number of hypotheses kept during search + vocab_size (int): The number of vocabulary + scorers (dict[str, ScorerInterface]): Dict of decoder modules + e.g., Decoder, CTCPrefixScorer, LM + The scorer will be ignored if it is `None` + weights (dict[str, float]): Dict of weights for each scorers + The scorer will be ignored if its weight is 0 + token_list (list[str]): List of tokens for debug log + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + minlenratio (float): Input length ratio to obtain min output length. + pre_beam_score_key (str): key of scores to perform pre-beam search + pre_beam_ratio (float): beam size in the pre-beam search + will be `int(pre_beam_ratio * beam_size)` + + Returns: + list: N-best decoding results + + """ + ret = BeamSearch( + scorers, + weights, + beam_size=beam_size, + vocab_size=vocab_size, + pre_beam_ratio=pre_beam_ratio, + pre_beam_score_key=pre_beam_score_key, + sos=sos, + eos=eos, + token_list=token_list, + ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) + return [h.asdict() for h in ret] diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 3d2004c2d..f8c100961 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -444,6 +444,12 @@ class AbsTask(ABC): default=False, help='Perform on "collect stats" mode', ) + group.add_argument( + "--mc", + type=bool, + default=False, + help="MultiChannel input", + ) group.add_argument( "--write_collected_feats", type=str2bool, @@ -635,8 +641,8 @@ class AbsTask(ABC): group.add_argument( "--init_param", type=str, + action="append", default=[], - nargs="*", help="Specify the file path used for initialization of parameters. " "The format is ':::', " "where file_path is the model file path, " @@ -662,7 +668,7 @@ class AbsTask(ABC): "--freeze_param", type=str, default=[], - nargs="*", + action="append", help="Freeze parameters", ) @@ -1153,10 +1159,10 @@ class AbsTask(ABC): elif args.distributed and args.simple_ddp: distributed_option.init_torch_distributed_pai(args) args.ngpu = dist.get_world_size() - if args.dataset_type == "small": + if args.dataset_type == "small" and args.ngpu > 0: if args.batch_size is not None: args.batch_size = args.batch_size * args.ngpu - if args.batch_bins is not None: + if args.batch_bins is not None and args.ngpu > 0: args.batch_bins = args.batch_bins * args.ngpu # filter samples if wav.scp and text are mismatch @@ -1316,6 +1322,7 @@ class AbsTask(ABC): data_path_and_name_and_type=args.train_data_path_and_name_and_type, key_file=train_key_file, batch_size=args.batch_size, + mc=args.mc, dtype=args.train_dtype, num_workers=args.num_workers, allow_variable_data_keys=args.allow_variable_data_keys, @@ -1327,6 +1334,7 @@ class AbsTask(ABC): data_path_and_name_and_type=args.valid_data_path_and_name_and_type, key_file=valid_key_file, batch_size=args.valid_batch_size, + mc=args.mc, dtype=args.train_dtype, num_workers=args.num_workers, allow_variable_data_keys=args.allow_variable_data_keys, diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py new file mode 100644 index 000000000..738ec522d --- /dev/null +++ b/funasr/tasks/sa_asr.py @@ -0,0 +1,623 @@ +import argparse +import logging +import os +from pathlib import Path +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import torch +import yaml +from typeguard import check_argument_types +from typeguard import check_return_type + +from funasr.datasets.collate_fn import CommonCollateFn +from funasr.datasets.preprocessor import CommonPreprocessor +from funasr.layers.abs_normalize import AbsNormalize +from funasr.layers.global_mvn import GlobalMVN +from funasr.layers.utterance_mvn import UtteranceMVN +from funasr.models.ctc import CTC +from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.decoder.rnn_decoder import RNNDecoder +from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt +from funasr.models.decoder.transformer_decoder import ( + DynamicConvolution2DTransformerDecoder, # noqa: H301 +) +from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder +from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder +from funasr.models.decoder.transformer_decoder import ( + LightweightConvolution2DTransformerDecoder, # noqa: H301 +) +from funasr.models.decoder.transformer_decoder import ( + LightweightConvolutionTransformerDecoder, # noqa: H301 +) +from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN +from funasr.models.decoder.transformer_decoder import TransformerDecoder +from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder +from funasr.models.e2e_sa_asr import ESPnetASRModel +from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer +from funasr.models.e2e_tp import TimestampPredictor +from funasr.models.e2e_asr_mfcca import MFCCA +from funasr.models.e2e_uni_asr import UniASR +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.encoder.conformer_encoder import ConformerEncoder +from funasr.models.encoder.data2vec_encoder import Data2VecEncoder +from funasr.models.encoder.rnn_encoder import RNNEncoder +from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt +from funasr.models.encoder.transformer_encoder import TransformerEncoder +from funasr.models.encoder.mfcca_encoder import MFCCAEncoder +from funasr.models.encoder.resnet34_encoder import ResNet34,ResNet34Diar +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.frontend.default import DefaultFrontend +from funasr.models.frontend.default import MultiChannelFrontend +from funasr.models.frontend.fused import FusedFrontends +from funasr.models.frontend.s3prl import S3prlFrontend +from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.models.frontend.windowing import SlidingWindow +from funasr.models.postencoder.abs_postencoder import AbsPostEncoder +from funasr.models.postencoder.hugging_face_transformers_postencoder import ( + HuggingFaceTransformersPostEncoder, # noqa: H301 +) +from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3 +from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.preencoder.linear import LinearProjection +from funasr.models.preencoder.sinc import LightweightSincConvs +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.specaug.specaug import SpecAug +from funasr.models.specaug.specaug import SpecAugLFR +from funasr.modules.subsampling import Conv1dSubsampling +from funasr.tasks.abs_task import AbsTask +from funasr.text.phoneme_tokenizer import g2p_choices +from funasr.torch_utils.initialize import initialize +from funasr.train.abs_espnet_model import AbsESPnetModel +from funasr.train.class_choices import ClassChoices +from funasr.train.trainer import Trainer +from funasr.utils.get_default_kwargs import get_default_kwargs +from funasr.utils.nested_dict_action import NestedDictAction +from funasr.utils.types import float_or_none +from funasr.utils.types import int_or_none +from funasr.utils.types import str2bool +from funasr.utils.types import str_or_none + +frontend_choices = ClassChoices( + name="frontend", + classes=dict( + default=DefaultFrontend, + sliding_window=SlidingWindow, + s3prl=S3prlFrontend, + fused=FusedFrontends, + wav_frontend=WavFrontend, + multichannelfrontend=MultiChannelFrontend, + ), + type_check=AbsFrontend, + default="default", +) +specaug_choices = ClassChoices( + name="specaug", + classes=dict( + specaug=SpecAug, + specaug_lfr=SpecAugLFR, + ), + type_check=AbsSpecAug, + default=None, + optional=True, +) +normalize_choices = ClassChoices( + "normalize", + classes=dict( + global_mvn=GlobalMVN, + utterance_mvn=UtteranceMVN, + ), + type_check=AbsNormalize, + default=None, + optional=True, +) +model_choices = ClassChoices( + "model", + classes=dict( + asr=ESPnetASRModel, + uniasr=UniASR, + paraformer=Paraformer, + paraformer_bert=ParaformerBert, + bicif_paraformer=BiCifParaformer, + contextual_paraformer=ContextualParaformer, + mfcca=MFCCA, + timestamp_prediction=TimestampPredictor, + ), + type_check=AbsESPnetModel, + default="asr", +) +preencoder_choices = ClassChoices( + name="preencoder", + classes=dict( + sinc=LightweightSincConvs, + linear=LinearProjection, + ), + type_check=AbsPreEncoder, + default=None, + optional=True, +) +asr_encoder_choices = ClassChoices( + "asr_encoder", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + rnn=RNNEncoder, + sanm=SANMEncoder, + sanm_chunk_opt=SANMEncoderChunkOpt, + data2vec_encoder=Data2VecEncoder, + mfcca_enc=MFCCAEncoder, + ), + type_check=AbsEncoder, + default="rnn", +) + +spk_encoder_choices = ClassChoices( + "spk_encoder", + classes=dict( + resnet34_diar=ResNet34Diar, + ), + default="resnet34_diar", +) + +encoder_choices2 = ClassChoices( + "encoder2", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + rnn=RNNEncoder, + sanm=SANMEncoder, + sanm_chunk_opt=SANMEncoderChunkOpt, + ), + type_check=AbsEncoder, + default="rnn", +) +postencoder_choices = ClassChoices( + name="postencoder", + classes=dict( + hugging_face_transformers=HuggingFaceTransformersPostEncoder, + ), + type_check=AbsPostEncoder, + default=None, + optional=True, +) +decoder_choices = ClassChoices( + "decoder", + classes=dict( + transformer=TransformerDecoder, + lightweight_conv=LightweightConvolutionTransformerDecoder, + lightweight_conv2d=LightweightConvolution2DTransformerDecoder, + dynamic_conv=DynamicConvolutionTransformerDecoder, + dynamic_conv2d=DynamicConvolution2DTransformerDecoder, + rnn=RNNDecoder, + fsmn_scama_opt=FsmnDecoderSCAMAOpt, + paraformer_decoder_sanm=ParaformerSANMDecoder, + paraformer_decoder_san=ParaformerDecoderSAN, + contextual_paraformer_decoder=ContextualParaformerDecoder, + sa_decoder=SAAsrTransformerDecoder, + ), + type_check=AbsDecoder, + default="sa_decoder", +) +decoder_choices2 = ClassChoices( + "decoder2", + classes=dict( + transformer=TransformerDecoder, + lightweight_conv=LightweightConvolutionTransformerDecoder, + lightweight_conv2d=LightweightConvolution2DTransformerDecoder, + dynamic_conv=DynamicConvolutionTransformerDecoder, + dynamic_conv2d=DynamicConvolution2DTransformerDecoder, + rnn=RNNDecoder, + fsmn_scama_opt=FsmnDecoderSCAMAOpt, + paraformer_decoder_sanm=ParaformerSANMDecoder, + ), + type_check=AbsDecoder, + default="rnn", +) +predictor_choices = ClassChoices( + name="predictor", + classes=dict( + cif_predictor=CifPredictor, + ctc_predictor=None, + cif_predictor_v2=CifPredictorV2, + cif_predictor_v3=CifPredictorV3, + ), + type_check=None, + default="cif_predictor", + optional=True, +) +predictor_choices2 = ClassChoices( + name="predictor2", + classes=dict( + cif_predictor=CifPredictor, + ctc_predictor=None, + cif_predictor_v2=CifPredictorV2, + ), + type_check=None, + default="cif_predictor", + optional=True, +) +stride_conv_choices = ClassChoices( + name="stride_conv", + classes=dict( + stride_conv1d=Conv1dSubsampling + ), + type_check=None, + default="stride_conv1d", + optional=True, +) + + +class ASRTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + specaug_choices, + # --normalize and --normalize_conf + normalize_choices, + # --model and --model_conf + model_choices, + # --preencoder and --preencoder_conf + preencoder_choices, + # --asr_encoder and --asr_encoder_conf + asr_encoder_choices, + # --spk_encoder and --spk_encoder_conf + spk_encoder_choices, + # --postencoder and --postencoder_conf + postencoder_choices, + # --decoder and --decoder_conf + decoder_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + # required = parser.get_default("required") + # required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) + group.add_argument( + "--max_spk_num", + type=int_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--seg_dict_file", + type=str, + default=None, + help="seg_dict_file for text processing", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group.add_argument( + "--ctc_conf", + action=NestedDictAction, + default=get_default_kwargs(CTC), + help="The keyword arguments for CTC class.", + ) + group.add_argument( + "--joint_net_conf", + action=NestedDictAction, + default=None, + help="The keyword arguments for joint network class.", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word", "phn"], + help="The text will be tokenized " "in the specified level token", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model file of sentencepiece", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + default=None, + help="non_linguistic_symbols file path", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=g2p_choices, + default=None, + help="Specify g2p method if --token_type=phn", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Scale the maximum amplitude to the given value.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The file path of rir scp file.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="THe probability for applying RIR convolution.", + ) + parser.add_argument( + "--cmvn_file", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability applying Noise adding.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of noise decibel level.", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, + seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech", "text") + else: + # Recognition mode + retval = ("speech",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace): + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size}") + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + if args.frontend == 'wav_frontend': + frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) + else: + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Pre-encoder input block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + if getattr(args, "preencoder", None) is not None: + preencoder_class = preencoder_choices.get_class(args.preencoder) + preencoder = preencoder_class(**args.preencoder_conf) + input_size = preencoder.output_size() + else: + preencoder = None + + # 5. Encoder + asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder) + asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf) + spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder) + spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf) + + # 6. Post-encoder block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + asr_encoder_output_size = asr_encoder.output_size() + if getattr(args, "postencoder", None) is not None: + postencoder_class = postencoder_choices.get_class(args.postencoder) + postencoder = postencoder_class( + input_size=asr_encoder_output_size, **args.postencoder_conf + ) + asr_encoder_output_size = postencoder.output_size() + else: + postencoder = None + + # 7. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=asr_encoder_output_size, + **args.decoder_conf, + ) + + # 8. CTC + ctc = CTC( + odim=vocab_size, encoder_output_size=asr_encoder_output_size, **args.ctc_conf + ) + + max_spk_num=int(args.max_spk_num) + + # import ipdb;ipdb.set_trace() + # 9. Build model + try: + model_class = model_choices.get_class(args.model) + except AttributeError: + model_class = model_choices.get_class("asr") + model = model_class( + vocab_size=vocab_size, + max_spk_num=max_spk_num, + frontend=frontend, + specaug=specaug, + normalize=normalize, + preencoder=preencoder, + asr_encoder=asr_encoder, + spk_encoder=spk_encoder, + postencoder=postencoder, + decoder=decoder, + ctc=ctc, + token_list=token_list, + **args.model_conf, + ) + + # 10. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py index b607e1da0..014a79ffa 100644 --- a/funasr/utils/postprocess_utils.py +++ b/funasr/utils/postprocess_utils.py @@ -106,18 +106,17 @@ def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]: if num in abbr_begin: if time_stamp is not None: begin = time_stamp[ts_nums[num]][0] - abbr_word = words[num].upper() + word_lists.append(words[num].upper()) num += 1 while num < words_size: if num in abbr_end: - abbr_word += words[num].upper() + word_lists.append(words[num].upper()) last_num = num break else: if words[num].encode('utf-8').isalpha(): - abbr_word += words[num].upper() + word_lists.append(words[num].upper()) num += 1 - word_lists.append(abbr_word) if time_stamp is not None: end = time_stamp[ts_nums[num]][1] ts_lists.append([begin, end]) diff --git a/setup.py b/setup.py index e83763726..ea556066b 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ requirements = { "install": [ "setuptools>=38.5.1", # "configargparse>=1.2.1", - "typeguard<=2.13.3", + "typeguard==2.13.3", "humanfriendly", "scipy>=1.4.1", # "filelock", @@ -42,7 +42,10 @@ requirements = { "oss2", # "kaldi-native-fbank", # timestamp - "edit-distance" + "edit-distance", + # textgrid + "textgrid", + "protobuf==3.20.0", ], # train: The modules invoked when training only. "train": [