From a73123bcfc14370b74b17084bc124f00c48613e4 Mon Sep 17 00:00:00 2001 From: smohan-speech Date: Sat, 6 May 2023 16:17:48 +0800 Subject: [PATCH 1/5] add speaker-attributed ASR task for alimeeting --- egs/alimeeting/sa-asr/asr_local.sh | 1562 +++++++++++++++++ egs/alimeeting/sa-asr/asr_local_infer.sh | 590 +++++++ .../sa-asr/conf/decode_asr_rnn.yaml | 6 + .../sa-asr/conf/train_asr_conformer.yaml | 88 + .../sa-asr/conf/train_lm_transformer.yaml | 29 + .../sa-asr/conf/train_sa_asr_conformer.yaml | 116 ++ .../sa-asr/local/alimeeting_data_prep.sh | 162 ++ .../local/alimeeting_data_prep_test_2023.sh | 129 ++ .../local/alimeeting_process_overlap_force.py | 235 +++ .../local/alimeeting_process_textgrid.py | 158 ++ egs/alimeeting/sa-asr/local/compute_cpcer.py | 91 + egs/alimeeting/sa-asr/local/compute_wer.py | 157 ++ .../sa-asr/local/download_xvector_model.py | 6 + .../sa-asr/local/filter_utt2spk_all_fifo.py | 22 + .../sa-asr/local/gen_cluster_profile_infer.py | 167 ++ .../sa-asr/local/gen_oracle_embedding.py | 70 + .../local/gen_oracle_profile_nopadding.py | 59 + .../local/gen_oracle_profile_padding.py | 68 + egs/alimeeting/sa-asr/local/proce_text.py | 32 + .../local/process_sot_fifo_textchar2spk.py | 86 + .../sa-asr/local/process_text_id.py | 24 + .../sa-asr/local/process_text_spk_merge.py | 55 + .../process_textgrid_to_single_speaker_wav.py | 127 ++ egs/alimeeting/sa-asr/local/text_format.pl | 14 + egs/alimeeting/sa-asr/local/text_normalize.pl | 38 + egs/alimeeting/sa-asr/path.sh | 6 + .../sa-asr/pyscripts/audio/format_wav_scp.py | 243 +++ .../sa-asr/pyscripts/utils/print_args.py | 45 + egs/alimeeting/sa-asr/run_m2met_2023.sh | 51 + egs/alimeeting/sa-asr/run_m2met_2023_infer.sh | 50 + .../sa-asr/scripts/audio/format_wav_scp.sh | 142 ++ .../scripts/utils/perturb_data_dir_speed.sh | 116 ++ egs/alimeeting/sa-asr/utils/apply_map.pl | 97 + egs/alimeeting/sa-asr/utils/combine_data.sh | 146 ++ egs/alimeeting/sa-asr/utils/copy_data_dir.sh | 145 ++ .../sa-asr/utils/data/get_reco2dur.sh | 143 ++ .../utils/data/get_segments_for_data.sh | 29 + .../sa-asr/utils/data/get_utt2dur.sh | 135 ++ .../sa-asr/utils/data/split_data.sh | 160 ++ egs/alimeeting/sa-asr/utils/filter_scp.pl | 87 + egs/alimeeting/sa-asr/utils/fix_data_dir.sh | 215 +++ egs/alimeeting/sa-asr/utils/parse_options.sh | 97 + .../sa-asr/utils/spk2utt_to_utt2spk.pl | 27 + egs/alimeeting/sa-asr/utils/split_scp.pl | 246 +++ .../sa-asr/utils/utt2spk_to_spk2utt.pl | 38 + .../sa-asr/utils/validate_data_dir.sh | 404 +++++ egs/alimeeting/sa-asr/utils/validate_text.pl | 136 ++ funasr/bin/asr_inference.py | 34 +- funasr/bin/asr_inference_launch.py | 5 +- funasr/bin/asr_train.py | 15 +- funasr/bin/sa_asr_inference.py | 674 +++++++ funasr/bin/sa_asr_train.py | 55 + funasr/fileio/sound_scp.py | 6 +- funasr/losses/nll_loss.py | 47 + funasr/models/decoder/decoder_layer_sa_asr.py | 169 ++ .../decoder/transformer_decoder_sa_asr.py | 291 +++ funasr/models/e2e_sa_asr.py | 521 ++++++ funasr/models/frontend/default.py | 11 +- funasr/models/pooling/statistic_pooling.py | 4 +- funasr/modules/attention.py | 37 + .../modules/beam_search/beam_search_sa_asr.py | 525 ++++++ funasr/tasks/abs_task.py | 16 +- funasr/tasks/sa_asr.py | 623 +++++++ funasr/utils/postprocess_utils.py | 7 +- setup.py | 7 +- 65 files changed, 9859 insertions(+), 37 deletions(-) create mode 100755 egs/alimeeting/sa-asr/asr_local.sh create mode 100755 egs/alimeeting/sa-asr/asr_local_infer.sh create mode 100644 egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml create mode 100644 egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml create mode 100644 egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml create mode 100644 egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml create mode 100755 egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh create mode 100755 egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh create mode 100755 egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py create mode 100755 egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py create mode 100644 egs/alimeeting/sa-asr/local/compute_cpcer.py create mode 100755 egs/alimeeting/sa-asr/local/compute_wer.py create mode 100644 egs/alimeeting/sa-asr/local/download_xvector_model.py create mode 100644 egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py create mode 100644 egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py create mode 100644 egs/alimeeting/sa-asr/local/gen_oracle_embedding.py create mode 100644 egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py create mode 100644 egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py create mode 100755 egs/alimeeting/sa-asr/local/proce_text.py create mode 100755 egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py create mode 100644 egs/alimeeting/sa-asr/local/process_text_id.py create mode 100644 egs/alimeeting/sa-asr/local/process_text_spk_merge.py create mode 100755 egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py create mode 100755 egs/alimeeting/sa-asr/local/text_format.pl create mode 100755 egs/alimeeting/sa-asr/local/text_normalize.pl create mode 100755 egs/alimeeting/sa-asr/path.sh create mode 100755 egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py create mode 100755 egs/alimeeting/sa-asr/pyscripts/utils/print_args.py create mode 100755 egs/alimeeting/sa-asr/run_m2met_2023.sh create mode 100755 egs/alimeeting/sa-asr/run_m2met_2023_infer.sh create mode 100755 egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh create mode 100755 egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh create mode 100755 egs/alimeeting/sa-asr/utils/apply_map.pl create mode 100755 egs/alimeeting/sa-asr/utils/combine_data.sh create mode 100755 egs/alimeeting/sa-asr/utils/copy_data_dir.sh create mode 100755 egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh create mode 100755 egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh create mode 100755 egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh create mode 100755 egs/alimeeting/sa-asr/utils/data/split_data.sh create mode 100755 egs/alimeeting/sa-asr/utils/filter_scp.pl create mode 100755 egs/alimeeting/sa-asr/utils/fix_data_dir.sh create mode 100755 egs/alimeeting/sa-asr/utils/parse_options.sh create mode 100755 egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl create mode 100755 egs/alimeeting/sa-asr/utils/split_scp.pl create mode 100755 egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl create mode 100755 egs/alimeeting/sa-asr/utils/validate_data_dir.sh create mode 100755 egs/alimeeting/sa-asr/utils/validate_text.pl create mode 100644 funasr/bin/sa_asr_inference.py create mode 100755 funasr/bin/sa_asr_train.py create mode 100644 funasr/losses/nll_loss.py create mode 100644 funasr/models/decoder/decoder_layer_sa_asr.py create mode 100644 funasr/models/decoder/transformer_decoder_sa_asr.py create mode 100644 funasr/models/e2e_sa_asr.py create mode 100755 funasr/modules/beam_search/beam_search_sa_asr.py create mode 100644 funasr/tasks/sa_asr.py diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh new file mode 100755 index 000000000..c0359eb35 --- /dev/null +++ b/egs/alimeeting/sa-asr/asr_local.sh @@ -0,0 +1,1562 @@ +#!/usr/bin/env bash + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} +min() { + local a b + a=$1 + for b in "$@"; do + if [ "${b}" -le "${a}" ]; then + a="${b}" + fi + done + echo "${a}" +} +SECONDS=0 + +# General configuration +stage=1 # Processes starts from the specified stage. +stop_stage=10000 # Processes is stopped at the specified stage. +skip_data_prep=false # Skip data preparation stages. +skip_train=false # Skip training stages. +skip_eval=false # Skip decoding and evaluation stages. +skip_upload=true # Skip packing and uploading stages. +ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu). +num_nodes=1 # The number of nodes. +nj=16 # The number of parallel jobs. +inference_nj=16 # The number of parallel jobs in decoding. +gpu_inference=false # Whether to perform gpu decoding. +njob_infer=4 +dumpdir=dump2 # Directory to dump features. +expdir=exp # Directory to save experiments. +python=python3 # Specify python to execute espnet commands. +device=0 + +# Data preparation related +local_data_opts= # The options given to local/data.sh. + +# Speed perturbation related +speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space). + +# Feature extraction related +feats_type=raw # Feature type (raw or fbank_pitch). +audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw). +fs=16000 # Sampling rate. +min_wav_duration=0.1 # Minimum duration in second. +max_wav_duration=20 # Maximum duration in second. + +# Tokenization related +token_type=bpe # Tokenization type (char or bpe). +nbpe=30 # The number of BPE vocabulary. +bpemode=unigram # Mode of BPE (unigram or bpe). +oov="" # Out of vocabulary symbol. +blank="" # CTC blank symbol +sos_eos="" # sos and eos symbole +bpe_input_sentence_size=100000000 # Size of input sentence for BPE. +bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE +bpe_char_cover=1.0 # character coverage when modeling BPE + +# Language model related +use_lm=true # Use language model for ASR decoding. +lm_tag= # Suffix to the result dir for language model training. +lm_exp= # Specify the direcotry path for LM experiment. + # If this option is specified, lm_tag is ignored. +lm_stats_dir= # Specify the direcotry path for LM statistics. +lm_config= # Config for language model training. +lm_args= # Arguments for language model training, e.g., "--max_epoch 10". + # Note that it will overwrite args in lm config. +use_word_lm=false # Whether to use word language model. +num_splits_lm=1 # Number of splitting for lm corpus. +# shellcheck disable=SC2034 +word_vocab_size=10000 # Size of word vocabulary. + +# ASR model related +asr_tag= # Suffix to the result dir for asr model training. +asr_exp= # Specify the direcotry path for ASR experiment. + # If this option is specified, asr_tag is ignored. +sa_asr_exp= +asr_stats_dir= # Specify the direcotry path for ASR statistics. +asr_config= # Config for asr model training. +sa_asr_config= +asr_args= # Arguments for asr model training, e.g., "--max_epoch 10". + # Note that it will overwrite args in asr config. +feats_normalize=global_mvn # Normalizaton layer type. +num_splits_asr=1 # Number of splitting for lm corpus. + +# Decoding related +inference_tag= # Suffix to the result dir for decoding. +inference_config= # Config for decoding. +inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1". + # Note that it will overwrite args in inference config. +sa_asr_inference_tag= +sa_asr_inference_args= + +inference_lm=valid.loss.ave.pb # Language modle path for decoding. +inference_asr_model=valid.acc.ave.pb # ASR model path for decoding. + # e.g. + # inference_asr_model=train.loss.best.pth + # inference_asr_model=3epoch.pth + # inference_asr_model=valid.acc.best.pth + # inference_asr_model=valid.loss.ave.pth +inference_sa_asr_model=valid.acc_spk.ave.pb +download_model= # Download a model from Model Zoo and use it for decoding. + +# [Task dependent] Set the datadir name created by local/data.sh +train_set= # Name of training set. +valid_set= # Name of validation set used for monitoring/tuning network training. +test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified. +bpe_train_text= # Text file path of bpe training set. +lm_train_text= # Text file path of language model training set. +lm_dev_text= # Text file path of language model development set. +lm_test_text= # Text file path of language model evaluation set. +nlsyms_txt=none # Non-linguistic symbol list if existing. +cleaner=none # Text cleaner. +g2p=none # g2p method (needed if token_type=phn). +lang=zh # The language type of corpus. +score_opts= # The options given to sclite scoring +local_score_opts= # The options given to local/score.sh. + + +help_message=$(cat << EOF +Usage: $0 --train-set "" --valid-set "" --test_sets "" + +Options: + # General configuration + --stage # Processes starts from the specified stage (default="${stage}"). + --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}"). + --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}"). + --skip_train # Skip training stages (default="${skip_train}"). + --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}"). + --skip_upload # Skip packing and uploading stages (default="${skip_upload}"). + --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}"). + --num_nodes # The number of nodes (default="${num_nodes}"). + --nj # The number of parallel jobs (default="${nj}"). + --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}"). + --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}"). + --dumpdir # Directory to dump features (default="${dumpdir}"). + --expdir # Directory to save experiments (default="${expdir}"). + --python # Specify python to execute espnet commands (default="${python}"). + --device # Which GPUs are use for local training (defalut="${device}"). + + # Data preparation related + --local_data_opts # The options given to local/data.sh (default="${local_data_opts}"). + + # Speed perturbation related + --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}"). + + # Feature extraction related + --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}"). + --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}"). + --fs # Sampling rate (default="${fs}"). + --min_wav_duration # Minimum duration in second (default="${min_wav_duration}"). + --max_wav_duration # Maximum duration in second (default="${max_wav_duration}"). + + # Tokenization related + --token_type # Tokenization type (char or bpe, default="${token_type}"). + --nbpe # The number of BPE vocabulary (default="${nbpe}"). + --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}"). + --oov # Out of vocabulary symbol (default="${oov}"). + --blank # CTC blank symbol (default="${blank}"). + --sos_eos # sos and eos symbole (default="${sos_eos}"). + --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}"). + --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}"). + --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}"). + + # Language model related + --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}"). + --lm_exp # Specify the direcotry path for LM experiment. + # If this option is specified, lm_tag is ignored (default="${lm_exp}"). + --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}"). + --lm_config # Config for language model training (default="${lm_config}"). + --lm_args # Arguments for language model training (default="${lm_args}"). + # e.g., --lm_args "--max_epoch 10" + # Note that it will overwrite args in lm config. + --use_word_lm # Whether to use word language model (default="${use_word_lm}"). + --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}"). + --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}"). + + # ASR model related + --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}"). + --asr_exp # Specify the direcotry path for ASR experiment. + # If this option is specified, asr_tag is ignored (default="${asr_exp}"). + --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}"). + --asr_config # Config for asr model training (default="${asr_config}"). + --asr_args # Arguments for asr model training (default="${asr_args}"). + # e.g., --asr_args "--max_epoch 10" + # Note that it will overwrite args in asr config. + --feats_normalize # Normalizaton layer type (default="${feats_normalize}"). + --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}"). + + # Decoding related + --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}"). + --inference_config # Config for decoding (default="${inference_config}"). + --inference_args # Arguments for decoding (default="${inference_args}"). + # e.g., --inference_args "--lm_weight 0.1" + # Note that it will overwrite args in inference config. + --inference_lm # Language modle path for decoding (default="${inference_lm}"). + --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}"). + --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}"). + + # [Task dependent] Set the datadir name created by local/data.sh + --train_set # Name of training set (required). + --valid_set # Name of validation set used for monitoring/tuning network training (required). + --test_sets # Names of test sets. + # Multiple items (e.g., both dev and eval sets) can be specified (required). + --bpe_train_text # Text file path of bpe training set. + --lm_train_text # Text file path of language model training set. + --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}"). + --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}"). + --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}"). + --cleaner # Text cleaner (default="${cleaner}"). + --g2p # g2p method (default="${g2p}"). + --lang # The language type of corpus (default=${lang}). + --score_opts # The options given to sclite scoring (default="{score_opts}"). + --local_score_opts # The options given to local/score.sh (default="{local_score_opts}"). +EOF +) + +log "$0 $*" +# Save command line args for logging (they will be lost after utils/parse_options.sh) +run_args=$(python -m funasr.utils.cli_utils $0 "$@") +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + log "${help_message}" + log "Error: No positional arguments are required." + exit 2 +fi + +. ./path.sh + + +# Check required arguments +[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; }; +[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; }; +[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; }; + +# Check feature type +if [ "${feats_type}" = raw ]; then + data_feats=${dumpdir}/raw +elif [ "${feats_type}" = fbank_pitch ]; then + data_feats=${dumpdir}/fbank_pitch +elif [ "${feats_type}" = fbank ]; then + data_feats=${dumpdir}/fbank +elif [ "${feats_type}" == extracted ]; then + data_feats=${dumpdir}/extracted +else + log "${help_message}" + log "Error: not supported: --feats_type ${feats_type}" + exit 2 +fi + +# Use the same text as ASR for bpe training if not specified. +[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text" +# Use the same text as ASR for lm training if not specified. +[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text" +# Use the same text as ASR for lm training if not specified. +[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text" +# Use the text of the 1st evaldir if lm_test is not specified +[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text" + +# Check tokenization type +if [ "${lang}" != noinfo ]; then + token_listdir=data/${lang}_token_list +else + token_listdir=data/token_list +fi +bpedir="${token_listdir}/bpe_${bpemode}${nbpe}" +bpeprefix="${bpedir}"/bpe +bpemodel="${bpeprefix}".model +bpetoken_list="${bpedir}"/tokens.txt +chartoken_list="${token_listdir}"/char/tokens.txt +# NOTE: keep for future development. +# shellcheck disable=SC2034 +wordtoken_list="${token_listdir}"/word/tokens.txt + +if [ "${token_type}" = bpe ]; then + token_list="${bpetoken_list}" +elif [ "${token_type}" = char ]; then + token_list="${chartoken_list}" + bpemodel=none +elif [ "${token_type}" = word ]; then + token_list="${wordtoken_list}" + bpemodel=none +else + log "Error: not supported --token_type '${token_type}'" + exit 2 +fi +if ${use_word_lm}; then + log "Error: Word LM is not supported yet" + exit 2 + + lm_token_list="${wordtoken_list}" + lm_token_type=word +else + lm_token_list="${token_list}" + lm_token_type="${token_type}" +fi + + +# Set tag for naming of model directory +if [ -z "${asr_tag}" ]; then + if [ -n "${asr_config}" ]; then + asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}" + else + asr_tag="train_${feats_type}" + fi + if [ "${lang}" != noinfo ]; then + asr_tag+="_${lang}_${token_type}" + else + asr_tag+="_${token_type}" + fi + if [ "${token_type}" = bpe ]; then + asr_tag+="${nbpe}" + fi + # Add overwritten arg's info + if [ -n "${asr_args}" ]; then + asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" + fi + if [ -n "${speed_perturb_factors}" ]; then + asr_tag+="_sp" + fi +fi +if [ -z "${lm_tag}" ]; then + if [ -n "${lm_config}" ]; then + lm_tag="$(basename "${lm_config}" .yaml)" + else + lm_tag="train" + fi + if [ "${lang}" != noinfo ]; then + lm_tag+="_${lang}_${lm_token_type}" + else + lm_tag+="_${lm_token_type}" + fi + if [ "${lm_token_type}" = bpe ]; then + lm_tag+="${nbpe}" + fi + # Add overwritten arg's info + if [ -n "${lm_args}" ]; then + lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" + fi +fi + +# The directory used for collect-stats mode +if [ -z "${asr_stats_dir}" ]; then + if [ "${lang}" != noinfo ]; then + asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}" + else + asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}" + fi + if [ "${token_type}" = bpe ]; then + asr_stats_dir+="${nbpe}" + fi + if [ -n "${speed_perturb_factors}" ]; then + asr_stats_dir+="_sp" + fi +fi +if [ -z "${lm_stats_dir}" ]; then + if [ "${lang}" != noinfo ]; then + lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}" + else + lm_stats_dir="${expdir}/lm_stats_${lm_token_type}" + fi + if [ "${lm_token_type}" = bpe ]; then + lm_stats_dir+="${nbpe}" + fi +fi +# The directory used for training commands +if [ -z "${asr_exp}" ]; then + asr_exp="${expdir}/asr_${asr_tag}" +fi +if [ -z "${lm_exp}" ]; then + lm_exp="${expdir}/lm_${lm_tag}" +fi + + +if [ -z "${inference_tag}" ]; then + if [ -n "${inference_config}" ]; then + inference_tag="$(basename "${inference_config}" .yaml)" + else + inference_tag=inference + fi + # Add overwritten arg's info + if [ -n "${inference_args}" ]; then + inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")" + fi + if "${use_lm}"; then + inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" + fi + inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" +fi + +if [ -z "${sa_asr_inference_tag}" ]; then + if [ -n "${inference_config}" ]; then + sa_asr_inference_tag="$(basename "${inference_config}" .yaml)" + else + sa_asr_inference_tag=sa_asr_inference + fi + # Add overwritten arg's info + if [ -n "${sa_asr_inference_args}" ]; then + sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")" + fi + if "${use_lm}"; then + sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" + fi + sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" +fi + +train_cmd="run.pl" +cuda_cmd="run.pl" +decode_cmd="run.pl" + +# ========================== Main stages start from here. ========================== + +if ! "${skip_data_prep}"; then + + if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + log "Stage 1: Data preparation for data/${train_set}, data/${valid_set}, etc." + + ./local/alimeeting_data_prep.sh --tgt Test + ./local/alimeeting_data_prep.sh --tgt Eval + ./local/alimeeting_data_prep.sh --tgt Train + fi + + if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + if [ -n "${speed_perturb_factors}" ]; then + log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp" + for factor in ${speed_perturb_factors}; do + if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then + scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}" + _dirs+="data/${train_set}_sp${factor} " + else + # If speed factor is 1, same as the original + _dirs+="data/${train_set} " + fi + done + utils/combine_data.sh "data/${train_set}_sp" ${_dirs} + else + log "Skip stage 2: Speed perturbation" + fi + fi + + if [ -n "${speed_perturb_factors}" ]; then + train_set="${train_set}_sp" + fi + + if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + if [ "${feats_type}" = raw ]; then + log "Stage 3: Format wav.scp: data/ -> ${data_feats}" + + # ====== Recreating "wav.scp" ====== + # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |", + # shouldn't be used in training process. + # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file + # and it can also change the audio-format and sampling rate. + # If nothing is need, then format_wav_scp.sh does nothing: + # i.e. the input file format and rate is same as the output. + + for dset in "${train_set}" "${valid_set}" "${test_sets}" ; do + if [ "${dset}" = "${train_set}" ] || [ "${dset}" = "${valid_set}" ]; then + _suf="/org" + else + if [ "${dset}" = "${test_sets}" ] && [ "${test_sets}" = "Test_Ali_far" ]; then + _suf="/org" + else + _suf="" + fi + fi + utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" + + cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/" + + rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur} + _opts= + if [ -e data/"${dset}"/segments ]; then + # "segments" is used for splitting wav files which are written in "wav".scp + # into utterances. The file format of segments: + # + # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5" + # Where the time is written in seconds. + _opts+="--segments data/${dset}/segments " + fi + # shellcheck disable=SC2086 + scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ + --audio-format "${audio_format}" --fs "${fs}" ${_opts} \ + "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}" + + echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type" + done + + else + log "Error: not supported: --feats_type ${feats_type}" + exit 2 + fi + fi + + + if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + log "Stage 4: Remove long/short data: ${data_feats}/org -> ${data_feats}" + + # NOTE(kamo): Not applying to test_sets to keep original data + if [ "${test_sets}" = "Test_Ali_far" ]; then + rm_dset="${train_set} ${valid_set} ${test_sets}" + else + rm_dset="${train_set} ${valid_set}" + fi + + for dset in $rm_dset; do + + # Copy data dir + utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}" + cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type" + + # Remove short utterances + _feats_type="$(<${data_feats}/${dset}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _fs=$(python3 -c "import humanfriendly as h;print(h.parse_size('${fs}'))") + _min_length=$(python3 -c "print(int(${min_wav_duration} * ${_fs}))") + _max_length=$(python3 -c "print(int(${max_wav_duration} * ${_fs}))") + + # utt2num_samples is created by format_wav_scp.sh + <"${data_feats}/org/${dset}/utt2num_samples" \ + awk -v min_length="${_min_length}" -v max_length="${_max_length}" \ + '{ if ($2 > min_length && $2 < max_length ) print $0; }' \ + >"${data_feats}/${dset}/utt2num_samples" + <"${data_feats}/org/${dset}/wav.scp" \ + utils/filter_scp.pl "${data_feats}/${dset}/utt2num_samples" \ + >"${data_feats}/${dset}/wav.scp" + else + # Get frame shift in ms from conf/fbank.conf + _frame_shift= + if [ -f conf/fbank.conf ] && [ "$( min_length && $2 < max_length) print $0; }' \ + >"${data_feats}/${dset}/feats_shape" + <"${data_feats}/org/${dset}/feats.scp" \ + utils/filter_scp.pl "${data_feats}/${dset}/feats_shape" \ + >"${data_feats}/${dset}/feats.scp" + fi + + # Remove empty text + <"${data_feats}/org/${dset}/text" \ + awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text" + + # fix_data_dir.sh leaves only utts which exist in all files + utils/fix_data_dir.sh "${data_feats}/${dset}" + + # generate uttid + cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid" + # filter utt2spk_all_fifo + python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset} + done + + # shellcheck disable=SC2002 + cat ${lm_train_text} | awk ' { if( NF != 1 ) print $0; } ' > "${data_feats}/lm_train.txt" + fi + + + if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + log "Stage 5: Dictionary Preparation" + mkdir -p data/${lang}_token_list/char/ + + echo "make a dictionary" + echo "" > ${token_list} + echo "" >> ${token_list} + echo "" >> ${token_list} + local/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} + num_token=$(cat ${token_list} | wc -l) + echo "" >> ${token_list} + vocab_size=$(cat ${token_list} | wc -l) + fi + + if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + log "Stage 6: Generate speaker settings" + mkdir -p "profile_log" + for dset in "${train_set}" "${valid_set}" "${test_sets}"; do + # generate text_id spk2id + python local/process_sot_fifo_textchar2spk.py --path ${data_feats}/${dset} + log "Successfully generate ${data_feats}/${dset}/text_id ${data_feats}/${dset}/spk2id" + # generate text_id_train for sot + python local/process_text_id.py ${data_feats}/${dset} + log "Successfully generate ${data_feats}/${dset}/text_id_train" + # generate oracle_embedding from single-speaker audio segment + python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log" + log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)" + # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training) + if [ "${dset}" = "${train_set}" ]; then + python local/gen_oracle_profile_padding.py ${data_feats}/${dset} + log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_padding.scp)" + else + python local/gen_oracle_profile_nopadding.py ${data_feats}/${dset} + log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_nopadding.scp)" + fi + # generate cluster_profile with spectral-cluster directly (for infering and without oracle information) + if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then + python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log" + log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)" + fi + + done + fi + +else + log "Skip the stages for data preparation" +fi + + +# ========================== Data preparation is done here. ========================== + + +if ! "${skip_train}"; then + if "${use_lm}"; then + if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + log "Stage 7: LM collect stats: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}" + + _opts= + if [ -n "${lm_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.lm_train --print_config --optim adam + _opts+="--config ${lm_config} " + fi + + # 1. Split the key file + _logdir="${lm_stats_dir}/logdir" + mkdir -p "${_logdir}" + # Get the minimum number among ${nj} and the number lines of input files + _nj=$(min "${nj}" "$(<${data_feats}/lm_train.txt wc -l)" "$(<${lm_dev_text} wc -l)") + + key_file="${data_feats}/lm_train.txt" + split_scps="" + for n in $(seq ${_nj}); do + split_scps+=" ${_logdir}/train.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + key_file="${lm_dev_text}" + split_scps="" + for n in $(seq ${_nj}); do + split_scps+=" ${_logdir}/dev.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Generate run.sh + log "Generate '${lm_stats_dir}/run.sh'. You can resume the process from stage 6 using this script" + mkdir -p "${lm_stats_dir}"; echo "${run_args} --stage 6 \"\$@\"; exit \$?" > "${lm_stats_dir}/run.sh"; chmod +x "${lm_stats_dir}/run.sh" + + # 3. Submit jobs + log "LM collect-stats started... log: '${_logdir}/stats.*.log'" + # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted, + # but it's used only for deciding the sample ids. + # shellcheck disable=SC2086 + ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \ + ${python} -m funasr.bin.lm_train \ + --collect_stats true \ + --use_preprocessor true \ + --bpemodel "${bpemodel}" \ + --token_type "${lm_token_type}"\ + --token_list "${lm_token_list}" \ + --non_linguistic_symbols "${nlsyms_txt}" \ + --cleaner "${cleaner}" \ + --g2p "${g2p}" \ + --train_data_path_and_name_and_type "${data_feats}/lm_train.txt,text,text" \ + --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \ + --train_shape_file "${_logdir}/train.JOB.scp" \ + --valid_shape_file "${_logdir}/dev.JOB.scp" \ + --output_dir "${_logdir}/stats.JOB" \ + ${_opts} ${lm_args} || { cat "${_logdir}"/stats.1.log; exit 1; } + + # 4. Aggregate shape files + _opts= + for i in $(seq "${_nj}"); do + _opts+="--input_dir ${_logdir}/stats.${i} " + done + # shellcheck disable=SC2086 + ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${lm_stats_dir}" + + # Append the num-tokens at the last dimensions. This is used for batch-bins count + <"${lm_stats_dir}/train/text_shape" \ + awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \ + >"${lm_stats_dir}/train/text_shape.${lm_token_type}" + + <"${lm_stats_dir}/valid/text_shape" \ + awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \ + >"${lm_stats_dir}/valid/text_shape.${lm_token_type}" + fi + + + if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + log "Stage 8: LM Training: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}" + + _opts= + if [ -n "${lm_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.lm_train --print_config --optim adam + _opts+="--config ${lm_config} " + fi + + if [ "${num_splits_lm}" -gt 1 ]; then + # If you met a memory error when parsing text files, this option may help you. + # The corpus is split into subsets and each subset is used for training one by one in order, + # so the memory footprint can be limited to the memory required for each dataset. + + _split_dir="${lm_stats_dir}/splits${num_splits_lm}" + if [ ! -f "${_split_dir}/.done" ]; then + rm -f "${_split_dir}/.done" + ${python} -m espnet2.bin.split_scps \ + --scps "${data_feats}/lm_train.txt" "${lm_stats_dir}/train/text_shape.${lm_token_type}" \ + --num_splits "${num_splits_lm}" \ + --output_dir "${_split_dir}" + touch "${_split_dir}/.done" + else + log "${_split_dir}/.done exists. Spliting is skipped" + fi + + _opts+="--train_data_path_and_name_and_type ${_split_dir}/lm_train.txt,text,text " + _opts+="--train_shape_file ${_split_dir}/text_shape.${lm_token_type} " + _opts+="--multiple_iterator true " + + else + _opts+="--train_data_path_and_name_and_type ${data_feats}/lm_train.txt,text,text " + _opts+="--train_shape_file ${lm_stats_dir}/train/text_shape.${lm_token_type} " + fi + + # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case + + log "Generate '${lm_exp}/run.sh'. You can resume the process from stage 8 using this script" + mkdir -p "${lm_exp}"; echo "${run_args} --stage 8 \"\$@\"; exit \$?" > "${lm_exp}/run.sh"; chmod +x "${lm_exp}/run.sh" + + log "LM training started... log: '${lm_exp}/train.log'" + if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then + # SGE can't include "/" in a job name + jobname="$(basename ${lm_exp})" + else + jobname="${lm_exp}/train.log" + fi + + mkdir -p ${lm_exp} + mkdir -p ${lm_exp}/log + INIT_FILE=${lm_exp}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $ngpu; ++i)); do + { + # i=0 + rank=$i + local_rank=$i + gpu_id=$(echo $device | cut -d',' -f$[$i+1]) + lm_train.py \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --bpemodel ${bpemodel} \ + --token_type ${token_type} \ + --token_list ${token_list} \ + --non_linguistic_symbols ${nlsyms_txt} \ + --cleaner ${cleaner} \ + --g2p ${g2p} \ + --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \ + --valid_shape_file "${lm_stats_dir}/valid/text_shape.${lm_token_type}" \ + --resume true \ + --output_dir ${lm_exp} \ + --config $lm_config \ + --ngpu $ngpu \ + --num_worker_count 1 \ + --multiprocessing_distributed true \ + --dist_init_method $init_method \ + --dist_world_size $ngpu \ + --dist_rank $rank \ + --local_rank $local_rank \ + ${_opts} 1> ${lm_exp}/log/train.log.$i 2>&1 + } & + done + wait + + fi + + + if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then + log "Stage 9: Calc perplexity: ${lm_test_text}" + _opts= + # TODO(kamo): Parallelize? + log "Perplexity calculation started... log: '${lm_exp}/perplexity_test/lm_calc_perplexity.log'" + # shellcheck disable=SC2086 + CUDA_VISIBLE_DEVICES=${device}\ + ${cuda_cmd} --gpu "${ngpu}" "${lm_exp}"/perplexity_test/lm_calc_perplexity.log \ + ${python} -m funasr.bin.lm_calc_perplexity \ + --ngpu "${ngpu}" \ + --data_path_and_name_and_type "${lm_test_text},text,text" \ + --train_config "${lm_exp}"/config.yaml \ + --model_file "${lm_exp}/${inference_lm}" \ + --output_dir "${lm_exp}/perplexity_test" \ + ${_opts} + log "PPL: ${lm_test_text}: $(cat ${lm_exp}/perplexity_test/ppl)" + + fi + + else + log "Stage 7-9: Skip lm-related stages: use_lm=${use_lm}" + fi + + + if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then + _asr_train_dir="${data_feats}/${train_set}" + _asr_valid_dir="${data_feats}/${valid_set}" + log "Stage 10: ASR collect stats: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}" + + _opts= + if [ -n "${asr_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.asr_train --print_config --optim adam + _opts+="--config ${asr_config} " + fi + + _feats_type="$(<${_asr_train_dir}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + # "sound" supports "wav", "flac", etc. + _type=sound + fi + _opts+="--frontend_conf fs=${fs} " + else + _scp=feats.scp + _type=kaldi_ark + _input_size="$(<${_asr_train_dir}/feats_dim)" + _opts+="--input_size=${_input_size} " + fi + + # 1. Split the key file + _logdir="${asr_stats_dir}/logdir" + mkdir -p "${_logdir}" + + # Get the minimum number among ${nj} and the number lines of input files + _nj=$(min "${nj}" "$(<${_asr_train_dir}/${_scp} wc -l)" "$(<${_asr_valid_dir}/${_scp} wc -l)") + + key_file="${_asr_train_dir}/${_scp}" + split_scps="" + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/train.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + key_file="${_asr_valid_dir}/${_scp}" + split_scps="" + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/valid.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Generate run.sh + log "Generate '${asr_stats_dir}/run.sh'. You can resume the process from stage 9 using this script" + mkdir -p "${asr_stats_dir}"; echo "${run_args} --stage 9 \"\$@\"; exit \$?" > "${asr_stats_dir}/run.sh"; chmod +x "${asr_stats_dir}/run.sh" + + # 3. Submit jobs + log "ASR collect-stats started... log: '${_logdir}/stats.*.log'" + + # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted, + # but it's used only for deciding the sample ids. + + # shellcheck disable=SC2086 + ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \ + ${python} -m funasr.bin.asr_train \ + --collect_stats true \ + --mc true \ + --use_preprocessor true \ + --bpemodel "${bpemodel}" \ + --token_type "${token_type}" \ + --token_list "${token_list}" \ + --split_with_space false \ + --non_linguistic_symbols "${nlsyms_txt}" \ + --cleaner "${cleaner}" \ + --g2p "${g2p}" \ + --train_data_path_and_name_and_type "${_asr_train_dir}/${_scp},speech,${_type}" \ + --train_data_path_and_name_and_type "${_asr_train_dir}/text,text,text" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \ + --train_shape_file "${_logdir}/train.JOB.scp" \ + --valid_shape_file "${_logdir}/valid.JOB.scp" \ + --output_dir "${_logdir}/stats.JOB" \ + ${_opts} ${asr_args} || { cat "${_logdir}"/stats.1.log; exit 1; } + + # 4. Aggregate shape files + _opts= + for i in $(seq "${_nj}"); do + _opts+="--input_dir ${_logdir}/stats.${i} " + done + # shellcheck disable=SC2086 + ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${asr_stats_dir}" + + # Append the num-tokens at the last dimensions. This is used for batch-bins count + <"${asr_stats_dir}/train/text_shape" \ + awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \ + >"${asr_stats_dir}/train/text_shape.${token_type}" + + <"${asr_stats_dir}/valid/text_shape" \ + awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \ + >"${asr_stats_dir}/valid/text_shape.${token_type}" + fi + + + if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then + _asr_train_dir="${data_feats}/${train_set}" + _asr_valid_dir="${data_feats}/${valid_set}" + log "Stage 11: ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}" + + _opts= + if [ -n "${asr_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.asr_train --print_config --optim adam + _opts+="--config ${asr_config} " + fi + + _feats_type="$(<${_asr_train_dir}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + # "sound" supports "wav", "flac", etc. + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + _opts+="--frontend_conf fs=${fs} " + else + _scp=feats.scp + _type=kaldi_ark + _input_size="$(<${_asr_train_dir}/feats_dim)" + _opts+="--input_size=${_input_size} " + + fi + if [ "${feats_normalize}" = global_mvn ]; then + # Default normalization is utterance_mvn and changes to global_mvn + _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz " + fi + + if [ "${num_splits_asr}" -gt 1 ]; then + # If you met a memory error when parsing text files, this option may help you. + # The corpus is split into subsets and each subset is used for training one by one in order, + # so the memory footprint can be limited to the memory required for each dataset. + + _split_dir="${asr_stats_dir}/splits${num_splits_asr}" + if [ ! -f "${_split_dir}/.done" ]; then + rm -f "${_split_dir}/.done" + ${python} -m espnet2.bin.split_scps \ + --scps \ + "${_asr_train_dir}/${_scp}" \ + "${_asr_train_dir}/text" \ + "${asr_stats_dir}/train/speech_shape" \ + "${asr_stats_dir}/train/text_shape.${token_type}" \ + --num_splits "${num_splits_asr}" \ + --output_dir "${_split_dir}" + touch "${_split_dir}/.done" + else + log "${_split_dir}/.done exists. Spliting is skipped" + fi + + _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} " + _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text " + _opts+="--train_shape_file ${_split_dir}/speech_shape " + _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} " + _opts+="--multiple_iterator true " + + else + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} " + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text " + _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape " + _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} " + fi + + # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script" + # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh" + + # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case + log "ASR training started... log: '${asr_exp}/log/train.log'" + # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then + # # SGE can't include "/" in a job name + # jobname="$(basename ${asr_exp})" + # else + # jobname="${asr_exp}/train.log" + # fi + + mkdir -p ${asr_exp} + mkdir -p ${asr_exp}/log + INIT_FILE=${asr_exp}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $ngpu; ++i)); do + { + # i=0 + rank=$i + local_rank=$i + gpu_id=$(echo $device | cut -d',' -f$[$i+1]) + asr_train.py \ + --mc true \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --bpemodel ${bpemodel} \ + --token_type ${token_type} \ + --token_list ${token_list} \ + --split_with_space false \ + --non_linguistic_symbols ${nlsyms_txt} \ + --cleaner ${cleaner} \ + --g2p ${g2p} \ + --valid_data_path_and_name_and_type ${_asr_valid_dir}/${_scp},speech,${_type} \ + --valid_data_path_and_name_and_type ${_asr_valid_dir}/text,text,text \ + --valid_shape_file ${asr_stats_dir}/valid/speech_shape \ + --valid_shape_file ${asr_stats_dir}/valid/text_shape.${token_type} \ + --resume true \ + --output_dir ${asr_exp} \ + --config $asr_config \ + --ngpu $ngpu \ + --num_worker_count 1 \ + --multiprocessing_distributed true \ + --dist_init_method $init_method \ + --dist_world_size $ngpu \ + --dist_rank $rank \ + --local_rank $local_rank \ + ${_opts} 1> ${asr_exp}/log/train.log.$i 2>&1 + } & + done + wait + + fi + + if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then + _asr_train_dir="${data_feats}/${train_set}" + _asr_valid_dir="${data_feats}/${valid_set}" + log "Stage 12: SA-ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}" + + _opts= + if [ -n "${sa_asr_config}" ]; then + # To generate the config file: e.g. + # % python3 -m espnet2.bin.asr_train --print_config --optim adam + _opts+="--config ${sa_asr_config} " + fi + + _feats_type="$(<${_asr_train_dir}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + # "sound" supports "wav", "flac", etc. + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + _opts+="--frontend_conf fs=${fs} " + else + _scp=feats.scp + _type=kaldi_ark + _input_size="$(<${_asr_train_dir}/feats_dim)" + _opts+="--input_size=${_input_size} " + + fi + if [ "${feats_normalize}" = global_mvn ]; then + # Default normalization is utterance_mvn and changes to global_mvn + _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz " + fi + + if [ "${num_splits_asr}" -gt 1 ]; then + # If you met a memory error when parsing text files, this option may help you. + # The corpus is split into subsets and each subset is used for training one by one in order, + # so the memory footprint can be limited to the memory required for each dataset. + + _split_dir="${asr_stats_dir}/splits${num_splits_asr}" + if [ ! -f "${_split_dir}/.done" ]; then + rm -f "${_split_dir}/.done" + ${python} -m espnet2.bin.split_scps \ + --scps \ + "${_asr_train_dir}/${_scp}" \ + "${_asr_train_dir}/text" \ + "${asr_stats_dir}/train/speech_shape" \ + "${asr_stats_dir}/train/text_shape.${token_type}" \ + --num_splits "${num_splits_asr}" \ + --output_dir "${_split_dir}" + touch "${_split_dir}/.done" + else + log "${_split_dir}/.done exists. Spliting is skipped" + fi + + _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} " + _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text " + _opts+="--train_data_path_and_name_and_type ${_split_dir}/text_id_train,text_id,text_int " + _opts+="--train_data_path_and_name_and_type ${_split_dir}/oracle_profile_padding.scp,profile,npy " + _opts+="--train_shape_file ${_split_dir}/speech_shape " + _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} " + _opts+="--multiple_iterator true " + + else + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} " + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text " + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/oracle_profile_padding.scp,profile,npy " + _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text_id_train,text_id,text_int " + _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape " + _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} " + fi + + # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script" + # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh" + + # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case + log "SA-ASR training started... log: '${sa_asr_exp}/log/train.log'" + # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then + # # SGE can't include "/" in a job name + # jobname="$(basename ${asr_exp})" + # else + # jobname="${asr_exp}/train.log" + # fi + + mkdir -p ${sa_asr_exp} + mkdir -p ${sa_asr_exp}/log + INIT_FILE=${sa_asr_exp}/ddp_init + + if [ ! -f "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then + # download xvector extractor model file + python local/download_xvector_model.py exp + log "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" + fi + + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $ngpu; ++i)); do + { + # i=0 + rank=$i + local_rank=$i + gpu_id=$(echo $device | cut -d',' -f$[$i+1]) + sa_asr_train.py \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --unused_parameters true \ + --bpemodel ${bpemodel} \ + --token_type ${token_type} \ + --token_list ${token_list} \ + --max_spk_num 4 \ + --split_with_space false \ + --non_linguistic_symbols ${nlsyms_txt} \ + --cleaner ${cleaner} \ + --g2p ${g2p} \ + --allow_variable_data_keys true \ + --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \ + --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \ + --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \ + --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \ + --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/oracle_profile_nopadding.scp,profile,npy" \ + --valid_data_path_and_name_and_type "${_asr_valid_dir}/text_id_train,text_id,text_int" \ + --valid_shape_file "${asr_stats_dir}/valid/speech_shape" \ + --valid_shape_file "${asr_stats_dir}/valid/text_shape.${token_type}" \ + --resume true \ + --output_dir ${sa_asr_exp} \ + --config $sa_asr_config \ + --ngpu $ngpu \ + --num_worker_count 1 \ + --multiprocessing_distributed true \ + --dist_init_method $init_method \ + --dist_world_size $ngpu \ + --dist_rank $rank \ + --local_rank $local_rank \ + ${_opts} 1> ${sa_asr_exp}/log/train.log.$i 2>&1 + } & + done + wait + + fi + +else + log "Skip the training stages" +fi + + +if ! "${skip_eval}"; then + if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then + log "Stage 13: Decoding multi-talker ASR: training_dir=${asr_exp}" + + if ${gpu_inference}; then + _cmd="${cuda_cmd}" + inference_nj=$[${ngpu}*${njob_infer}] + _ngpu=1 + + else + _cmd="${decode_cmd}" + inference_nj=$inference_nj + _ngpu=0 + fi + + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + if "${use_lm}"; then + if "${use_word_lm}"; then + _opts+="--word_lm_train_config ${lm_exp}/config.yaml " + _opts+="--word_lm_file ${lm_exp}/${inference_lm} " + else + _opts+="--lm_train_config ${lm_exp}/config.yaml " + _opts+="--lm_file ${lm_exp}/${inference_lm} " + fi + fi + + # 2. Generate run.sh + log "Generate '${asr_exp}/${inference_tag}/run.sh'. You can resume the process from stage 13 using this script" + mkdir -p "${asr_exp}/${inference_tag}"; echo "${run_args} --stage 13 \"\$@\"; exit \$?" > "${asr_exp}/${inference_tag}/run.sh"; chmod +x "${asr_exp}/${inference_tag}/run.sh" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${asr_exp}/${inference_tag}/${dset}" + _logdir="${_dir}/logdir" + mkdir -p "${_logdir}" + + _feats_type="$(<${_data}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + else + _scp=feats.scp + _type=kaldi_ark + fi + + # 1. Split the key file + key_file=${_data}/${_scp} + split_scps="" + _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)") + echo $_nj + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Submit decoding jobs + log "Decoding started... log: '${_logdir}/asr_inference.*.log'" + + ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --nbest 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob_infer} \ + --gpuid_list ${device} \ + --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --asr_train_config "${asr_exp}"/config.yaml \ + --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode asr \ + ${_opts} + + # 3. Concatenates the output files from each jobs + for f in token token_int score text; do + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | LC_ALL=C sort -k1 >"${_dir}/${f}" + done + done + fi + + + if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then + log "Stage 14: Scoring multi-talker ASR" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${asr_exp}/${inference_tag}/${dset}" + + python local/proce_text.py ${_data}/text ${_data}/text.proc + python local/proce_text.py ${_dir}/text ${_dir}/text.proc + + python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + + done + + fi + + if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then + log "Stage 15: Decoding SA-ASR (oracle profile): training_dir=${sa_asr_exp}" + + if ${gpu_inference}; then + _cmd="${cuda_cmd}" + inference_nj=$[${ngpu}*${njob_infer}] + _ngpu=1 + + else + _cmd="${decode_cmd}" + inference_nj=$inference_nj + _ngpu=0 + fi + + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + if "${use_lm}"; then + if "${use_word_lm}"; then + _opts+="--word_lm_train_config ${lm_exp}/config.yaml " + _opts+="--word_lm_file ${lm_exp}/${inference_lm} " + else + _opts+="--lm_train_config ${lm_exp}/config.yaml " + _opts+="--lm_file ${lm_exp}/${inference_lm} " + fi + fi + + # 2. Generate run.sh + log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh'. You can resume the process from stage 15 using this script" + mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.oracle"; echo "${run_args} --stage 15 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}" + _logdir="${_dir}/logdir" + mkdir -p "${_logdir}" + + _feats_type="$(<${_data}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + else + _scp=feats.scp + _type=kaldi_ark + fi + + # 1. Split the key file + key_file=${_data}/${_scp} + split_scps="" + _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)") + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Submit decoding jobs + log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'" + # shellcheck disable=SC2086 + ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --nbest 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob_infer} \ + --gpuid_list ${device} \ + --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \ + --data_path_and_name_and_type "${_data}/oracle_profile_nopadding.scp,profile,npy" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --allow_variable_data_keys true \ + --asr_train_config "${sa_asr_exp}"/config.yaml \ + --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode sa_asr \ + ${_opts} + + + # 3. Concatenates the output files from each jobs + for f in token token_int score text text_id; do + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | LC_ALL=C sort -k1 >"${_dir}/${f}" + done + done + fi + + if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then + log "Stage 16: Scoring SA-ASR (oracle profile)" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}" + + python local/proce_text.py ${_data}/text ${_data}/text.proc + python local/proce_text.py ${_dir}/text ${_dir}/text.proc + + python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + + python local/process_text_spk_merge.py ${_dir} + python local/process_text_spk_merge.py ${_data} + + python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer + tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt + cat ${_dir}/text.cpcer.txt + + done + + fi + + if [ ${stage} -le 17 ] && [ ${stop_stage} -ge 17 ]; then + log "Stage 17: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}" + + if ${gpu_inference}; then + _cmd="${cuda_cmd}" + inference_nj=$[${ngpu}*${njob_infer}] + _ngpu=1 + + else + _cmd="${decode_cmd}" + inference_nj=$inference_nj + _ngpu=0 + fi + + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + if "${use_lm}"; then + if "${use_word_lm}"; then + _opts+="--word_lm_train_config ${lm_exp}/config.yaml " + _opts+="--word_lm_file ${lm_exp}/${inference_lm} " + else + _opts+="--lm_train_config ${lm_exp}/config.yaml " + _opts+="--lm_file ${lm_exp}/${inference_lm} " + fi + fi + + # 2. Generate run.sh + log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script" + mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" + _logdir="${_dir}/logdir" + mkdir -p "${_logdir}" + + _feats_type="$(<${_data}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + else + _scp=feats.scp + _type=kaldi_ark + fi + + # 1. Split the key file + key_file=${_data}/${_scp} + split_scps="" + _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)") + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Submit decoding jobs + log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'" + # shellcheck disable=SC2086 + ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --nbest 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob_infer} \ + --gpuid_list ${device} \ + --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \ + --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --allow_variable_data_keys true \ + --asr_train_config "${sa_asr_exp}"/config.yaml \ + --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode sa_asr \ + ${_opts} + + # 3. Concatenates the output files from each jobs + for f in token token_int score text text_id; do + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | LC_ALL=C sort -k1 >"${_dir}/${f}" + done + done + fi + + if [ ${stage} -le 18 ] && [ ${stop_stage} -ge 18 ]; then + log "Stage 18: Scoring SA-ASR (cluster profile)" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" + + python local/proce_text.py ${_data}/text ${_data}/text.proc + python local/proce_text.py ${_dir}/text ${_dir}/text.proc + + python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + + python local/process_text_spk_merge.py ${_dir} + python local/process_text_spk_merge.py ${_data} + + python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer + tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt + cat ${_dir}/text.cpcer.txt + + done + + fi + +else + log "Skip the evaluation stages" +fi + + +log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/egs/alimeeting/sa-asr/asr_local_infer.sh b/egs/alimeeting/sa-asr/asr_local_infer.sh new file mode 100755 index 000000000..8e8148ff8 --- /dev/null +++ b/egs/alimeeting/sa-asr/asr_local_infer.sh @@ -0,0 +1,590 @@ +#!/usr/bin/env bash + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} +min() { + local a b + a=$1 + for b in "$@"; do + if [ "${b}" -le "${a}" ]; then + a="${b}" + fi + done + echo "${a}" +} +SECONDS=0 + +# General configuration +stage=1 # Processes starts from the specified stage. +stop_stage=10000 # Processes is stopped at the specified stage. +skip_data_prep=false # Skip data preparation stages. +skip_train=false # Skip training stages. +skip_eval=false # Skip decoding and evaluation stages. +skip_upload=true # Skip packing and uploading stages. +ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu). +num_nodes=1 # The number of nodes. +nj=16 # The number of parallel jobs. +inference_nj=16 # The number of parallel jobs in decoding. +gpu_inference=false # Whether to perform gpu decoding. +njob_infer=4 +dumpdir=dump2 # Directory to dump features. +expdir=exp # Directory to save experiments. +python=python3 # Specify python to execute espnet commands. +device=0 + +# Data preparation related +local_data_opts= # The options given to local/data.sh. + +# Speed perturbation related +speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space). + +# Feature extraction related +feats_type=raw # Feature type (raw or fbank_pitch). +audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw). +fs=16000 # Sampling rate. +min_wav_duration=0.1 # Minimum duration in second. +max_wav_duration=20 # Maximum duration in second. + +# Tokenization related +token_type=bpe # Tokenization type (char or bpe). +nbpe=30 # The number of BPE vocabulary. +bpemode=unigram # Mode of BPE (unigram or bpe). +oov="" # Out of vocabulary symbol. +blank="" # CTC blank symbol +sos_eos="" # sos and eos symbole +bpe_input_sentence_size=100000000 # Size of input sentence for BPE. +bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE +bpe_char_cover=1.0 # character coverage when modeling BPE + +# Language model related +use_lm=true # Use language model for ASR decoding. +lm_tag= # Suffix to the result dir for language model training. +lm_exp= # Specify the direcotry path for LM experiment. + # If this option is specified, lm_tag is ignored. +lm_stats_dir= # Specify the direcotry path for LM statistics. +lm_config= # Config for language model training. +lm_args= # Arguments for language model training, e.g., "--max_epoch 10". + # Note that it will overwrite args in lm config. +use_word_lm=false # Whether to use word language model. +num_splits_lm=1 # Number of splitting for lm corpus. +# shellcheck disable=SC2034 +word_vocab_size=10000 # Size of word vocabulary. + +# ASR model related +asr_tag= # Suffix to the result dir for asr model training. +asr_exp= # Specify the direcotry path for ASR experiment. + # If this option is specified, asr_tag is ignored. +sa_asr_exp= +asr_stats_dir= # Specify the direcotry path for ASR statistics. +asr_config= # Config for asr model training. +sa_asr_config= +asr_args= # Arguments for asr model training, e.g., "--max_epoch 10". + # Note that it will overwrite args in asr config. +feats_normalize=global_mvn # Normalizaton layer type. +num_splits_asr=1 # Number of splitting for lm corpus. + +# Decoding related +inference_tag= # Suffix to the result dir for decoding. +inference_config= # Config for decoding. +inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1". + # Note that it will overwrite args in inference config. +sa_asr_inference_tag= +sa_asr_inference_args= + +inference_lm=valid.loss.ave.pb # Language modle path for decoding. +inference_asr_model=valid.acc.ave.pb # ASR model path for decoding. + # e.g. + # inference_asr_model=train.loss.best.pth + # inference_asr_model=3epoch.pth + # inference_asr_model=valid.acc.best.pth + # inference_asr_model=valid.loss.ave.pth +inference_sa_asr_model=valid.acc_spk.ave.pb +download_model= # Download a model from Model Zoo and use it for decoding. + +# [Task dependent] Set the datadir name created by local/data.sh +train_set= # Name of training set. +valid_set= # Name of validation set used for monitoring/tuning network training. +test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified. +bpe_train_text= # Text file path of bpe training set. +lm_train_text= # Text file path of language model training set. +lm_dev_text= # Text file path of language model development set. +lm_test_text= # Text file path of language model evaluation set. +nlsyms_txt=none # Non-linguistic symbol list if existing. +cleaner=none # Text cleaner. +g2p=none # g2p method (needed if token_type=phn). +lang=zh # The language type of corpus. +score_opts= # The options given to sclite scoring +local_score_opts= # The options given to local/score.sh. + +help_message=$(cat << EOF +Usage: $0 --train-set "" --valid-set "" --test_sets "" + +Options: + # General configuration + --stage # Processes starts from the specified stage (default="${stage}"). + --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}"). + --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}"). + --skip_train # Skip training stages (default="${skip_train}"). + --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}"). + --skip_upload # Skip packing and uploading stages (default="${skip_upload}"). + --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}"). + --num_nodes # The number of nodes (default="${num_nodes}"). + --nj # The number of parallel jobs (default="${nj}"). + --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}"). + --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}"). + --dumpdir # Directory to dump features (default="${dumpdir}"). + --expdir # Directory to save experiments (default="${expdir}"). + --python # Specify python to execute espnet commands (default="${python}"). + --device # Which GPUs are use for local training (defalut="${device}"). + + # Data preparation related + --local_data_opts # The options given to local/data.sh (default="${local_data_opts}"). + + # Speed perturbation related + --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}"). + + # Feature extraction related + --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}"). + --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}"). + --fs # Sampling rate (default="${fs}"). + --min_wav_duration # Minimum duration in second (default="${min_wav_duration}"). + --max_wav_duration # Maximum duration in second (default="${max_wav_duration}"). + + # Tokenization related + --token_type # Tokenization type (char or bpe, default="${token_type}"). + --nbpe # The number of BPE vocabulary (default="${nbpe}"). + --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}"). + --oov # Out of vocabulary symbol (default="${oov}"). + --blank # CTC blank symbol (default="${blank}"). + --sos_eos # sos and eos symbole (default="${sos_eos}"). + --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}"). + --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}"). + --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}"). + + # Language model related + --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}"). + --lm_exp # Specify the direcotry path for LM experiment. + # If this option is specified, lm_tag is ignored (default="${lm_exp}"). + --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}"). + --lm_config # Config for language model training (default="${lm_config}"). + --lm_args # Arguments for language model training (default="${lm_args}"). + # e.g., --lm_args "--max_epoch 10" + # Note that it will overwrite args in lm config. + --use_word_lm # Whether to use word language model (default="${use_word_lm}"). + --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}"). + --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}"). + + # ASR model related + --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}"). + --asr_exp # Specify the direcotry path for ASR experiment. + # If this option is specified, asr_tag is ignored (default="${asr_exp}"). + --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}"). + --asr_config # Config for asr model training (default="${asr_config}"). + --asr_args # Arguments for asr model training (default="${asr_args}"). + # e.g., --asr_args "--max_epoch 10" + # Note that it will overwrite args in asr config. + --feats_normalize # Normalizaton layer type (default="${feats_normalize}"). + --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}"). + + # Decoding related + --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}"). + --inference_config # Config for decoding (default="${inference_config}"). + --inference_args # Arguments for decoding (default="${inference_args}"). + # e.g., --inference_args "--lm_weight 0.1" + # Note that it will overwrite args in inference config. + --inference_lm # Language modle path for decoding (default="${inference_lm}"). + --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}"). + --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}"). + + # [Task dependent] Set the datadir name created by local/data.sh + --train_set # Name of training set (required). + --valid_set # Name of validation set used for monitoring/tuning network training (required). + --test_sets # Names of test sets. + # Multiple items (e.g., both dev and eval sets) can be specified (required). + --bpe_train_text # Text file path of bpe training set. + --lm_train_text # Text file path of language model training set. + --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}"). + --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}"). + --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}"). + --cleaner # Text cleaner (default="${cleaner}"). + --g2p # g2p method (default="${g2p}"). + --lang # The language type of corpus (default=${lang}). + --score_opts # The options given to sclite scoring (default="{score_opts}"). + --local_score_opts # The options given to local/score.sh (default="{local_score_opts}"). +EOF +) + +log "$0 $*" +# Save command line args for logging (they will be lost after utils/parse_options.sh) +run_args=$(python -m funasr.utils.cli_utils $0 "$@") +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + log "${help_message}" + log "Error: No positional arguments are required." + exit 2 +fi + +. ./path.sh + + +# Check required arguments +[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; }; +[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; }; +[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; }; + +# Check feature type +if [ "${feats_type}" = raw ]; then + data_feats=${dumpdir}/raw +elif [ "${feats_type}" = fbank_pitch ]; then + data_feats=${dumpdir}/fbank_pitch +elif [ "${feats_type}" = fbank ]; then + data_feats=${dumpdir}/fbank +elif [ "${feats_type}" == extracted ]; then + data_feats=${dumpdir}/extracted +else + log "${help_message}" + log "Error: not supported: --feats_type ${feats_type}" + exit 2 +fi + +# Use the same text as ASR for bpe training if not specified. +[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text" +# Use the same text as ASR for lm training if not specified. +[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text" +# Use the same text as ASR for lm training if not specified. +[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text" +# Use the text of the 1st evaldir if lm_test is not specified +[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text" + +# Check tokenization type +if [ "${lang}" != noinfo ]; then + token_listdir=data/${lang}_token_list +else + token_listdir=data/token_list +fi +bpedir="${token_listdir}/bpe_${bpemode}${nbpe}" +bpeprefix="${bpedir}"/bpe +bpemodel="${bpeprefix}".model +bpetoken_list="${bpedir}"/tokens.txt +chartoken_list="${token_listdir}"/char/tokens.txt +# NOTE: keep for future development. +# shellcheck disable=SC2034 +wordtoken_list="${token_listdir}"/word/tokens.txt + +if [ "${token_type}" = bpe ]; then + token_list="${bpetoken_list}" +elif [ "${token_type}" = char ]; then + token_list="${chartoken_list}" + bpemodel=none +elif [ "${token_type}" = word ]; then + token_list="${wordtoken_list}" + bpemodel=none +else + log "Error: not supported --token_type '${token_type}'" + exit 2 +fi +if ${use_word_lm}; then + log "Error: Word LM is not supported yet" + exit 2 + + lm_token_list="${wordtoken_list}" + lm_token_type=word +else + lm_token_list="${token_list}" + lm_token_type="${token_type}" +fi + + +# Set tag for naming of model directory +if [ -z "${asr_tag}" ]; then + if [ -n "${asr_config}" ]; then + asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}" + else + asr_tag="train_${feats_type}" + fi + if [ "${lang}" != noinfo ]; then + asr_tag+="_${lang}_${token_type}" + else + asr_tag+="_${token_type}" + fi + if [ "${token_type}" = bpe ]; then + asr_tag+="${nbpe}" + fi + # Add overwritten arg's info + if [ -n "${asr_args}" ]; then + asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" + fi + if [ -n "${speed_perturb_factors}" ]; then + asr_tag+="_sp" + fi +fi +if [ -z "${lm_tag}" ]; then + if [ -n "${lm_config}" ]; then + lm_tag="$(basename "${lm_config}" .yaml)" + else + lm_tag="train" + fi + if [ "${lang}" != noinfo ]; then + lm_tag+="_${lang}_${lm_token_type}" + else + lm_tag+="_${lm_token_type}" + fi + if [ "${lm_token_type}" = bpe ]; then + lm_tag+="${nbpe}" + fi + # Add overwritten arg's info + if [ -n "${lm_args}" ]; then + lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")" + fi +fi + +# The directory used for collect-stats mode +if [ -z "${asr_stats_dir}" ]; then + if [ "${lang}" != noinfo ]; then + asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}" + else + asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}" + fi + if [ "${token_type}" = bpe ]; then + asr_stats_dir+="${nbpe}" + fi + if [ -n "${speed_perturb_factors}" ]; then + asr_stats_dir+="_sp" + fi +fi +if [ -z "${lm_stats_dir}" ]; then + if [ "${lang}" != noinfo ]; then + lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}" + else + lm_stats_dir="${expdir}/lm_stats_${lm_token_type}" + fi + if [ "${lm_token_type}" = bpe ]; then + lm_stats_dir+="${nbpe}" + fi +fi +# The directory used for training commands +if [ -z "${asr_exp}" ]; then + asr_exp="${expdir}/asr_${asr_tag}" +fi +if [ -z "${lm_exp}" ]; then + lm_exp="${expdir}/lm_${lm_tag}" +fi + + +if [ -z "${inference_tag}" ]; then + if [ -n "${inference_config}" ]; then + inference_tag="$(basename "${inference_config}" .yaml)" + else + inference_tag=inference + fi + # Add overwritten arg's info + if [ -n "${inference_args}" ]; then + inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")" + fi + if "${use_lm}"; then + inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" + fi + inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" +fi + +if [ -z "${sa_asr_inference_tag}" ]; then + if [ -n "${inference_config}" ]; then + sa_asr_inference_tag="$(basename "${inference_config}" .yaml)" + else + sa_asr_inference_tag=sa_asr_inference + fi + # Add overwritten arg's info + if [ -n "${sa_asr_inference_args}" ]; then + sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")" + fi + if "${use_lm}"; then + sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" + fi + sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")" +fi + +train_cmd="run.pl" +cuda_cmd="run.pl" +decode_cmd="run.pl" + +# ========================== Main stages start from here. ========================== + +if ! "${skip_data_prep}"; then + + if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + if [ "${feats_type}" = raw ]; then + log "Stage 1: Format wav.scp: data/ -> ${data_feats}" + + # ====== Recreating "wav.scp" ====== + # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |", + # shouldn't be used in training process. + # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file + # and it can also change the audio-format and sampling rate. + # If nothing is need, then format_wav_scp.sh does nothing: + # i.e. the input file format and rate is same as the output. + + for dset in "${test_sets}" ; do + + _suf="" + + utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" + + rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur} + _opts= + if [ -e data/"${dset}"/segments ]; then + # "segments" is used for splitting wav files which are written in "wav".scp + # into utterances. The file format of segments: + # + # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5" + # Where the time is written in seconds. + _opts+="--segments data/${dset}/segments " + fi + # shellcheck disable=SC2086 + scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ + --audio-format "${audio_format}" --fs "${fs}" ${_opts} \ + "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}" + + echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type" + done + + else + log "Error: not supported: --feats_type ${feats_type}" + exit 2 + fi + fi + + if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + log "Stage 2: Generate speaker profile by spectral-cluster" + mkdir -p "profile_log" + for dset in "${test_sets}"; do + # generate cluster_profile with spectral-cluster directly (for infering and without oracle information) + python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log" + log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)" + done + fi + +else + log "Skip the stages for data preparation" +fi + + +# ========================== Data preparation is done here. ========================== + +if ! "${skip_eval}"; then + + if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + log "Stage 3: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}" + + if ${gpu_inference}; then + _cmd="${cuda_cmd}" + inference_nj=$[${ngpu}*${njob_infer}] + _ngpu=1 + + else + _cmd="${decode_cmd}" + inference_nj=$njob_infer + _ngpu=0 + fi + + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + if "${use_lm}"; then + if "${use_word_lm}"; then + _opts+="--word_lm_train_config ${lm_exp}/config.yaml " + _opts+="--word_lm_file ${lm_exp}/${inference_lm} " + else + _opts+="--lm_train_config ${lm_exp}/config.yaml " + _opts+="--lm_file ${lm_exp}/${inference_lm} " + fi + fi + + # 2. Generate run.sh + log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script" + mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh" + + for dset in ${test_sets}; do + _data="${data_feats}/${dset}" + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" + _logdir="${_dir}/logdir" + mkdir -p "${_logdir}" + + _feats_type="$(<${_data}/feats_type)" + if [ "${_feats_type}" = raw ]; then + _scp=wav.scp + if [[ "${audio_format}" == *ark* ]]; then + _type=kaldi_ark + else + _type=sound + fi + else + _scp=feats.scp + _type=kaldi_ark + fi + + # 1. Split the key file + key_file=${_data}/${_scp} + split_scps="" + _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)") + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + + # 2. Submit decoding jobs + log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'" + # shellcheck disable=SC2086 + ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --nbest 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob_infer} \ + --gpuid_list ${device} \ + --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \ + --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --allow_variable_data_keys true \ + --asr_train_config "${sa_asr_exp}"/config.yaml \ + --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode sa_asr \ + ${_opts} + + # 3. Concatenates the output files from each jobs + for f in token token_int score text text_id; do + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | LC_ALL=C sort -k1 >"${_dir}/${f}" + done + done + fi + + if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + log "Stage 4: Generate SA-ASR results (cluster profile)" + + for dset in ${test_sets}; do + _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" + + python local/process_text_spk_merge.py ${_dir} + done + + fi + +else + log "Skip the evaluation stages" +fi + + +log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml new file mode 100644 index 000000000..88fdbc20b --- /dev/null +++ b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml @@ -0,0 +1,6 @@ +beam_size: 20 +penalty: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +ctc_weight: 0.6 +lm_weight: 0.3 diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml new file mode 100644 index 000000000..a8c996875 --- /dev/null +++ b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml @@ -0,0 +1,88 @@ +# network architecture +frontend: default +frontend_conf: + n_fft: 400 + win_length: 400 + hop_length: 160 + use_channel: 0 + +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder architecture type + normalize_before: true + rel_pos_type: latest + pos_enc_layer_type: rel_pos + selfattention_layer_type: rel_selfattn + activation_type: swish + macaron_style: true + use_cnn_module: true + cnn_module_kernel: 15 + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# ctc related +ctc_conf: + ignore_nan_grad: true + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# minibatch related +batch_type: numel +batch_bins: 10000000 # reduce/increase this number according to your GPU memory + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 100 +val_scheduler_criterion: + - valid + - acc +best_model_criterion: +- - valid + - acc + - max +keep_nbest_models: 10 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 diff --git a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml new file mode 100644 index 000000000..68520ae23 --- /dev/null +++ b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml @@ -0,0 +1,29 @@ +lm: transformer +lm_conf: + pos_enc: null + embed_unit: 128 + att_unit: 512 + head: 8 + unit: 2048 + layer: 16 + dropout_rate: 0.1 + +# optimization related +grad_clip: 5.0 +batch_type: numel +batch_bins: 500000 # 4gpus * 500000 +accum_grad: 1 +max_epoch: 15 # 15epoch is enougth + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + +best_model_criterion: +- - valid + - loss + - min +keep_nbest_models: 10 # 10 is good. diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml new file mode 100644 index 000000000..e91db1804 --- /dev/null +++ b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml @@ -0,0 +1,116 @@ +# network architecture +frontend: default +frontend_conf: + n_fft: 400 + win_length: 400 + hop_length: 160 + use_channel: 0 + +# encoder related +asr_encoder: conformer +asr_encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder architecture type + normalize_before: true + pos_enc_layer_type: rel_pos + selfattention_layer_type: rel_selfattn + activation_type: swish + macaron_style: true + use_cnn_module: true + cnn_module_kernel: 15 + +spk_encoder: resnet34_diar +spk_encoder_conf: + use_head_conv: true + batchnorm_momentum: 0.5 + use_head_maxpool: false + num_nodes_pooling_layer: 256 + layers_in_block: + - 3 + - 4 + - 6 + - 3 + filters_in_block: + - 32 + - 64 + - 128 + - 256 + pooling_type: statistic + num_nodes_resnet1: 256 + num_nodes_last_layer: 256 + batchnorm_momentum: 0.5 + +# decoder related +decoder: sa_decoder +decoder_conf: + attention_heads: 4 + linear_units: 2048 + asr_num_blocks: 6 + spk_num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + spk_weight: 0.5 + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +ctc_conf: + ignore_nan_grad: true + +# minibatch related +batch_type: numel +batch_bins: 10000000 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 60 +val_scheduler_criterion: + - valid + - loss +best_model_criterion: +- - valid + - acc + - max +- - valid + - acc_spk + - max +- - valid + - loss + - min +keep_nbest_models: 10 + +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 8000 + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 + diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh new file mode 100755 index 000000000..8151bae30 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +help_messge=$(cat << EOF +Usage: $0 + +Options: + --no_overlap (bool): Whether to ignore the overlapping utterance in the training set. + --tgt (string): Which set to process, test or train. +EOF +) + +SECONDS=0 +tgt=Train #Train or Eval + + +log "$0 $*" +echo $tgt +. ./utils/parse_options.sh + +. ./path.sh + +AliMeeting="${PWD}/dataset" + +if [ $# -gt 2 ]; then + log "${help_message}" + exit 2 +fi + + +if [ ! -d "${AliMeeting}" ]; then + log "Error: ${AliMeeting} is empty." + exit 2 +fi + +# To absolute path +AliMeeting=$(cd ${AliMeeting}; pwd) +echo $AliMeeting +far_raw_dir=${AliMeeting}/${tgt}_Ali_far/ +near_raw_dir=${AliMeeting}/${tgt}_Ali_near/ + +far_dir=data/local/${tgt}_Ali_far +near_dir=data/local/${tgt}_Ali_near +far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker +mkdir -p $far_single_speaker_dir + +stage=1 +stop_stage=4 +mkdir -p $far_dir +mkdir -p $near_dir + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + log "stage 1:process alimeeting near dir" + + find -L $near_raw_dir/audio_dir -iname "*.wav" > $near_dir/wavlist + awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid + find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist + n1_wav=$(wc -l < $near_dir/wavlist) + n2_text=$(wc -l < $near_dir/textgrid.flist) + log near file found $n1_wav wav and $n2_text text. + + paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp + + # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp + cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp + + python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False + cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text + utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk + #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1 + #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk + utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt + utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments + sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1 + sed -e 's/!//g' $near_dir/tmp1> $near_dir/tmp2 + sed -e 's/?//g' $near_dir/tmp2> $near_dir/text + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + log "stage 2:process alimeeting far dir" + + find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist + awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid + find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist + n1_wav=$(wc -l < $far_dir/wavlist) + n2_text=$(wc -l < $far_dir/textgrid.flist) + log far file found $n1_wav wav and $n2_text text. + + paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp + + cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp + + python local/alimeeting_process_overlap_force.py --path $far_dir \ + --no-overlap false --mars True \ + --overlap_length 0.8 --max_length 7 + + cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text + utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk + #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk + + utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt + utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments + sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1 + sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2 + sed -e 's/!//g' $far_dir/tmp2> $far_dir/tmp3 + sed -e 's/?//g' $far_dir/tmp3> $far_dir/text +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + log "stage 3: finali data process" + + utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near + utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far + + sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo + sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo + + # remove space in text + for x in ${tgt}_Ali_near ${tgt}_Ali_far; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + + log "Successfully finished. [elapsed=${SECONDS}s]" +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + log "stage 4: process alimeeting far dir (single speaker by oracle time strap)" + cp -r $far_dir/* $far_single_speaker_dir + mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath + paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist + python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir + + cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text + utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt + + ./utils/fix_data_dir.sh $far_single_speaker_dir + utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker + + # remove space in text + for x in ${tgt}_Ali_far_single_speaker; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + log "Successfully finished. [elapsed=${SECONDS}s]" +fi \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh new file mode 100755 index 000000000..382a05669 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh @@ -0,0 +1,129 @@ +#!/usr/bin/env bash +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +help_messge=$(cat << EOF +Usage: $0 + +Options: + --no_overlap (bool): Whether to ignore the overlapping utterance in the training set. + --tgt (string): Which set to process, test or train. +EOF +) + +SECONDS=0 +tgt=Train #Train or Eval + + +log "$0 $*" +echo $tgt +. ./utils/parse_options.sh + +. ./path.sh + +AliMeeting="${PWD}/dataset" + +if [ $# -gt 2 ]; then + log "${help_message}" + exit 2 +fi + + +if [ ! -d "${AliMeeting}" ]; then + log "Error: ${AliMeeting} is empty." + exit 2 +fi + +# To absolute path +AliMeeting=$(cd ${AliMeeting}; pwd) +echo $AliMeeting +far_raw_dir=${AliMeeting}/${tgt}_Ali_far/ + +far_dir=data/local/${tgt}_Ali_far +far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker +mkdir -p $far_single_speaker_dir + +stage=1 +stop_stage=3 +mkdir -p $far_dir + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + log "stage 1:process alimeeting far dir" + + find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist + awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid + find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist + n1_wav=$(wc -l < $far_dir/wavlist) + n2_text=$(wc -l < $far_dir/textgrid.flist) + log far file found $n1_wav wav and $n2_text text. + + paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp + + cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp + + python local/alimeeting_process_overlap_force.py --path $far_dir \ + --no-overlap false --mars True \ + --overlap_length 0.8 --max_length 7 + + cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text + utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk + #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk + + utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt + utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments + sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1 + sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2 + sed -e 's/!//g' $far_dir/tmp2> $far_dir/tmp3 + sed -e 's/?//g' $far_dir/tmp3> $far_dir/text +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + log "stage 2: finali data process" + + utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far + + sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo + sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo + + # remove space in text + for x in ${tgt}_Ali_far; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + + log "Successfully finished. [elapsed=${SECONDS}s]" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + log "stage 3:process alimeeting far dir (single speaker by oracal time strap)" + cp -r $far_dir/* $far_single_speaker_dir + mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath + paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist + python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir + + cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text + utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt + + ./utils/fix_data_dir.sh $far_single_speaker_dir + utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker + + # remove space in text + for x in ${tgt}_Ali_far_single_speaker; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + log "Successfully finished. [elapsed=${SECONDS}s]" +fi \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py new file mode 100755 index 000000000..8ece75706 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +""" +Process the textgrid files +""" +import argparse +import codecs +from distutils.util import strtobool +from pathlib import Path +import textgrid +import pdb + +class Segment(object): + def __init__(self, uttid, spkr, stime, etime, text): + self.uttid = uttid + self.spkr = spkr + self.spkr_all = uttid+"-"+spkr + self.stime = round(stime, 2) + self.etime = round(etime, 2) + self.text = text + self.spk_text = {uttid+"-"+spkr: text} + + def change_stime(self, time): + self.stime = time + + def change_etime(self, time): + self.etime = time + + +def get_args(): + parser = argparse.ArgumentParser(description="process the textgrid files") + parser.add_argument("--path", type=str, required=True, help="Data path") + parser.add_argument( + "--no-overlap", + type=strtobool, + default=False, + help="Whether to ignore the overlapping utterances.", + ) + parser.add_argument( + "--max_length", + default=100000, + type=float, + help="overlap speech max time,if longger than max length should cut", + ) + parser.add_argument( + "--overlap_length", + default=1, + type=float, + help="if length longer than max length, speech overlength shorter, is cut", + ) + parser.add_argument( + "--mars", + type=strtobool, + default=False, + help="Whether to process mars data set.", + ) + args = parser.parse_args() + return args + + +def preposs_overlap(segments,max_length,overlap_length): + new_segments = [] + # init a helper list to store all overlap segments + tmp_segments = segments[0] + min_stime = segments[0].stime + max_etime = segments[0].etime + overlap_length_big = 1.5 + max_length_big = 15 + for i in range(1, len(segments)): + if segments[i].stime >= max_etime: + # doesn't overlap with preivous segments + new_segments.append(tmp_segments) + tmp_segments = segments[i] + min_stime = segments[i].stime + max_etime = segments[i].etime + else: + # overlap with previous segments + dur_time = max_etime - min_stime + if dur_time < max_length: + if min_stime > segments[i].stime: + min_stime = segments[i].stime + if max_etime < segments[i].etime: + max_etime = segments[i].etime + tmp_segments.stime = min_stime + tmp_segments.etime = max_etime + tmp_segments.text = tmp_segments.text + "src" + segments[i].text + spk_name =segments[i].uttid +"-" + segments[i].spkr + if spk_name in tmp_segments.spk_text: + tmp_segments.spk_text[spk_name] += segments[i].text + else: + tmp_segments.spk_text[spk_name] = segments[i].text + tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name + else: + overlap_time = max_etime - segments[i].stime + if dur_time < max_length_big: + overlap_length_option = overlap_length + else: + overlap_length_option = overlap_length_big + if overlap_time > overlap_length_option: + if min_stime > segments[i].stime: + min_stime = segments[i].stime + if max_etime < segments[i].etime: + max_etime = segments[i].etime + tmp_segments.stime = min_stime + tmp_segments.etime = max_etime + tmp_segments.text = tmp_segments.text + "src" + segments[i].text + spk_name =segments[i].uttid +"-" + segments[i].spkr + if spk_name in tmp_segments.spk_text: + tmp_segments.spk_text[spk_name] += segments[i].text + else: + tmp_segments.spk_text[spk_name] = segments[i].text + tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name + else: + new_segments.append(tmp_segments) + tmp_segments = segments[i] + min_stime = segments[i].stime + max_etime = segments[i].etime + + return new_segments + +def filter_overlap(segments): + new_segments = [] + # init a helper list to store all overlap segments + tmp_segments = [segments[0]] + min_stime = segments[0].stime + max_etime = segments[0].etime + + for i in range(1, len(segments)): + if segments[i].stime >= max_etime: + # doesn't overlap with preivous segments + if len(tmp_segments) == 1: + new_segments.append(tmp_segments[0]) + # TODO: for multi-spkr asr, we can reset the stime/etime to + # min_stime/max_etime for generating a max length mixutre speech + tmp_segments = [segments[i]] + min_stime = segments[i].stime + max_etime = segments[i].etime + else: + # overlap with previous segments + tmp_segments.append(segments[i]) + if min_stime > segments[i].stime: + min_stime = segments[i].stime + if max_etime < segments[i].etime: + max_etime = segments[i].etime + + return new_segments + + +def main(args): + wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8") + textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8") + + # get the path of textgrid file for each utterance + utt2textgrid = {} + for line in textgrid_flist: + path = Path(line.strip()) + uttid = path.stem + utt2textgrid[uttid] = path + + # parse the textgrid file for each utterance + all_segments = [] + for line in wav_scp: + uttid = line.strip().split(" ")[0] + uttid_part=uttid + if args.mars == True: + uttid_list = uttid.split("_") + uttid_part= uttid_list[0]+"_"+uttid_list[1] + if uttid_part not in utt2textgrid: + print("%s doesn't have transcription" % uttid) + continue + + segments = [] + tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part]) + for i in range(tg.__len__()): + for j in range(tg[i].__len__()): + if tg[i][j].mark: + segments.append( + Segment( + uttid, + tg[i].name, + tg[i][j].minTime, + tg[i][j].maxTime, + tg[i][j].mark.strip(), + ) + ) + + segments = sorted(segments, key=lambda x: x.stime) + + if args.no_overlap: + segments = filter_overlap(segments) + else: + segments = preposs_overlap(segments,args.max_length,args.overlap_length) + all_segments += segments + + wav_scp.close() + textgrid_flist.close() + + segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8") + utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8") + text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8") + utt2spk_file_fifo = codecs.open(Path(args.path) / "utt2spk_all_fifo", "w", "utf-8") + + for i in range(len(all_segments)): + utt_name = "%s-%s-%07d-%07d" % ( + all_segments[i].uttid, + all_segments[i].spkr, + all_segments[i].stime * 100, + all_segments[i].etime * 100, + ) + + segments_file.write( + "%s %s %.2f %.2f\n" + % ( + utt_name, + all_segments[i].uttid, + all_segments[i].stime, + all_segments[i].etime, + ) + ) + utt2spk_file.write( + "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr) + ) + utt2spk_file_fifo.write( + "%s %s\n" % (utt_name, all_segments[i].spkr_all) + ) + text_file.write("%s %s\n" % (utt_name, all_segments[i].text)) + + segments_file.close() + utt2spk_file.close() + text_file.close() + utt2spk_file_fifo.close() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py new file mode 100755 index 000000000..81c19659a --- /dev/null +++ b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +""" +Process the textgrid files +""" +import argparse +import codecs +from distutils.util import strtobool +from pathlib import Path +import textgrid +import pdb + +class Segment(object): + def __init__(self, uttid, spkr, stime, etime, text): + self.uttid = uttid + self.spkr = spkr + self.stime = round(stime, 2) + self.etime = round(etime, 2) + self.text = text + + def change_stime(self, time): + self.stime = time + + def change_etime(self, time): + self.etime = time + + +def get_args(): + parser = argparse.ArgumentParser(description="process the textgrid files") + parser.add_argument("--path", type=str, required=True, help="Data path") + parser.add_argument( + "--no-overlap", + type=strtobool, + default=False, + help="Whether to ignore the overlapping utterances.", + ) + parser.add_argument( + "--mars", + type=strtobool, + default=False, + help="Whether to process mars data set.", + ) + args = parser.parse_args() + return args + + +def filter_overlap(segments): + new_segments = [] + # init a helper list to store all overlap segments + tmp_segments = [segments[0]] + min_stime = segments[0].stime + max_etime = segments[0].etime + + for i in range(1, len(segments)): + if segments[i].stime >= max_etime: + # doesn't overlap with preivous segments + if len(tmp_segments) == 1: + new_segments.append(tmp_segments[0]) + # TODO: for multi-spkr asr, we can reset the stime/etime to + # min_stime/max_etime for generating a max length mixutre speech + tmp_segments = [segments[i]] + min_stime = segments[i].stime + max_etime = segments[i].etime + else: + # overlap with previous segments + tmp_segments.append(segments[i]) + if min_stime > segments[i].stime: + min_stime = segments[i].stime + if max_etime < segments[i].etime: + max_etime = segments[i].etime + + return new_segments + + +def main(args): + wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8") + textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8") + + # get the path of textgrid file for each utterance + utt2textgrid = {} + for line in textgrid_flist: + path = Path(line.strip()) + uttid = path.stem + utt2textgrid[uttid] = path + + # parse the textgrid file for each utterance + all_segments = [] + for line in wav_scp: + uttid = line.strip().split(" ")[0] + uttid_part=uttid + if args.mars == True: + uttid_list = uttid.split("_") + uttid_part= uttid_list[0]+"_"+uttid_list[1] + if uttid_part not in utt2textgrid: + print("%s doesn't have transcription" % uttid) + continue + #pdb.set_trace() + segments = [] + try: + tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part]) + except: + pdb.set_trace() + for i in range(tg.__len__()): + for j in range(tg[i].__len__()): + if tg[i][j].mark: + segments.append( + Segment( + uttid, + tg[i].name, + tg[i][j].minTime, + tg[i][j].maxTime, + tg[i][j].mark.strip(), + ) + ) + + segments = sorted(segments, key=lambda x: x.stime) + + if args.no_overlap: + segments = filter_overlap(segments) + + all_segments += segments + + wav_scp.close() + textgrid_flist.close() + + segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8") + utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8") + text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8") + + for i in range(len(all_segments)): + utt_name = "%s-%s-%07d-%07d" % ( + all_segments[i].uttid, + all_segments[i].spkr, + all_segments[i].stime * 100, + all_segments[i].etime * 100, + ) + + segments_file.write( + "%s %s %.2f %.2f\n" + % ( + utt_name, + all_segments[i].uttid, + all_segments[i].stime, + all_segments[i].etime, + ) + ) + utt2spk_file.write( + "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr) + ) + text_file.write("%s %s\n" % (utt_name, all_segments[i].text)) + + segments_file.close() + utt2spk_file.close() + text_file.close() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/alimeeting/sa-asr/local/compute_cpcer.py b/egs/alimeeting/sa-asr/local/compute_cpcer.py new file mode 100644 index 000000000..f4d4a7978 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/compute_cpcer.py @@ -0,0 +1,91 @@ +import editdistance +import sys +import os +from itertools import permutations + + +def load_transcripts(file_path): + trans_list = [] + for one_line in open(file_path, "rt"): + meeting_id, trans = one_line.strip().split(" ") + trans_list.append((meeting_id.strip(), trans.strip())) + + return trans_list + +def calc_spk_trans(trans): + spk_trans_ = [x.strip() for x in trans.split("$")] + spk_trans = [] + for i in range(len(spk_trans_)): + spk_trans.append((str(i), spk_trans_[i])) + return spk_trans + +def calc_cer(ref_trans, hyp_trans): + ref_spk_trans = calc_spk_trans(ref_trans) + hyp_spk_trans = calc_spk_trans(hyp_trans) + ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans) + num_spk = max(len(ref_spk_trans), len(hyp_spk_trans)) + ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans))) + hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans))) + + errors, counts, permutes = [], [], [] + min_error = 0 + cost_dict = {} + for perm in permutations(range(num_spk)): + flag = True + p_err, p_count = 0, 0 + for idx, p in enumerate(perm): + if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0: + flag = False + break + cost_key = "{}-{}".format(idx, p) + if cost_key in cost_dict: + _e = cost_dict[cost_key] + else: + _e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1]) + cost_dict[cost_key] = _e + if _e > min_error > 0: + flag = False + break + p_err += _e + p_count += len(ref_spk_trans[idx][1]) + + if flag: + if p_err < min_error or min_error == 0: + min_error = p_err + + errors.append(p_err) + counts.append(p_count) + permutes.append(perm) + + sd_cer = [(err, cnt, err/cnt, permute) + for err, cnt, permute in zip(errors, counts, permutes)] + # import ipdb;ipdb.set_trace() + best_rst = min(sd_cer, key=lambda x: x[2]) + + return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num + + +def main(): + ref=sys.argv[1] + hyp=sys.argv[2] + result_path=sys.argv[3] + ref_list = load_transcripts(ref) + hyp_list = load_transcripts(hyp) + result_file = open(result_path,'w') + error, count = 0, 0 + for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list): + assert ref_id == hyp_id + mid = ref_id + dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans) + error, count = error + dist, count + length + result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num)) + + # print("{} {:.2f} {} {}".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num)) + + result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0)) + result_file.close() + # print("Sum/Avg: {:.2f}".format(error / count * 100.0)) + + +if __name__ == '__main__': + main() diff --git a/egs/alimeeting/sa-asr/local/compute_wer.py b/egs/alimeeting/sa-asr/local/compute_wer.py new file mode 100755 index 000000000..349a3f609 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/compute_wer.py @@ -0,0 +1,157 @@ +import os +import numpy as np +import sys + +def compute_wer(ref_file, + hyp_file, + cer_detail_file): + rst = { + 'Wrd': 0, + 'Corr': 0, + 'Ins': 0, + 'Del': 0, + 'Sub': 0, + 'Snt': 0, + 'Err': 0.0, + 'S.Err': 0.0, + 'wrong_words': 0, + 'wrong_sentences': 0 + } + + hyp_dict = {} + ref_dict = {} + with open(hyp_file, 'r') as hyp_reader: + for line in hyp_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + hyp_dict[key] = value + with open(ref_file, 'r') as ref_reader: + for line in ref_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + ref_dict[key] = value + + cer_detail_writer = open(cer_detail_file, 'w') + for hyp_key in hyp_dict: + if hyp_key in ref_dict: + out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key]) + rst['Wrd'] += out_item['nwords'] + rst['Corr'] += out_item['cor'] + rst['wrong_words'] += out_item['wrong'] + rst['Ins'] += out_item['ins'] + rst['Del'] += out_item['del'] + rst['Sub'] += out_item['sub'] + rst['Snt'] += 1 + if out_item['wrong'] > 0: + rst['wrong_sentences'] += 1 + cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n') + cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n') + cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n') + + if rst['Wrd'] > 0: + rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) + if rst['Snt'] > 0: + rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) + + cer_detail_writer.write('\n') + cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) + + ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n') + cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n') + cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n') + + +def compute_wer_by_line(hyp, + ref): + hyp = list(map(lambda x: x.lower(), hyp)) + ref = list(map(lambda x: x.lower(), ref)) + + len_hyp = len(hyp) + len_ref = len(ref) + + cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) + + ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) + + for i in range(len_hyp + 1): + cost_matrix[i][0] = i + for j in range(len_ref + 1): + cost_matrix[0][j] = j + + for i in range(1, len_hyp + 1): + for j in range(1, len_ref + 1): + if hyp[i - 1] == ref[j - 1]: + cost_matrix[i][j] = cost_matrix[i - 1][j - 1] + else: + substitution = cost_matrix[i - 1][j - 1] + 1 + insertion = cost_matrix[i - 1][j] + 1 + deletion = cost_matrix[i][j - 1] + 1 + + compare_val = [substitution, insertion, deletion] + + min_val = min(compare_val) + operation_idx = compare_val.index(min_val) + 1 + cost_matrix[i][j] = min_val + ops_matrix[i][j] = operation_idx + + match_idx = [] + i = len_hyp + j = len_ref + rst = { + 'nwords': len_ref, + 'cor': 0, + 'wrong': 0, + 'ins': 0, + 'del': 0, + 'sub': 0 + } + while i >= 0 or j >= 0: + i_idx = max(0, i) + j_idx = max(0, j) + + if ops_matrix[i_idx][j_idx] == 0: # correct + if i - 1 >= 0 and j - 1 >= 0: + match_idx.append((j - 1, i - 1)) + rst['cor'] += 1 + + i -= 1 + j -= 1 + + elif ops_matrix[i_idx][j_idx] == 2: # insert + i -= 1 + rst['ins'] += 1 + + elif ops_matrix[i_idx][j_idx] == 3: # delete + j -= 1 + rst['del'] += 1 + + elif ops_matrix[i_idx][j_idx] == 1: # substitute + i -= 1 + j -= 1 + rst['sub'] += 1 + + if i < 0 and j >= 0: + rst['del'] += 1 + elif j < 0 and i >= 0: + rst['ins'] += 1 + + match_idx.reverse() + wrong_cnt = cost_matrix[len_hyp][len_ref] + rst['wrong'] = wrong_cnt + + return rst + +def print_cer_detail(rst): + return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor']) + + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub=" + + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords']) + + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords'])) + +if __name__ == '__main__': + if len(sys.argv) != 4: + print("usage : python compute-wer.py test.ref test.hyp test.wer") + sys.exit(0) + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + cer_detail_file = sys.argv[3] + compute_wer(ref_file, hyp_file, cer_detail_file) diff --git a/egs/alimeeting/sa-asr/local/download_xvector_model.py b/egs/alimeeting/sa-asr/local/download_xvector_model.py new file mode 100644 index 000000000..7da655944 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/download_xvector_model.py @@ -0,0 +1,6 @@ +from modelscope.hub.snapshot_download import snapshot_download +import sys + + +cache_dir = sys.argv[1] +model_dir = snapshot_download('damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch', cache_dir=cache_dir) diff --git a/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py new file mode 100644 index 000000000..e6061625b --- /dev/null +++ b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py @@ -0,0 +1,22 @@ +import sys +if __name__=="__main__": + uttid_path=sys.argv[1] + src_path=sys.argv[2] + tgt_path=sys.argv[3] + uttid_file=open(uttid_path,'r') + uttid_line=uttid_file.readlines() + uttid_file.close() + ori_utt2spk_all_fifo_file=open(src_path+'/utt2spk_all_fifo','r') + ori_utt2spk_all_fifo_line=ori_utt2spk_all_fifo_file.readlines() + ori_utt2spk_all_fifo_file.close() + new_utt2spk_all_fifo_file=open(tgt_path+'/utt2spk_all_fifo','w') + + uttid_list=[] + for line in uttid_line: + uttid_list.append(line.strip()) + + for line in ori_utt2spk_all_fifo_line: + if line.strip().split(' ')[0] in uttid_list: + new_utt2spk_all_fifo_file.write(line) + + new_utt2spk_all_fifo_file.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py new file mode 100644 index 000000000..c37abf9a0 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py @@ -0,0 +1,167 @@ +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +import numpy as np +import sys +import os +import soundfile +from itertools import permutations +from sklearn.metrics.pairwise import cosine_similarity +from sklearn import cluster + + +def custom_spectral_clustering(affinity, min_n_clusters=2, max_n_clusters=4, refine=True, + threshold=0.995, laplacian_type="graph_cut"): + if refine: + # Symmetrization + affinity = np.maximum(affinity, np.transpose(affinity)) + # Diffusion + affinity = np.matmul(affinity, np.transpose(affinity)) + # Row-wise max normalization + row_max = affinity.max(axis=1, keepdims=True) + affinity = affinity / row_max + + # a) Construct S and set diagonal elements to 0 + affinity = affinity - np.diag(np.diag(affinity)) + # b) Compute Laplacian matrix L and perform normalization: + degree = np.diag(np.sum(affinity, axis=1)) + laplacian = degree - affinity + if laplacian_type == "random_walk": + degree_norm = np.diag(1 / (np.diag(degree) + 1e-10)) + laplacian_norm = degree_norm.dot(laplacian) + else: + degree_half = np.diag(degree) ** 0.5 + 1e-15 + laplacian_norm = laplacian / degree_half[:, np.newaxis] / degree_half + + # c) Compute eigenvalues and eigenvectors of L_norm + eigenvalues, eigenvectors = np.linalg.eig(laplacian_norm) + eigenvalues = eigenvalues.real + eigenvectors = eigenvectors.real + index_array = np.argsort(eigenvalues) + eigenvalues = eigenvalues[index_array] + eigenvectors = eigenvectors[:, index_array] + + # d) Compute the number of clusters k + k = min_n_clusters + for k in range(min_n_clusters, max_n_clusters + 1): + if eigenvalues[k] > threshold: + break + k = max(k, min_n_clusters) + spectral_embeddings = eigenvectors[:, :k] + # print(mid, k, eigenvalues[:10]) + + spectral_embeddings = spectral_embeddings / np.linalg.norm(spectral_embeddings, axis=1, ord=2, keepdims=True) + solver = cluster.KMeans(n_clusters=k, max_iter=1000, random_state=42) + solver.fit(spectral_embeddings) + return solver.labels_ + + +if __name__ == "__main__": + path = sys.argv[1] # dump2/raw/Eval_Ali_far + raw_path = sys.argv[2] # data/local/Eval_Ali_far + threshold = float(sys.argv[3]) # 0.996 + sv_threshold = float(sys.argv[4]) # 0.815 + wav_scp_file = open(path+'/wav.scp', 'r') + wav_scp = wav_scp_file.readlines() + wav_scp_file.close() + raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r') + raw_meeting_scp = raw_meeting_scp_file.readlines() + raw_meeting_scp_file.close() + segments_scp_file = open(raw_path + '/segments', 'r') + segments_scp = segments_scp_file.readlines() + segments_scp_file.close() + + segments_map = {} + for line in segments_scp: + line_list = line.strip().split(' ') + meeting = line_list[1] + seg = (float(line_list[-2]), float(line_list[-1])) + if meeting not in segments_map.keys(): + segments_map[meeting] = [seg] + else: + segments_map[meeting].append(seg) + + inference_sv_pipline = pipeline( + task=Tasks.speaker_verification, + model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch' + ) + + chunk_len = int(1.5*16000) # 1.5 seconds + hop_len = int(0.75*16000) # 0.75 seconds + + os.system("mkdir -p " + path + "/cluster_profile_infer") + cluster_spk_num_file = open(path + '/cluster_spk_num', 'w') + meeting_map = {} + for line in raw_meeting_scp: + meeting = line.strip().split('\t')[0] + wav_path = line.strip().split('\t')[1] + wav = soundfile.read(wav_path)[0] + # take the first channel + if wav.ndim == 2: + wav=wav[:, 0] + # gen_seg_embedding + segments_list = segments_map[meeting] + + # import ipdb;ipdb.set_trace() + all_seg_embedding_list = [] + for seg in segments_list: + wav_seg = wav[int(seg[0] * 16000): int(seg[1] * 16000)] + wav_seg_len = wav_seg.shape[0] + i = 0 + while i < wav_seg_len: + if i + chunk_len < wav_seg_len: + cur_wav_chunk = wav_seg[i: i+chunk_len] + else: + cur_wav_chunk=wav_seg[i: ] + # chunks under 0.2s are ignored + if cur_wav_chunk.shape[0] >= 0.2 * 16000: + cur_chunk_embedding = inference_sv_pipline(audio_in=cur_wav_chunk)["spk_embedding"] + all_seg_embedding_list.append(cur_chunk_embedding) + i += hop_len + all_seg_embedding = np.vstack(all_seg_embedding_list) + # all_seg_embedding (n, dim) + + # compute affinity + affinity=cosine_similarity(all_seg_embedding) + + affinity = np.maximum(affinity - sv_threshold, 0.0001) / (affinity.max() - sv_threshold) + + # clustering + labels = custom_spectral_clustering( + affinity=affinity, + min_n_clusters=2, + max_n_clusters=4, + refine=True, + threshold=threshold, + laplacian_type="graph_cut") + + + cluster_dict={} + for j in range(labels.shape[0]): + if labels[j] not in cluster_dict.keys(): + cluster_dict[labels[j]] = np.atleast_2d(all_seg_embedding[j]) + else: + cluster_dict[labels[j]] = np.concatenate((cluster_dict[labels[j]], np.atleast_2d(all_seg_embedding[j]))) + + emb_list = [] + # get cluster center + for k in cluster_dict.keys(): + cluster_dict[k] = np.mean(cluster_dict[k], axis=0) + emb_list.append(cluster_dict[k]) + + spk_num = len(emb_list) + profile_for_infer = np.vstack(emb_list) + # save profile for each meeting + np.save(path + '/cluster_profile_infer/' + meeting + '.npy', profile_for_infer) + meeting_map[meeting] = (path + '/cluster_profile_infer/' + meeting + '.npy', spk_num) + cluster_spk_num_file.write(meeting + ' ' + str(spk_num) + '\n') + cluster_spk_num_file.flush() + + cluster_spk_num_file.close() + + profile_scp = open(path + "/cluster_profile_infer.scp", 'w') + for line in wav_scp: + uttid = line.strip().split(' ')[0] + meeting = uttid.split('-')[0] + profile_scp.write(uttid + ' ' + meeting_map[meeting][0] + '\n') + profile_scp.flush() + profile_scp.close() diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py new file mode 100644 index 000000000..18286b42d --- /dev/null +++ b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py @@ -0,0 +1,70 @@ +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +import numpy as np +import sys +import os +import soundfile + + +if __name__=="__main__": + path = sys.argv[1] # dump2/raw/Eval_Ali_far + raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker + raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r') + raw_meeting_scp = raw_meeting_scp_file.readlines() + raw_meeting_scp_file.close() + segments_scp_file = open(raw_path + '/segments', 'r') + segments_scp = segments_scp_file.readlines() + segments_scp_file.close() + + oracle_emb_dir = path + '/oracle_embedding/' + os.system("mkdir -p " + oracle_emb_dir) + oracle_emb_scp_file = open(path+'/oracle_embedding.scp', 'w') + + raw_wav_map = {} + for line in raw_meeting_scp: + meeting = line.strip().split('\t')[0] + wav_path = line.strip().split('\t')[1] + raw_wav_map[meeting] = wav_path + + spk_map = {} + for line in segments_scp: + line_list = line.strip().split(' ') + meeting = line_list[1] + spk_id = line_list[0].split('_')[3] + spk = meeting + '_' + spk_id + time_start = float(line_list[-2]) + time_end = float(line_list[-1]) + if time_end - time_start > 0.5: + if spk not in spk_map.keys(): + spk_map[spk] = [(int(time_start * 16000), int(time_end * 16000))] + else: + spk_map[spk].append((int(time_start * 16000), int(time_end * 16000))) + + inference_sv_pipline = pipeline( + task=Tasks.speaker_verification, + model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch' + ) + + for spk in spk_map.keys(): + meeting = spk.split('_SPK')[0] + wav_path = raw_wav_map[meeting] + wav = soundfile.read(wav_path)[0] + # take the first channel + if wav.ndim == 2: + wav = wav[:, 0] + all_seg_embedding_list=[] + # import ipdb;ipdb.set_trace() + for seg_time in spk_map[spk]: + if seg_time[0] < wav.shape[0] - 0.5 * 16000: + if seg_time[1] > wav.shape[0]: + cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: ])["spk_embedding"] + else: + cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: seg_time[1]])["spk_embedding"] + all_seg_embedding_list.append(cur_seg_embedding) + all_seg_embedding = np.vstack(all_seg_embedding_list) + spk_embedding = np.mean(all_seg_embedding, axis=0) + np.save(oracle_emb_dir + spk + '.npy', spk_embedding) + oracle_emb_scp_file.write(spk + ' ' + oracle_emb_dir + spk + '.npy' + '\n') + oracle_emb_scp_file.flush() + + oracle_emb_scp_file.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py new file mode 100644 index 000000000..f44fcd449 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py @@ -0,0 +1,59 @@ +import random +import numpy as np +import os +import sys + + +if __name__=="__main__": + path = sys.argv[1] # dump2/raw/Eval_Ali_far + wav_scp_file = open(path+"/wav.scp", 'r') + wav_scp = wav_scp_file.readlines() + wav_scp_file.close() + spk2id_file = open(path + "/spk2id", 'r') + spk2id = spk2id_file.readlines() + spk2id_file.close() + embedding_scp_file = open(path + "/oracle_embedding.scp", 'r') + embedding_scp = embedding_scp_file.readlines() + embedding_scp_file.close() + + embedding_map = {} + for line in embedding_scp: + spk = line.strip().split(' ')[0] + if spk not in embedding_map.keys(): + emb=np.load(line.strip().split(' ')[1]) + embedding_map[spk] = emb + + meeting_map_tmp = {} + global_spk_list = [] + for line in spk2id: + line_list = line.strip().split(' ') + meeting = line_list[0].split('-')[0] + spk_id = line_list[0].split('-')[-1].split('_')[-1] + spk = meeting + '_' + spk_id + global_spk_list.append(spk) + if meeting in meeting_map_tmp.keys(): + meeting_map_tmp[meeting].append(spk) + else: + meeting_map_tmp[meeting] = [spk] + + meeting_map = {} + os.system('mkdir -p ' + path + '/oracle_profile_nopadding') + for meeting in meeting_map_tmp.keys(): + emb_list = [] + for i in range(len(meeting_map_tmp[meeting])): + spk = meeting_map_tmp[meeting][i] + emb_list.append(embedding_map[spk]) + profile = np.vstack(emb_list) + np.save(path + '/oracle_profile_nopadding/' + meeting + '.npy', profile) + meeting_map[meeting] = path + '/oracle_profile_nopadding/' + meeting + '.npy' + + profile_scp = open(path + '/oracle_profile_nopadding.scp', 'w') + profile_map_scp = open(path + '/oracle_profile_nopadding_spk_list', 'w') + + for line in wav_scp: + uttid = line.strip().split(' ')[0] + meeting = uttid.split('-')[0] + profile_scp.write(uttid + ' ' + meeting_map[meeting] + '\n') + profile_map_scp.write(uttid + ' ' + '$'.join(meeting_map_tmp[meeting]) + '\n') + profile_scp.close() + profile_map_scp.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py new file mode 100644 index 000000000..b70a32a19 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py @@ -0,0 +1,68 @@ +import random +import numpy as np +import os +import sys + + +if __name__=="__main__": + path = sys.argv[1] # dump2/raw/Train_Ali_far + wav_scp_file = open(path+"/wav.scp", 'r') + wav_scp = wav_scp_file.readlines() + wav_scp_file.close() + spk2id_file = open(path+"/spk2id", 'r') + spk2id = spk2id_file.readlines() + spk2id_file.close() + embedding_scp_file = open(path + "/oracle_embedding.scp", 'r') + embedding_scp = embedding_scp_file.readlines() + embedding_scp_file.close() + + embedding_map = {} + for line in embedding_scp: + spk = line.strip().split(' ')[0] + if spk not in embedding_map.keys(): + emb = np.load(line.strip().split(' ')[1]) + embedding_map[spk] = emb + + meeting_map_tmp = {} + global_spk_list = [] + for line in spk2id: + line_list = line.strip().split(' ') + meeting = line_list[0].split('-')[0] + spk_id = line_list[0].split('-')[-1].split('_')[-1] + spk = meeting+'_' + spk_id + global_spk_list.append(spk) + if meeting in meeting_map_tmp.keys(): + meeting_map_tmp[meeting].append(spk) + else: + meeting_map_tmp[meeting] = [spk] + + for meeting in meeting_map_tmp.keys(): + num = len(meeting_map_tmp[meeting]) + if num < 4: + global_spk_list_tmp = global_spk_list[: ] + for spk in meeting_map_tmp[meeting]: + global_spk_list_tmp.remove(spk) + padding_spk = random.sample(global_spk_list_tmp, 4 - num) + meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk + + meeting_map = {} + os.system('mkdir -p ' + path + '/oracle_profile_padding') + for meeting in meeting_map_tmp.keys(): + emb_list = [] + for i in range(len(meeting_map_tmp[meeting])): + spk = meeting_map_tmp[meeting][i] + emb_list.append(embedding_map[spk]) + profile = np.vstack(emb_list) + np.save(path + '/oracle_profile_padding/' + meeting + '.npy',profile) + meeting_map[meeting] = path + '/oracle_profile_padding/' + meeting + '.npy' + + profile_scp = open(path + '/oracle_profile_padding.scp', 'w') + profile_map_scp = open(path + '/oracle_profile_padding_spk_list', 'w') + + for line in wav_scp: + uttid = line.strip().split(' ')[0] + meeting = uttid.split('-')[0] + profile_scp.write(uttid+' ' + meeting_map[meeting] + '\n') + profile_map_scp.write(uttid+' ' + '$'.join(meeting_map_tmp[meeting]) + '\n') + profile_scp.close() + profile_map_scp.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/proce_text.py b/egs/alimeeting/sa-asr/local/proce_text.py new file mode 100755 index 000000000..e56cc0f37 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/proce_text.py @@ -0,0 +1,32 @@ + +import sys +import re + +in_f = sys.argv[1] +out_f = sys.argv[2] + + +with open(in_f, "r", encoding="utf-8") as f: + lines = f.readlines() + +with open(out_f, "w", encoding="utf-8") as f: + for line in lines: + outs = line.strip().split(" ", 1) + if len(outs) == 2: + idx, text = outs + text = re.sub("", "", text) + text = re.sub("", "", text) + text = re.sub("@@", "", text) + text = re.sub("@", "", text) + text = re.sub("", "", text) + text = re.sub(" ", "", text) + text = re.sub("\$", "", text) + text = text.lower() + else: + idx = outs[0] + text = " " + + text = [x for x in text] + text = " ".join(text) + out = "{} {}\n".format(idx, text) + f.write(out) diff --git a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py new file mode 100755 index 000000000..d900bb17a --- /dev/null +++ b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +Process the textgrid files +""" +import argparse +import codecs +from distutils.util import strtobool +from pathlib import Path +import textgrid +import pdb + +def get_args(): + parser = argparse.ArgumentParser(description="process the textgrid files") + parser.add_argument("--path", type=str, required=True, help="Data path") + args = parser.parse_args() + return args + +class Segment(object): + def __init__(self, uttid, text): + self.uttid = uttid + self.text = text + +def main(args): + text = codecs.open(Path(args.path) / "text", "r", "utf-8") + spk2utt = codecs.open(Path(args.path) / "spk2utt", "r", "utf-8") + utt2spk = codecs.open(Path(args.path) / "utt2spk_all_fifo", "r", "utf-8") + spk2id = codecs.open(Path(args.path) / "spk2id", "w", "utf-8") + + spkid_map = {} + meetingid_map = {} + for line in spk2utt: + spkid = line.strip().split(" ")[0] + meeting_id_list = spkid.split("_")[:3] + meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2] + if meeting_id not in meetingid_map: + meetingid_map[meeting_id] = 1 + else: + meetingid_map[meeting_id] += 1 + spkid_map[spkid] = meetingid_map[meeting_id] + spk2id.write("%s %s\n" % (spkid, meetingid_map[meeting_id])) + + utt2spklist = {} + for line in utt2spk: + uttid = line.strip().split(" ")[0] + spkid = line.strip().split(" ")[1] + spklist = spkid.split("$") + tmp = [] + for index in range(len(spklist)): + tmp.append(spkid_map[spklist[index]]) + utt2spklist[uttid] = tmp + # parse the textgrid file for each utterance + all_segments = [] + for line in text: + uttid = line.strip().split(" ")[0] + context = line.strip().split(" ")[1] + spklist = utt2spklist[uttid] + length_text = len(context) + cnt = 0 + tmp_text = "" + for index in range(length_text): + if context[index] != "$": + tmp_text += str(spklist[cnt]) + else: + tmp_text += "$" + cnt += 1 + tmp_seg = Segment(uttid,tmp_text) + all_segments.append(tmp_seg) + + text.close() + utt2spk.close() + spk2utt.close() + spk2id.close() + + text_id = codecs.open(Path(args.path) / "text_id", "w", "utf-8") + + for i in range(len(all_segments)): + uttid_tmp = all_segments[i].uttid + text_tmp = all_segments[i].text + + text_id.write("%s %s\n" % (uttid_tmp, text_tmp)) + + text_id.close() + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/alimeeting/sa-asr/local/process_text_id.py b/egs/alimeeting/sa-asr/local/process_text_id.py new file mode 100644 index 000000000..0a9506e29 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/process_text_id.py @@ -0,0 +1,24 @@ +import sys +if __name__=="__main__": + path=sys.argv[1] + + text_id_old_file=open(path+"/text_id",'r') + text_id_old=text_id_old_file.readlines() + text_id_old_file.close() + + text_id=open(path+"/text_id_train",'w') + for line in text_id_old: + uttid=line.strip().split(' ')[0] + old_id=line.strip().split(' ')[1] + pre_id='0' + new_id_list=[] + for i in old_id: + if i == '$': + new_id_list.append(pre_id) + else: + new_id_list.append(str(int(i)-1)) + pre_id=str(int(i)-1) + new_id_list.append(pre_id) + new_id=' '.join(new_id_list) + text_id.write(uttid+' '+new_id+'\n') + text_id.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/process_text_spk_merge.py b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py new file mode 100644 index 000000000..f15d509b5 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py @@ -0,0 +1,55 @@ +import sys + + +if __name__ == "__main__": + path=sys.argv[1] + text_scp_file = open(path + '/text', 'r') + text_scp = text_scp_file.readlines() + text_scp_file.close() + text_id_scp_file = open(path + '/text_id', 'r') + text_id_scp = text_id_scp_file.readlines() + text_id_scp_file.close() + text_spk_merge_file = open(path + '/text_spk_merge', 'w') + assert len(text_scp) == len(text_id_scp) + + meeting_map = {} # {meeting_id: [(start_time, text, text_id), (start_time, text, text_id), ...]} + for i in range(len(text_scp)): + text_line = text_scp[i].strip().split(' ') + text_id_line = text_id_scp[i].strip().split(' ') + assert text_line[0] == text_id_line[0] + if len(text_line) > 1: + uttid = text_line[0] + text = text_line[1] + text_id = text_id_line[1] + meeting_id = uttid.split('-')[0] + start_time = int(uttid.split('-')[-2]) + if meeting_id not in meeting_map: + meeting_map[meeting_id] = [(start_time,text,text_id)] + else: + meeting_map[meeting_id].append((start_time,text,text_id)) + + for meeting_id in sorted(meeting_map.keys()): + cur_meeting_list = sorted(meeting_map[meeting_id], key=lambda x: x[0]) + text_spk_merge_map = {} #{1: text1, 2: text2, ...} + for cur_utt in cur_meeting_list: + cur_text = cur_utt[1] + cur_text_id = cur_utt[2] + assert len(cur_text)==len(cur_text_id) + if len(cur_text) != 0: + cur_text_split = cur_text.split('$') + cur_text_id_split = cur_text_id.split('$') + assert len(cur_text_split) == len(cur_text_id_split) + for i in range(len(cur_text_split)): + if len(cur_text_split[i]) != 0: + spk_id = int(cur_text_id_split[i][0]) + if spk_id not in text_spk_merge_map.keys(): + text_spk_merge_map[spk_id] = cur_text_split[i] + else: + text_spk_merge_map[spk_id] += cur_text_split[i] + text_spk_merge_list = [] + for spk_id in sorted(text_spk_merge_map.keys()): + text_spk_merge_list.append(text_spk_merge_map[spk_id]) + text_spk_merge_file.write(meeting_id + ' ' + '$'.join(text_spk_merge_list) + '\n') + text_spk_merge_file.flush() + + text_spk_merge_file.close() \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py new file mode 100755 index 000000000..fdf246090 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" +Process the textgrid files +""" +import argparse +import codecs +from distutils.util import strtobool +from pathlib import Path +import textgrid +import pdb +import numpy as np +import sys +import math + + +class Segment(object): + def __init__(self, uttid, spkr, stime, etime, text): + self.uttid = uttid + self.spkr = spkr + self.stime = round(stime, 2) + self.etime = round(etime, 2) + self.text = text + + def change_stime(self, time): + self.stime = time + + def change_etime(self, time): + self.etime = time + + +def get_args(): + parser = argparse.ArgumentParser(description="process the textgrid files") + parser.add_argument("--path", type=str, required=True, help="Data path") + args = parser.parse_args() + return args + + + +def main(args): + textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8") + segment_file = codecs.open(Path(args.path)/"segments", "w", "utf-8") + utt2spk = codecs.open(Path(args.path)/"utt2spk", "w", "utf-8") + + # get the path of textgrid file for each utterance + for line in textgrid_flist: + line_array = line.strip().split(" ") + path = Path(line_array[1]) + uttid = line_array[0] + + try: + tg = textgrid.TextGrid.fromFile(path) + except: + pdb.set_trace() + num_spk = tg.__len__() + spk2textgrid = {} + spk2weight = {} + weight2spk = {} + cnt = 2 + xmax = 0 + for i in range(tg.__len__()): + spk_name = tg[i].name + if spk_name not in spk2weight: + spk2weight[spk_name] = cnt + weight2spk[cnt] = spk_name + cnt = cnt * 2 + segments = [] + for j in range(tg[i].__len__()): + if tg[i][j].mark: + if xmax < tg[i][j].maxTime: + xmax = tg[i][j].maxTime + segments.append( + Segment( + uttid, + tg[i].name, + tg[i][j].minTime, + tg[i][j].maxTime, + tg[i][j].mark.strip(), + ) + ) + segments = sorted(segments, key=lambda x: x.stime) + spk2textgrid[spk_name] = segments + olp_label = np.zeros((num_spk, int(xmax/0.01)), dtype=np.int32) + for spkid in spk2weight.keys(): + weight = spk2weight[spkid] + segments = spk2textgrid[spkid] + idx = int(math.log2(weight) )- 1 + for i in range(len(segments)): + stime = segments[i].stime + etime = segments[i].etime + olp_label[idx, int(stime/0.01): int(etime/0.01)] = weight + sum_label = olp_label.sum(axis=0) + stime = 0 + pre_value = 0 + for pos in range(sum_label.shape[0]): + if sum_label[pos] in weight2spk: + if pre_value in weight2spk: + if sum_label[pos] != pre_value: + spkids = weight2spk[pre_value] + spkid_array = spkids.split("_") + spkid = spkid_array[-1] + #spkid = uttid+spkid + if round(stime*0.01, 2) != round((pos-1)*0.01, 2): + segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2))) + utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid)) + stime = pos + pre_value = sum_label[pos] + else: + stime = pos + pre_value = sum_label[pos] + else: + if pre_value in weight2spk: + spkids = weight2spk[pre_value] + spkid_array = spkids.split("_") + spkid = spkid_array[-1] + #spkid = uttid+spkid + if round(stime*0.01, 2) != round((pos-1)*0.01, 2): + segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2))) + utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid)) + stime = pos + pre_value = sum_label[pos] + textgrid_flist.close() + segment_file.close() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/alimeeting/sa-asr/local/text_format.pl b/egs/alimeeting/sa-asr/local/text_format.pl new file mode 100755 index 000000000..45f1f6428 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/text_format.pl @@ -0,0 +1,14 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter +# Copyright Chao Weng + +# normalizations for hkust trascript +# see the docs/trans-guidelines.pdf for details + +while () { + @A = split(" ", $_); + if (@A == 1) { + next; + } + print $_ +} diff --git a/egs/alimeeting/sa-asr/local/text_normalize.pl b/egs/alimeeting/sa-asr/local/text_normalize.pl new file mode 100755 index 000000000..ac301d466 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/text_normalize.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter +# Copyright Chao Weng + +# normalizations for hkust trascript +# see the docs/trans-guidelines.pdf for details + +while () { + @A = split(" ", $_); + print "$A[0] "; + for ($n = 1; $n < @A; $n++) { + $tmp = $A[$n]; + if ($tmp =~ //) {$tmp =~ s:::g;} + if ($tmp =~ /<%>/) {$tmp =~ s:<%>::g;} + if ($tmp =~ /<->/) {$tmp =~ s:<->::g;} + if ($tmp =~ /<\$>/) {$tmp =~ s:<\$>::g;} + if ($tmp =~ /<#>/) {$tmp =~ s:<#>::g;} + if ($tmp =~ /<_>/) {$tmp =~ s:<_>::g;} + if ($tmp =~ //) {$tmp =~ s:::g;} + if ($tmp =~ /`/) {$tmp =~ s:`::g;} + if ($tmp =~ /&/) {$tmp =~ s:&::g;} + if ($tmp =~ /,/) {$tmp =~ s:,::g;} + if ($tmp =~ /[a-zA-Z]/) {$tmp=uc($tmp);} + if ($tmp =~ /A/) {$tmp =~ s:A:A:g;} + if ($tmp =~ /a/) {$tmp =~ s:a:A:g;} + if ($tmp =~ /b/) {$tmp =~ s:b:B:g;} + if ($tmp =~ /c/) {$tmp =~ s:c:C:g;} + if ($tmp =~ /k/) {$tmp =~ s:k:K:g;} + if ($tmp =~ /t/) {$tmp =~ s:t:T:g;} + if ($tmp =~ /,/) {$tmp =~ s:,::g;} + if ($tmp =~ /丶/) {$tmp =~ s:丶::g;} + if ($tmp =~ /。/) {$tmp =~ s:。::g;} + if ($tmp =~ /、/) {$tmp =~ s:、::g;} + if ($tmp =~ /?/) {$tmp =~ s:?::g;} + print "$tmp "; + } + print "\n"; +} diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa-asr/path.sh new file mode 100755 index 000000000..3aa13d0c2 --- /dev/null +++ b/egs/alimeeting/sa-asr/path.sh @@ -0,0 +1,6 @@ +export FUNASR_DIR=$PWD/../../.. + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PATH=$FUNASR_DIR/funasr/bin:$PATH +export PATH=$PWD/utils/:$PATH \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py new file mode 100755 index 000000000..1fd63d690 --- /dev/null +++ b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +import argparse +import logging +from io import BytesIO +from pathlib import Path +from typing import Tuple, Optional + +import kaldiio +import humanfriendly +import numpy as np +import resampy +import soundfile +from tqdm import tqdm +from typeguard import check_argument_types + +from funasr.utils.cli_utils import get_commandline_args +from funasr.fileio.read_text import read_2column_text +from funasr.fileio.sound_scp import SoundScpWriter + + +def humanfriendly_or_none(value: str): + if value in ("none", "None", "NONE"): + return None + return humanfriendly.parse_size(value) + + +def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]: + """ + + >>> str2int_tuple('3,4,5') + (3, 4, 5) + + """ + assert check_argument_types() + if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"): + return None + return tuple(map(int, integers.strip().split(","))) + + +def main(): + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + logging.basicConfig(level=logging.INFO, format=logfmt) + logging.info(get_commandline_args()) + + parser = argparse.ArgumentParser( + description='Create waves list from "wav.scp"', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("scp") + parser.add_argument("outdir") + parser.add_argument( + "--name", + default="wav", + help="Specify the prefix word of output file name " 'such as "wav.scp"', + ) + parser.add_argument("--segments", default=None) + parser.add_argument( + "--fs", + type=humanfriendly_or_none, + default=None, + help="If the sampling rate specified, " "Change the sampling rate.", + ) + parser.add_argument("--audio-format", default="wav") + group = parser.add_mutually_exclusive_group() + group.add_argument("--ref-channels", default=None, type=str2int_tuple) + group.add_argument("--utt2ref-channels", default=None, type=str) + args = parser.parse_args() + + out_num_samples = Path(args.outdir) / f"utt2num_samples" + + if args.ref_channels is not None: + + def utt2ref_channels(x) -> Tuple[int, ...]: + return args.ref_channels + + elif args.utt2ref_channels is not None: + utt2ref_channels_dict = read_2column_text(args.utt2ref_channels) + + def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]: + chs_str = d[x] + return tuple(map(int, chs_str.split())) + + else: + utt2ref_channels = None + + Path(args.outdir).mkdir(parents=True, exist_ok=True) + out_wavscp = Path(args.outdir) / f"{args.name}.scp" + if args.segments is not None: + # Note: kaldiio supports only wav-pcm-int16le file. + loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments) + if args.audio_format.endswith("ark"): + fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") + fscp = out_wavscp.open("w") + else: + writer = SoundScpWriter( + args.outdir, + out_wavscp, + format=args.audio_format, + ) + + with out_num_samples.open("w") as fnum_samples: + for uttid, (rate, wave) in tqdm(loader): + # wave: (Time,) or (Time, Nmic) + if wave.ndim == 2 and utt2ref_channels is not None: + wave = wave[:, utt2ref_channels(uttid)] + + if args.fs is not None and args.fs != rate: + # FIXME(kamo): To use sox? + wave = resampy.resample( + wave.astype(np.float64), rate, args.fs, axis=0 + ) + wave = wave.astype(np.int16) + rate = args.fs + if args.audio_format.endswith("ark"): + if "flac" in args.audio_format: + suf = "flac" + elif "wav" in args.audio_format: + suf = "wav" + else: + raise RuntimeError("wav.ark or flac") + + # NOTE(kamo): Using extended ark format style here. + # This format is incompatible with Kaldi + kaldiio.save_ark( + fark, + {uttid: (wave, rate)}, + scp=fscp, + append=True, + write_function=f"soundfile_{suf}", + ) + + else: + writer[uttid] = rate, wave + fnum_samples.write(f"{uttid} {len(wave)}\n") + else: + if args.audio_format.endswith("ark"): + fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") + else: + wavdir = Path(args.outdir) / f"data_{args.name}" + wavdir.mkdir(parents=True, exist_ok=True) + + with Path(args.scp).open("r") as fscp, out_wavscp.open( + "w" + ) as fout, out_num_samples.open("w") as fnum_samples: + for line in tqdm(fscp): + uttid, wavpath = line.strip().split(None, 1) + + if wavpath.endswith("|"): + # Streaming input e.g. cat a.wav | + with kaldiio.open_like_kaldi(wavpath, "rb") as f: + with BytesIO(f.read()) as g: + wave, rate = soundfile.read(g, dtype=np.int16) + if wave.ndim == 2 and utt2ref_channels is not None: + wave = wave[:, utt2ref_channels(uttid)] + + if args.fs is not None and args.fs != rate: + # FIXME(kamo): To use sox? + wave = resampy.resample( + wave.astype(np.float64), rate, args.fs, axis=0 + ) + wave = wave.astype(np.int16) + rate = args.fs + + if args.audio_format.endswith("ark"): + if "flac" in args.audio_format: + suf = "flac" + elif "wav" in args.audio_format: + suf = "wav" + else: + raise RuntimeError("wav.ark or flac") + + # NOTE(kamo): Using extended ark format style here. + # This format is incompatible with Kaldi + kaldiio.save_ark( + fark, + {uttid: (wave, rate)}, + scp=fout, + append=True, + write_function=f"soundfile_{suf}", + ) + else: + owavpath = str(wavdir / f"{uttid}.{args.audio_format}") + soundfile.write(owavpath, wave, rate) + fout.write(f"{uttid} {owavpath}\n") + else: + wave, rate = soundfile.read(wavpath, dtype=np.int16) + if wave.ndim == 2 and utt2ref_channels is not None: + wave = wave[:, utt2ref_channels(uttid)] + save_asis = False + + elif args.audio_format.endswith("ark"): + save_asis = False + + elif Path(wavpath).suffix == "." + args.audio_format and ( + args.fs is None or args.fs == rate + ): + save_asis = True + + else: + save_asis = False + + if save_asis: + # Neither --segments nor --fs are specified and + # the line doesn't end with "|", + # i.e. not using unix-pipe, + # only in this case, + # just using the original file as is. + fout.write(f"{uttid} {wavpath}\n") + else: + if args.fs is not None and args.fs != rate: + # FIXME(kamo): To use sox? + wave = resampy.resample( + wave.astype(np.float64), rate, args.fs, axis=0 + ) + wave = wave.astype(np.int16) + rate = args.fs + + if args.audio_format.endswith("ark"): + if "flac" in args.audio_format: + suf = "flac" + elif "wav" in args.audio_format: + suf = "wav" + else: + raise RuntimeError("wav.ark or flac") + + # NOTE(kamo): Using extended ark format style here. + # This format is not supported in Kaldi. + kaldiio.save_ark( + fark, + {uttid: (wave, rate)}, + scp=fout, + append=True, + write_function=f"soundfile_{suf}", + ) + else: + owavpath = str(wavdir / f"{uttid}.{args.audio_format}") + soundfile.write(owavpath, wave, rate) + fout.write(f"{uttid} {owavpath}\n") + fnum_samples.write(f"{uttid} {len(wave)}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py new file mode 100755 index 000000000..b0c61e5b4 --- /dev/null +++ b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +import sys + + +def get_commandline_args(no_executable=True): + extra_chars = [ + " ", + ";", + "&", + "|", + "<", + ">", + "?", + "*", + "~", + "`", + '"', + "'", + "\\", + "{", + "}", + "(", + ")", + ] + + # Escape the extra characters for shell + argv = [ + arg.replace("'", "'\\''") + if all(char not in arg for char in extra_chars) + else "'" + arg.replace("'", "'\\''") + "'" + for arg in sys.argv + ] + + if no_executable: + return " ".join(argv[1:]) + else: + return sys.executable + " " + " ".join(argv) + + +def main(): + print(get_commandline_args()) + + +if __name__ == "__main__": + main() diff --git a/egs/alimeeting/sa-asr/run_m2met_2023.sh b/egs/alimeeting/sa-asr/run_m2met_2023.sh new file mode 100755 index 000000000..807e49948 --- /dev/null +++ b/egs/alimeeting/sa-asr/run_m2met_2023.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +ngpu=4 +device="0,1,2,3" + +#stage 1 creat both near and far +stage=1 +stop_stage=18 + + +train_set=Train_Ali_far +valid_set=Eval_Ali_far +test_sets="Test_Ali_far" +asr_config=conf/train_asr_conformer.yaml +sa_asr_config=conf/train_sa_asr_conformer.yaml +inference_config=conf/decode_asr_rnn.yaml + +lm_config=conf/train_lm_transformer.yaml +use_lm=false +use_wordlm=false +./asr_local.sh \ + --device ${device} \ + --ngpu ${ngpu} \ + --stage ${stage} \ + --stop_stage ${stop_stage} \ + --gpu_inference true \ + --njob_infer 4 \ + --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \ + --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \ + --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \ + --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \ + --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \ + --lang zh \ + --audio_format wav \ + --feats_type raw \ + --token_type char \ + --use_lm ${use_lm} \ + --use_word_lm ${use_wordlm} \ + --lm_config "${lm_config}" \ + --asr_config "${asr_config}" \ + --sa_asr_config "${sa_asr_config}" \ + --inference_config "${inference_config}" \ + --train_set "${train_set}" \ + --valid_set "${valid_set}" \ + --test_sets "${test_sets}" \ + --lm_train_text "data/${train_set}/text" "$@" diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh new file mode 100755 index 000000000..d35e6a693 --- /dev/null +++ b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +ngpu=4 +device="0,1,2,3" + +stage=1 +stop_stage=4 + + +train_set=Train_Ali_far +valid_set=Eval_Ali_far +test_sets="Test_2023_Ali_far" +asr_config=conf/train_asr_conformer.yaml +sa_asr_config=conf/train_sa_asr_conformer.yaml +inference_config=conf/decode_asr_rnn.yaml + +lm_config=conf/train_lm_transformer.yaml +use_lm=false +use_wordlm=false +./asr_local_infer.sh \ + --device ${device} \ + --ngpu ${ngpu} \ + --stage ${stage} \ + --stop_stage ${stop_stage} \ + --gpu_inference true \ + --njob_infer 4 \ + --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \ + --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \ + --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \ + --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \ + --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \ + --lang zh \ + --audio_format wav \ + --feats_type raw \ + --token_type char \ + --use_lm ${use_lm} \ + --use_word_lm ${use_wordlm} \ + --lm_config "${lm_config}" \ + --asr_config "${asr_config}" \ + --sa_asr_config "${sa_asr_config}" \ + --inference_config "${inference_config}" \ + --train_set "${train_set}" \ + --valid_set "${valid_set}" \ + --test_sets "${test_sets}" \ + --lm_train_text "data/${train_set}/text" "$@" diff --git a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh new file mode 100755 index 000000000..15e4563f1 --- /dev/null +++ b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +set -euo pipefail +SECONDS=0 +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} +help_message=$(cat << EOF +Usage: $0 [ []] +e.g. +$0 data/test/wav.scp data/test_format/ + +Format 'wav.scp': In short words, +changing "kaldi-datadir" to "modified-kaldi-datadir" + +The 'wav.scp' format in kaldi is very flexible, +e.g. It can use unix-pipe as describing that wav file, +but it sometime looks confusing and make scripts more complex. +This tools creates actual wav files from 'wav.scp' +and also segments wav files using 'segments'. + +Options + --fs + --segments + --nj + --cmd +EOF +) + +out_filename=wav.scp +cmd=utils/run.pl +nj=30 +fs=none +segments= + +ref_channels= +utt2ref_channels= + +audio_format=wav +write_utt2num_samples=true + +log "$0 $*" +. utils/parse_options.sh + +if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then + log "${help_message}" + log "Error: invalid command line arguments" + exit 1 +fi + +. ./path.sh # Setup the environment + +scp=$1 +if [ ! -f "${scp}" ]; then + log "${help_message}" + echo "$0: Error: No such file: ${scp}" + exit 1 +fi +dir=$2 + + +if [ $# -eq 2 ]; then + logdir=${dir}/logs + outdir=${dir}/data + +elif [ $# -eq 3 ]; then + logdir=$3 + outdir=${dir}/data + +elif [ $# -eq 4 ]; then + logdir=$3 + outdir=$4 +fi + + +mkdir -p ${logdir} + +rm -f "${dir}/${out_filename}" + + +opts= +if [ -n "${utt2ref_channels}" ]; then + opts="--utt2ref-channels ${utt2ref_channels} " +elif [ -n "${ref_channels}" ]; then + opts="--ref-channels ${ref_channels} " +fi + + +if [ -n "${segments}" ]; then + log "[info]: using ${segments}" + nutt=$(<${segments} wc -l) + nj=$((nj /dev/null + +# concatenate the .scp files together. +for n in $(seq ${nj}); do + cat "${outdir}/format.${n}/wav.scp" || exit 1; +done > "${dir}/${out_filename}" || exit 1 + +if "${write_utt2num_samples}"; then + for n in $(seq ${nj}); do + cat "${outdir}/format.${n}/utt2num_samples" || exit 1; + done > "${dir}/utt2num_samples" || exit 1 +fi + +log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh new file mode 100755 index 000000000..9e08dba72 --- /dev/null +++ b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash + +# 2020 @kamo-naoyuki +# This file was copied from Kaldi and +# I deleted parts related to wav duration +# because we shouldn't use kaldi's command here +# and we don't need the files actually. + +# Copyright 2013 Johns Hopkins University (author: Daniel Povey) +# 2014 Tom Ko +# 2018 Emotech LTD (author: Pawel Swietojanski) +# Apache 2.0 + +# This script operates on a directory, such as in data/train/, +# that contains some subset of the following files: +# wav.scp +# spk2utt +# utt2spk +# text +# +# It generates the files which are used for perturbing the speed of the original data. + +export LC_ALL=C +set -euo pipefail + +if [[ $# != 3 ]]; then + echo "Usage: perturb_data_dir_speed.sh " + echo "e.g.:" + echo " $0 0.9 data/train_si284 data/train_si284p" + exit 1 +fi + +factor=$1 +srcdir=$2 +destdir=$3 +label="sp" +spk_prefix="${label}${factor}-" +utt_prefix="${label}${factor}-" + +#check is sox on the path + +! command -v sox &>/dev/null && echo "sox: command not found" && exit 1; + +if [[ ! -f ${srcdir}/utt2spk ]]; then + echo "$0: no such file ${srcdir}/utt2spk" + exit 1; +fi + +if [[ ${destdir} == "${srcdir}" ]]; then + echo "$0: this script requires and to be different." + exit 1 +fi + +mkdir -p "${destdir}" + +<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map" +<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map" +<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map" +if [[ ! -f ${srcdir}/utt2uniq ]]; then + <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq" +else + <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq" +fi + + +<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \ + utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk + +utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt + +if [[ -f ${srcdir}/segments ]]; then + + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \ + utils/apply_map.pl -f 2 "${destdir}"/reco_map | \ + awk -v factor="${factor}" \ + '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \ + >"${destdir}"/segments + + utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ + # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" + awk -v factor="${factor}" \ + '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} + else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } + else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ + > "${destdir}"/wav.scp + if [[ -f ${srcdir}/reco2file_and_channel ]]; then + utils/apply_map.pl -f 1 "${destdir}"/reco_map \ + <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel + fi + +else # no segments->wav indexed by utterance. + if [[ -f ${srcdir}/wav.scp ]]; then + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ + # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" + awk -v factor="${factor}" \ + '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} + else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } + else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ + > "${destdir}"/wav.scp + fi +fi + +if [[ -f ${srcdir}/text ]]; then + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text +fi +if [[ -f ${srcdir}/spk2gender ]]; then + utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender +fi +if [[ -f ${srcdir}/utt2lang ]]; then + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang +fi + +rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null +echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}" + +utils/validate_data_dir.sh --no-feats --no-text "${destdir}" diff --git a/egs/alimeeting/sa-asr/utils/apply_map.pl b/egs/alimeeting/sa-asr/utils/apply_map.pl new file mode 100755 index 000000000..725d3463a --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/apply_map.pl @@ -0,0 +1,97 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0. + +# This program is a bit like ./sym2int.pl in that it applies a map +# to things in a file, but it's a bit more general in that it doesn't +# assume the things being mapped to are single tokens, they could +# be sequences of tokens. See the usage message. + + +$permissive = 0; + +for ($x = 0; $x <= 2; $x++) { + + if (@ARGV > 0 && $ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } + } + + if (@ARGV > 0 && $ARGV[0] eq '--permissive') { + shift @ARGV; + # Mapping is optional (missing key is printed to output) + $permissive = 1; + } +} + +if(@ARGV != 1) { + print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n"; + print STDERR <<'EOF'; +Usage: apply_map.pl [options] map output + options: [-f ] [--permissive] + This applies a map to some specified fields of some input text: + For each line in the map file: the first field is the thing we + map from, and the remaining fields are the sequence we map it to. + The -f (field-range) option says which fields of the input file the map + map should apply to. + If the --permissive option is supplied, fields which are not present + in the map will be left as they were. + Applies the map 'map' to all input text, where each line of the map + is interpreted as a map from the first field to the list of the other fields + Note: can look like 4-5, or 4-, or 5-, or 1, it means the field + range in the input to apply the map to. + e.g.: echo A B | apply_map.pl a.txt + where a.txt is: + A a1 a2 + B b + will produce: + a1 a2 b +EOF + exit(1); +} + +($map_file) = @ARGV; +open(M, "<$map_file") || die "Error opening map file $map_file: $!"; + +while () { + @A = split(" ", $_); + @A >= 1 || die "apply_map.pl: empty line."; + $i = shift @A; + $o = join(" ", @A); + $map{$i} = $o; +} + +while() { + @A = split(" ", $_); + for ($x = 0; $x < @A; $x++) { + if ( (!defined $field_begin || $x >= $field_begin) + && (!defined $field_end || $x <= $field_end)) { + $a = $A[$x]; + if (!defined $map{$a}) { + if (!$permissive) { + die "apply_map.pl: undefined key $a in $map_file\n"; + } else { + print STDERR "apply_map.pl: warning! missing key $a in $map_file\n"; + } + } else { + $A[$x] = $map{$a}; + } + } + } + print join(" ", @A) . "\n"; +} diff --git a/egs/alimeeting/sa-asr/utils/combine_data.sh b/egs/alimeeting/sa-asr/utils/combine_data.sh new file mode 100755 index 000000000..e1eba8539 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/combine_data.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# 2014 David Snyder + +# This script combines the data from multiple source directories into +# a single destination directory. + +# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information +# about what these directories contain. + +# Begin configuration section. +extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..." +skip_fix=false # skip the fix_data_dir.sh in the end +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# -lt 2 ]; then + echo "Usage: combine_data.sh [--extra-files 'file1 file2'] ..." + echo "Note, files that don't appear in all source dirs will not be combined," + echo "with the exception of utt2uniq and segments, which are created where necessary." + exit 1 +fi + +dest=$1; +shift; + +first_src=$1; + +rm -r $dest 2>/dev/null || true +mkdir -p $dest; + +export LC_ALL=C + +for dir in $*; do + if [ ! -f $dir/utt2spk ]; then + echo "$0: no such file $dir/utt2spk" + exit 1; + fi +done + +# Check that frame_shift are compatible, where present together with features. +dir_with_frame_shift= +for dir in $*; do + if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then + if [[ $dir_with_frame_shift ]] && + ! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then + echo "$0:error: different frame_shift in directories $dir and " \ + "$dir_with_frame_shift. Cannot combine features." + exit 1; + fi + dir_with_frame_shift=$dir + fi +done + +# W.r.t. utt2uniq file the script has different behavior compared to other files +# it is not compulsary for it to exist in src directories, but if it exists in +# even one it should exist in all. We will create the files where necessary +has_utt2uniq=false +for in_dir in $*; do + if [ -f $in_dir/utt2uniq ]; then + has_utt2uniq=true + break + fi +done + +if $has_utt2uniq; then + # we are going to create an utt2uniq file in the destdir + for in_dir in $*; do + if [ ! -f $in_dir/utt2uniq ]; then + # we assume that utt2uniq is a one to one mapping + cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}' + else + cat $in_dir/utt2uniq + fi + done | sort -k1 > $dest/utt2uniq + echo "$0: combined utt2uniq" +else + echo "$0 [info]: not combining utt2uniq as it does not exist" +fi +# some of the old scripts might provide utt2uniq as an extrafile, so just remove it +extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g") + +# segments are treated similarly to utt2uniq. If it exists in some, but not all +# src directories, then we generate segments where necessary. +has_segments=false +for in_dir in $*; do + if [ -f $in_dir/segments ]; then + has_segments=true + break + fi +done + +if $has_segments; then + for in_dir in $*; do + if [ ! -f $in_dir/segments ]; then + echo "$0 [info]: will generate missing segments for $in_dir" 1>&2 + utils/data/get_segments_for_data.sh $in_dir + else + cat $in_dir/segments + fi + done | sort -k1 > $dest/segments + echo "$0: combined segments" +else + echo "$0 [info]: not combining segments as it does not exist" +fi + +for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do + exists_somewhere=false + absent_somewhere=false + for d in $*; do + if [ -f $d/$file ]; then + exists_somewhere=true + else + absent_somewhere=true + fi + done + + if ! $absent_somewhere; then + set -o pipefail + ( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1; + set +o pipefail + echo "$0: combined $file" + else + if ! $exists_somewhere; then + echo "$0 [info]: not combining $file as it does not exist" + else + echo "$0 [info]: **not combining $file as it does not exist everywhere**" + fi + fi +done + +utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt + +if [[ $dir_with_frame_shift ]]; then + cp $dir_with_frame_shift/frame_shift $dest +fi + +if ! $skip_fix ; then + utils/fix_data_dir.sh $dest || exit 1; +fi + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/copy_data_dir.sh b/egs/alimeeting/sa-asr/utils/copy_data_dir.sh new file mode 100755 index 000000000..9fd420c42 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/copy_data_dir.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash + +# Copyright 2013 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +# This script operates on a directory, such as in data/train/, +# that contains some subset of the following files: +# feats.scp +# wav.scp +# vad.scp +# spk2utt +# utt2spk +# text +# +# It copies to another directory, possibly adding a specified prefix or a suffix +# to the utterance and/or speaker names. Note, the recording-ids stay the same. +# + + +# begin configuration section +spk_prefix= +utt_prefix= +spk_suffix= +utt_suffix= +validate_opts= # should rarely be needed. +# end configuration section + +. utils/parse_options.sh + +if [ $# != 2 ]; then + echo "Usage: " + echo " $0 [options] " + echo "e.g.:" + echo " $0 --spk-prefix=1- --utt-prefix=1- data/train data/train_1" + echo "Options" + echo " --spk-prefix= # Prefix for speaker ids, default empty" + echo " --utt-prefix= # Prefix for utterance ids, default empty" + echo " --spk-suffix= # Suffix for speaker ids, default empty" + echo " --utt-suffix= # Suffix for utterance ids, default empty" + exit 1; +fi + + +export LC_ALL=C + +srcdir=$1 +destdir=$2 + +if [ ! -f $srcdir/utt2spk ]; then + echo "copy_data_dir.sh: no such file $srcdir/utt2spk" + exit 1; +fi + +if [ "$destdir" == "$srcdir" ]; then + echo "$0: this script requires and to be different." + exit 1 +fi + +set -e; + +mkdir -p $destdir + +cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/utt_map +cat $srcdir/spk2utt | awk -v p=$spk_prefix -v s=$spk_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/spk_map + +if [ ! -f $srcdir/utt2uniq ]; then + if [[ ! -z $utt_prefix || ! -z $utt_suffix ]]; then + cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $1);}' > $destdir/utt2uniq + fi +else + cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq +fi + +cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \ + utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk + +utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt + +if [ -f $srcdir/feats.scp ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp +fi + +if [ -f $srcdir/vad.scp ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp +fi + +if [ -f $srcdir/segments ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments + cp $srcdir/wav.scp $destdir +else # no segments->wav indexed by utt. + if [ -f $srcdir/wav.scp ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp + fi +fi + +if [ -f $srcdir/reco2file_and_channel ]; then + cp $srcdir/reco2file_and_channel $destdir/ +fi + +if [ -f $srcdir/text ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text +fi +if [ -f $srcdir/utt2dur ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur +fi +if [ -f $srcdir/utt2num_frames ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames +fi +if [ -f $srcdir/reco2dur ]; then + if [ -f $srcdir/segments ]; then + cp $srcdir/reco2dur $destdir/reco2dur + else + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur + fi +fi +if [ -f $srcdir/spk2gender ]; then + utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender +fi +if [ -f $srcdir/cmvn.scp ]; then + utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp +fi +for f in frame_shift stm glm ctm; do + if [ -f $srcdir/$f ]; then + cp $srcdir/$f $destdir + fi +done + +rm $destdir/spk_map $destdir/utt_map + +echo "$0: copied data from $srcdir to $destdir" + +for f in feats.scp cmvn.scp vad.scp utt2lang utt2uniq utt2dur utt2num_frames text wav.scp reco2file_and_channel frame_shift stm glm ctm; do + if [ -f $destdir/$f ] && [ ! -f $srcdir/$f ]; then + echo "$0: file $f exists in dest $destdir but not in src $srcdir. Moving it to" + echo " ... $destdir/.backup/$f" + mkdir -p $destdir/.backup + mv $destdir/$f $destdir/.backup/ + fi +done + + +[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" +[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text" + +utils/validate_data_dir.sh $validate_opts $destdir diff --git a/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh b/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh new file mode 100755 index 000000000..24f51e723 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash + +# Copyright 2016 Johns Hopkins University (author: Daniel Povey) +# 2018 Andrea Carmantini +# Apache 2.0 + +# This script operates on a data directory, such as in data/train/, and adds the +# reco2dur file if it does not already exist. The file 'reco2dur' maps from +# recording to the duration of the recording in seconds. This script works it +# out from the 'wav.scp' file, or, if utterance-ids are the same as recording-ids, from the +# utt2dur file (it first tries interrogating the headers, and if this fails, it reads the wave +# files in entirely.) +# We could use durations from segments file, but that's not the duration of the recordings +# but the sum of utterance lenghts (silence in between could be excluded from segments) +# For sum of utterance lenghts: +# awk 'FNR==NR{uttdur[$1]=$2;next} +# { for(i=2;i<=NF;i++){dur+=uttdur[$i];} +# print $1 FS dur; dur=0 }' $data/utt2dur $data/reco2utt + + +frame_shift=0.01 +cmd=run.pl +nj=4 + +. utils/parse_options.sh +. ./path.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 data/train" + echo " Options:" + echo " --frame-shift # frame shift in seconds. Only relevant when we are" + echo " # getting duration from feats.scp (default: 0.01). " + exit 1 +fi + +export LC_ALL=C + +data=$1 + + +if [ -s $data/reco2dur ] && \ + [ $(wc -l < $data/wav.scp) -eq $(wc -l < $data/reco2dur) ]; then + echo "$0: $data/reco2dur already exists with the expected length. We won't recompute it." + exit 0; +fi + +if [ -s $data/utt2dur ] && \ + [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ] && \ + [ ! -s $data/segments ]; then + + echo "$0: $data/wav.scp indexed by utt-id; copying utt2dur to reco2dur" + cp $data/utt2dur $data/reco2dur && exit 0; + +elif [ -f $data/wav.scp ]; then + echo "$0: obtaining durations from recordings" + + # if the wav.scp contains only lines of the form + # utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph | + if cat $data/wav.scp | perl -e ' + while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space. + @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ && + $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); } + $reco = $A[0]; $sphere_file = $A[4]; + + if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; } + $sample_rate = -1; $sample_count = -1; + for ($n = 0; $n <= 30; $n++) { + $line = ; + if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; } + if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; } + if ($line =~ m/end_head/) { break; } + } + close(F); + if ($sample_rate == -1 || $sample_count == -1) { + die "could not parse sphere header from $sphere_file"; + } + $duration = $sample_count * 1.0 / $sample_rate; + print "$reco $duration\n"; + } ' > $data/reco2dur; then + echo "$0: successfully obtained recording lengths from sphere-file headers" + else + echo "$0: could not get recording lengths from sphere-file headers, using wav-to-duration" + if ! command -v wav-to-duration >/dev/null; then + echo "$0: wav-to-duration is not on your path" + exit 1; + fi + + read_entire_file=false + if grep -q 'sox.*speed' $data/wav.scp; then + read_entire_file=true + echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow." + echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or " + echo "... perturb_data_dir_speed_3way.sh." + fi + + num_recos=$(wc -l <$data/wav.scp) + if [ $nj -gt $num_recos ]; then + nj=$num_recos + fi + + temp_data_dir=$data/wav${nj}split + wavscps=$(for n in `seq $nj`; do echo $temp_data_dir/$n/wav.scp; done) + subdirs=$(for n in `seq $nj`; do echo $temp_data_dir/$n; done) + + if ! mkdir -p $subdirs >&/dev/null; then + for n in `seq $nj`; do + mkdir -p $temp_data_dir/$n + done + fi + + utils/split_scp.pl $data/wav.scp $wavscps + + + $cmd JOB=1:$nj $data/log/get_reco_durations.JOB.log \ + wav-to-duration --read-entire-file=$read_entire_file \ + scp:$temp_data_dir/JOB/wav.scp ark,t:$temp_data_dir/JOB/reco2dur || \ + { echo "$0: there was a problem getting the durations"; exit 1; } # This could + + for n in `seq $nj`; do + cat $temp_data_dir/$n/reco2dur + done > $data/reco2dur + fi + rm -r $temp_data_dir +else + echo "$0: Expected $data/wav.scp to exist" + exit 1 +fi + +len1=$(wc -l < $data/wav.scp) +len2=$(wc -l < $data/reco2dur) +if [ "$len1" != "$len2" ]; then + echo "$0: warning: length of reco2dur does not equal that of wav.scp, $len2 != $len1" + if [ $len1 -gt $[$len2*2] ]; then + echo "$0: less than half of recordings got a duration: failing." + exit 1 + fi +fi + +echo "$0: computed $data/reco2dur" + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh b/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh new file mode 100755 index 000000000..6b161b31e --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +# This script operates on a data directory, such as in data/train/, +# and writes new segments to stdout. The file 'segments' maps from +# utterance to time offsets into a recording, with the format: +# +# This script assumes utterance and recording ids are the same (i.e., that +# wav.scp is indexed by utterance), and uses durations from 'utt2dur', +# created if necessary by get_utt2dur.sh. + +. ./path.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 data/train > data/train/segments" + exit 1 +fi + +data=$1 + +if [ ! -s $data/utt2dur ]; then + utils/data/get_utt2dur.sh $data 1>&2 || exit 1; +fi + +# 0 +awk '{ print $1, $1, 0, $2 }' $data/utt2dur + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh b/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh new file mode 100755 index 000000000..5ee7ea30d --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh @@ -0,0 +1,135 @@ +#!/usr/bin/env bash + +# Copyright 2016 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +# This script operates on a data directory, such as in data/train/, and adds the +# utt2dur file if it does not already exist. The file 'utt2dur' maps from +# utterance to the duration of the utterance in seconds. This script works it +# out from the 'segments' file, or, if not present, from the wav.scp file (it +# first tries interrogating the headers, and if this fails, it reads the wave +# files in entirely.) + +frame_shift=0.01 +cmd=run.pl +nj=4 +read_entire_file=false + +. utils/parse_options.sh +. ./path.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 data/train" + echo " Options:" + echo " --frame-shift # frame shift in seconds. Only relevant when we are" + echo " # getting duration from feats.scp, and only if the " + echo " # file frame_shift does not exist (default: 0.01). " + exit 1 +fi + +export LC_ALL=C + +data=$1 + +if [ -s $data/utt2dur ] && \ + [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ]; then + echo "$0: $data/utt2dur already exists with the expected length. We won't recompute it." + exit 0; +fi + +if [ -s $data/segments ]; then + echo "$0: working out $data/utt2dur from $data/segments" + awk '{len=$4-$3; print $1, len;}' < $data/segments > $data/utt2dur +elif [[ -s $data/frame_shift && -f $data/utt2num_frames ]]; then + echo "$0: computing $data/utt2dur from $data/{frame_shift,utt2num_frames}." + frame_shift=$(cat $data/frame_shift) || exit 1 + # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift. + awk -v fs=$frame_shift '{ $2=($2+1.5)*fs; print }' <$data/utt2num_frames >$data/utt2dur +elif [ -f $data/wav.scp ]; then + echo "$0: segments file does not exist so getting durations from wave files" + + # if the wav.scp contains only lines of the form + # utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph | + if perl <$data/wav.scp -e ' + while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space. + @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ && + $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); } + $utt = $A[0]; $sphere_file = $A[4]; + + if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; } + $sample_rate = -1; $sample_count = -1; + for ($n = 0; $n <= 30; $n++) { + $line = ; + if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; } + if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; } + if ($line =~ m/end_head/) { break; } + } + close(F); + if ($sample_rate == -1 || $sample_count == -1) { + die "could not parse sphere header from $sphere_file"; + } + $duration = $sample_count * 1.0 / $sample_rate; + print "$utt $duration\n"; + } ' > $data/utt2dur; then + echo "$0: successfully obtained utterance lengths from sphere-file headers" + else + echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration" + if ! command -v wav-to-duration >/dev/null; then + echo "$0: wav-to-duration is not on your path" + exit 1; + fi + + if grep -q 'sox.*speed' $data/wav.scp; then + read_entire_file=true + echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow." + echo "... It is much faster if you call get_utt2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or " + echo "... perturb_data_dir_speed_3way.sh." + fi + + + num_utts=$(wc -l <$data/utt2spk) + if [ $nj -gt $num_utts ]; then + nj=$num_utts + fi + + utils/data/split_data.sh --per-utt $data $nj + sdata=$data/split${nj}utt + + $cmd JOB=1:$nj $data/log/get_durations.JOB.log \ + wav-to-duration --read-entire-file=$read_entire_file \ + scp:$sdata/JOB/wav.scp ark,t:$sdata/JOB/utt2dur || \ + { echo "$0: there was a problem getting the durations"; exit 1; } + + for n in `seq $nj`; do + cat $sdata/$n/utt2dur + done > $data/utt2dur + fi +elif [ -f $data/feats.scp ]; then + echo "$0: wave file does not exist so getting durations from feats files" + if [[ -s $data/frame_shift ]]; then + frame_shift=$(cat $data/frame_shift) || exit 1 + echo "$0: using frame_shift=$frame_shift from file $data/frame_shift" + fi + # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift. + feat-to-len scp:$data/feats.scp ark,t:- | + awk -v frame_shift=$frame_shift '{print $1, ($2+1.5)*frame_shift}' >$data/utt2dur +else + echo "$0: Expected $data/wav.scp, $data/segments or $data/feats.scp to exist" + exit 1 +fi + +len1=$(wc -l < $data/utt2spk) +len2=$(wc -l < $data/utt2dur) +if [ "$len1" != "$len2" ]; then + echo "$0: warning: length of utt2dur does not equal that of utt2spk, $len2 != $len1" + if [ $len1 -gt $[$len2*2] ]; then + echo "$0: less than half of utterances got a duration: failing." + exit 1 + fi +fi + +echo "$0: computed $data/utt2dur" + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/data/split_data.sh b/egs/alimeeting/sa-asr/utils/data/split_data.sh new file mode 100755 index 000000000..8aa71a1f2 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/data/split_data.sh @@ -0,0 +1,160 @@ +#!/usr/bin/env bash +# Copyright 2010-2013 Microsoft Corporation +# Johns Hopkins University (Author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +split_per_spk=true +if [ "$1" == "--per-utt" ]; then + split_per_spk=false + shift +fi + +if [ $# != 2 ]; then + echo "Usage: $0 [--per-utt] " + echo "E.g.: $0 data/train 50" + echo "It creates its output in e.g. data/train/split50/{1,2,3,...50}, or if the " + echo "--per-utt option was given, in e.g. data/train/split50utt/{1,2,3,...50}." + echo "" + echo "This script will not split the data-dir if it detects that the output is newer than the input." + echo "By default it splits per speaker (so each speaker is in only one split dir)," + echo "but with the --per-utt option it will ignore the speaker information while splitting." + exit 1 +fi + +data=$1 +numsplit=$2 + +if ! [ "$numsplit" -gt 0 ]; then + echo "Invalid num-split argument $numsplit"; + exit 1; +fi + +if $split_per_spk; then + warning_opt= +else + # suppress warnings from filter_scps.pl about 'some input lines were output + # to multiple files'. + warning_opt="--no-warn" +fi + +n=0; +feats="" +wavs="" +utt2spks="" +texts="" + +nu=`cat $data/utt2spk | wc -l` +nf=`cat $data/feats.scp 2>/dev/null | wc -l` +nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file +if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then + echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can " + echo "** use utils/fix_data_dir.sh $data to fix this." +fi +if [ -f $data/text ] && [ $nu -ne $nt ]; then + echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can " + echo "** use utils/fix_data_dir.sh to fix this." +fi + + +if $split_per_spk; then + utt2spk_opt="--utt2spk=$data/utt2spk" + utt="" +else + utt2spk_opt= + utt="utt" +fi + +s1=$data/split${numsplit}${utt}/1 +if [ ! -d $s1 ]; then + need_to_split=true +else + need_to_split=false + for f in utt2spk spk2utt spk2warp feats.scp text wav.scp cmvn.scp spk2gender \ + vad.scp segments reco2file_and_channel utt2lang; do + if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then + need_to_split=true + fi + done +fi + +if ! $need_to_split; then + exit 0; +fi + +utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done) + +directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done) + +# if this mkdir fails due to argument-list being too long, iterate. +if ! mkdir -p $directories >&/dev/null; then + for n in `seq $numsplit`; do + mkdir -p $data/split${numsplit}${utt}/$n + done +fi + +# If lockfile is not installed, just don't lock it. It's not a big deal. +which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock +trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM + +utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 + +for n in `seq $numsplit`; do + dsn=$data/split${numsplit}${utt}/$n + utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1; +done + +maybe_wav_scp= +if [ ! -f $data/segments ]; then + maybe_wav_scp=wav.scp # If there is no segments file, then wav file is + # indexed per utt. +fi + +# split some things that are indexed by utterance. +for f in feats.scp text vad.scp utt2lang $maybe_wav_scp utt2dur utt2num_frames; do + if [ -f $data/$f ]; then + utils/filter_scps.pl JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/utt2spk $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1; + fi +done + +# split some things that are indexed by speaker +for f in spk2gender spk2warp cmvn.scp; do + if [ -f $data/$f ]; then + utils/filter_scps.pl $warning_opt JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/spk2utt $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1; + fi +done + +if [ -f $data/segments ]; then + utils/filter_scps.pl JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/utt2spk $data/segments $data/split${numsplit}${utt}/JOB/segments || exit 1 + for n in `seq $numsplit`; do + dsn=$data/split${numsplit}${utt}/$n + awk '{print $2;}' $dsn/segments | sort | uniq > $dsn/tmp.reco # recording-ids. + done + if [ -f $data/reco2file_and_channel ]; then + utils/filter_scps.pl $warning_opt JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/tmp.reco $data/reco2file_and_channel \ + $data/split${numsplit}${utt}/JOB/reco2file_and_channel || exit 1 + fi + if [ -f $data/wav.scp ]; then + utils/filter_scps.pl $warning_opt JOB=1:$numsplit \ + $data/split${numsplit}${utt}/JOB/tmp.reco $data/wav.scp \ + $data/split${numsplit}${utt}/JOB/wav.scp || exit 1 + fi + for f in $data/split${numsplit}${utt}/*/tmp.reco; do rm $f; done +fi + +exit 0 diff --git a/egs/alimeeting/sa-asr/utils/filter_scp.pl b/egs/alimeeting/sa-asr/utils/filter_scp.pl new file mode 100755 index 000000000..b76d37f41 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/alimeeting/sa-asr/utils/fix_data_dir.sh b/egs/alimeeting/sa-asr/utils/fix_data_dir.sh new file mode 100755 index 000000000..ed4710d0b --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/fix_data_dir.sh @@ -0,0 +1,215 @@ +#!/usr/bin/env bash + +# This script makes sure that only the segments present in +# all of "feats.scp", "wav.scp" [if present], segments [if present] +# text, and utt2spk are present in any of them. +# It puts the original contents of data-dir into +# data-dir/.backup + +cmd="$@" + +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: utils/data/fix_data_dir.sh " + echo "e.g.: utils/data/fix_data_dir.sh data/train" + echo "This script helps ensure that the various files in a data directory" + echo "are correctly sorted and filtered, for example removing utterances" + echo "that have no features (if feats.scp is present)" + exit 1 +fi + +data=$1 + +if [ -f $data/images.scp ]; then + image/fix_data_dir.sh $cmd + exit $? +fi + +mkdir -p $data/.backup + +[ ! -d $data ] && echo "$0: no such directory $data" && exit 1; + +[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1; + +set -e -o pipefail -u + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted { + file=$1 + sort -k1,1 -u <$file >$file.tmp + if ! cmp -s $file $file.tmp; then + echo "$0: file $1 is not in sorted order or not unique, sorting it" + mv $file.tmp $file + else + rm $file.tmp + fi +} + +for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \ + reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + check_sorted $data/$x + fi +done + + +function filter_file { + filter=$1 + file_to_filter=$2 + cp $file_to_filter ${file_to_filter}.tmp + utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter + if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then + length1=$(cat ${file_to_filter}.tmp | wc -l) + length2=$(cat ${file_to_filter} | wc -l) + if [ $length1 -ne $length2 ]; then + echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter." + fi + fi + rm $file_to_filter.tmp +} + +function filter_recordings { + # We call this once before the stage when we filter on utterance-id, and once + # after. + + if [ -f $data/segments ]; then + # We have a segments file -> we need to filter this and the file wav.scp, and + # reco2file_and_utt, if it exists, to make sure they have the same list of + # recording-ids. + + if [ ! -f $data/wav.scp ]; then + echo "$0: $data/segments exists but not $data/wav.scp" + exit 1; + fi + awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings + n1=$(cat $tmpdir/recordings | wc -l) + [ ! -s $tmpdir/recordings ] && \ + echo "Empty list of recordings (bad file $data/segments)?" && exit 1; + utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp + mv $tmpdir/recordings.tmp $tmpdir/recordings + + + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + filter_file $tmpdir/recordings $data/segments + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + rm $data/segments.tmp + + filter_file $tmpdir/recordings $data/wav.scp + [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel + [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur + true + fi +} + +function filter_speakers { + # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... + utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + for s in cmvn.scp spk2gender; do + f=$data/$s + if [ -f $f ]; then + filter_file $f $tmpdir/speakers + fi + done + + filter_file $tmpdir/speakers $data/spk2utt + utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk + + for s in cmvn.scp spk2gender $spk_extra_files; do + f=$data/$s + if [ -f $f ]; then + filter_file $tmpdir/speakers $f + fi + done +} + +function filter_utts { + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + + ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order (fix this yourself)" && exit 1; + + ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; + + ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \ + echo "spk2utt is not in sorted order (fix this yourself)" && exit 1; + + if [ -f $data/utt2uniq ]; then + ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \ + echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1; + fi + + maybe_wav= + maybe_reco2dur= + [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist. + [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts + + maybe_utt2dur= + if [ -f $data/utt2dur ]; then + cat $data/utt2dur | \ + awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2dur.ok || exit 1 + maybe_utt2dur=utt2dur.ok + fi + + maybe_utt2num_frames= + if [ -f $data/utt2num_frames ]; then + cat $data/utt2num_frames | \ + awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2num_frames.ok || exit 1 + maybe_utt2num_frames=utt2num_frames.ok + fi + + for x in feats.scp text segments utt2lang $maybe_wav $maybe_utt2dur $maybe_utt2num_frames; do + if [ -f $data/$x ]; then + utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp + mv $tmpdir/utts.tmp $tmpdir/utts + fi + done + rm $data/utt2dur.ok 2>/dev/null || true + rm $data/utt2num_frames.ok 2>/dev/null || true + + [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \ + rm $tmpdir/utts && exit 1; + + + if [ -f $data/utt2spk ]; then + new_nutts=$(cat $tmpdir/utts | wc -l) + old_nutts=$(cat $data/utt2spk | wc -l) + if [ $new_nutts -ne $old_nutts ]; then + echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts" + else + echo "fix_data_dir.sh: kept all $old_nutts utterances." + fi + fi + + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then + utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x + fi + fi + done + +} + +filter_recordings +filter_speakers +filter_utts +filter_speakers +filter_recordings + +utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + +echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/egs/alimeeting/sa-asr/utils/parse_options.sh b/egs/alimeeting/sa-asr/utils/parse_options.sh new file mode 100755 index 000000000..71fb9e5ea --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl new file mode 100755 index 000000000..23992f25d --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl @@ -0,0 +1,27 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +while(<>){ + @A = split(" ", $_); + @A > 1 || die "Invalid line in spk2utt file: $_"; + $s = shift @A; + foreach $u ( @A ) { + print "$u $s\n"; + } +} + + diff --git a/egs/alimeeting/sa-asr/utils/split_scp.pl b/egs/alimeeting/sa-asr/utils/split_scp.pl new file mode 100755 index 000000000..0876dcb6d --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/split_scp.pl @@ -0,0 +1,246 @@ +#!/usr/bin/env perl + +# Copyright 2010-2011 Microsoft Corporation + +# See ../../COPYING for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This program splits up any kind of .scp or archive-type file. +# If there is no utt2spk option it will work on any text file and +# will split it up with an approximately equal number of lines in +# each but. +# With the --utt2spk option it will work on anything that has the +# utterance-id as the first entry on each line; the utt2spk file is +# of the form "utterance speaker" (on each line). +# It splits it into equal size chunks as far as it can. If you use the utt2spk +# option it will make sure these chunks coincide with speaker boundaries. In +# this case, if there are more chunks than speakers (and in some other +# circumstances), some of the resulting chunks will be empty and it will print +# an error message and exit with nonzero status. +# You will normally call this like: +# split_scp.pl scp scp.1 scp.2 scp.3 ... +# or +# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ... +# Note that you can use this script to split the utt2spk file itself, +# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ... + +# You can also call the scripts like: +# split_scp.pl -j 3 0 scp scp.0 +# [note: with this option, it assumes zero-based indexing of the split parts, +# i.e. the second number must be 0 <= n < num-jobs.] + +use warnings; + +$num_jobs = 0; +$job_id = 0; +$utt2spk_file = ""; +$one_based = 0; + +for ($x = 1; $x <= 3 && @ARGV > 0; $x++) { + if ($ARGV[0] eq "-j") { + shift @ARGV; + $num_jobs = shift @ARGV; + $job_id = shift @ARGV; + } + if ($ARGV[0] =~ /--utt2spk=(.+)/) { + $utt2spk_file=$1; + shift; + } + if ($ARGV[0] eq '--one-based') { + $one_based = 1; + shift @ARGV; + } +} + +if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 || + $job_id - $one_based >= $num_jobs)) { + die "$0: Invalid job number/index values for '-j $num_jobs $job_id" . + ($one_based ? " --one-based" : "") . "'\n" +} + +$one_based + and $job_id--; + +if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) { + die +"Usage: split_scp.pl [--utt2spk=] in.scp out1.scp out2.scp ... + or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=] in.scp [out.scp] + ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n"; +} + +$error = 0; +$inscp = shift @ARGV; +if ($num_jobs == 0) { # without -j option + @OUTPUTS = @ARGV; +} else { + for ($j = 0; $j < $num_jobs; $j++) { + if ($j == $job_id) { + if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; } + else { push @OUTPUTS, "-"; } + } else { + push @OUTPUTS, "/dev/null"; + } + } +} + +if ($utt2spk_file ne "") { # We have the --utt2spk option... + open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n"; + while(<$u_fh>) { + @A = split; + @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n"; + ($u,$s) = @A; + $utt2spk{$u} = $s; + } + close $u_fh; + open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; + @spkrs = (); + while(<$i_fh>) { + @A = split; + if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; } + $u = $A[0]; + $s = $utt2spk{$u}; + defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n"; + if(!defined $spk_count{$s}) { + push @spkrs, $s; + $spk_count{$s} = 0; + $spk_data{$s} = []; # ref to new empty array. + } + $spk_count{$s}++; + push @{$spk_data{$s}}, $_; + } + # Now split as equally as possible .. + # First allocate spks to files by allocating an approximately + # equal number of speakers. + $numspks = @spkrs; # number of speakers. + $numscps = @OUTPUTS; # number of output files. + if ($numspks < $numscps) { + die "$0: Refusing to split data because number of speakers $numspks " . + "is less than the number of output .scp files $numscps\n"; + } + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scparray[$scpidx] = []; # [] is array reference. + } + for ($spkidx = 0; $spkidx < $numspks; $spkidx++) { + $scpidx = int(($spkidx*$numscps) / $numspks); + $spk = $spkrs[$spkidx]; + push @{$scparray[$scpidx]}, $spk; + $scpcount[$scpidx] += $spk_count{$spk}; + } + + # Now will try to reassign beginning + ending speakers + # to different scp's and see if it gets more balanced. + # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2. + # We can show that if considering changing just 2 scp's, we minimize + # this by minimizing the squared difference in sizes. This is + # equivalent to minimizing the absolute difference in sizes. This + # shows this method is bound to converge. + + $changed = 1; + while($changed) { + $changed = 0; + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + # First try to reassign ending spk of this scp. + if($scpidx < $numscps-1) { + $sz = @{$scparray[$scpidx]}; + if($sz > 0) { + $spk = $scparray[$scpidx]->[$sz-1]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx]; + $nutt2 = $scpcount[$scpidx+1]; + if( abs( ($nutt2+$count) - ($nutt1-$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx+1] += $count; + $scpcount[$scpidx] -= $count; + pop @{$scparray[$scpidx]}; + unshift @{$scparray[$scpidx+1]}, $spk; + $changed = 1; + } + } + } + if($scpidx > 0 && @{$scparray[$scpidx]} > 0) { + $spk = $scparray[$scpidx]->[0]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx-1]; + $nutt2 = $scpcount[$scpidx]; + if( abs( ($nutt2-$count) - ($nutt1+$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx-1] += $count; + $scpcount[$scpidx] -= $count; + shift @{$scparray[$scpidx]}; + push @{$scparray[$scpidx-1]}, $spk; + $changed = 1; + } + } + } + } + # Now print out the files... + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scpfile = $OUTPUTS[$scpidx]; + ($scpfile ne '-' ? open($f_fh, '>', $scpfile) + : open($f_fh, '>&', \*STDOUT)) || + die "$0: Could not open scp file $scpfile for writing: $!\n"; + $count = 0; + if(@{$scparray[$scpidx]} == 0) { + print STDERR "$0: eError: split_scp.pl producing empty .scp file " . + "$scpfile (too many splits and too few speakers?)\n"; + $error = 1; + } else { + foreach $spk ( @{$scparray[$scpidx]} ) { + print $f_fh @{$spk_data{$spk}}; + $count += $spk_count{$spk}; + } + $count == $scpcount[$scpidx] || die "Count mismatch [code error]"; + } + close($f_fh); + } +} else { + # This block is the "normal" case where there is no --utt2spk + # option and we just break into equal size chunks. + + open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; + + $numscps = @OUTPUTS; # size of array. + @F = (); + while(<$i_fh>) { + push @F, $_; + } + $numlines = @F; + if($numlines == 0) { + print STDERR "$0: error: empty input scp file $inscp\n"; + $error = 1; + } + $linesperscp = int( $numlines / $numscps); # the "whole part".. + $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n"; + $remainder = $numlines - ($linesperscp * $numscps); + ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder"; + # [just doing int() rounds down]. + $n = 0; + for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) { + $scpfile = $OUTPUTS[$scpidx]; + ($scpfile ne '-' ? open($o_fh, '>', $scpfile) + : open($o_fh, '>&', \*STDOUT)) || + die "$0: Could not open scp file $scpfile for writing: $!\n"; + for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) { + print $o_fh $F[$n++]; + } + close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n"; + } + $n == $numlines || die "$n != $numlines [code error]"; +} + +exit ($error); diff --git a/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl new file mode 100755 index 000000000..6e0e438ca --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# converts an utt2spk file to a spk2utt file. +# Takes input from the stdin or from a file argument; +# output goes to the standard out. + +if ( @ARGV > 1 ) { + die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; +} + +while(<>){ + @A = split(" ", $_); + @A == 2 || die "Invalid line in utt2spk file: $_"; + ($u,$s) = @A; + if(!$seen_spk{$s}) { + $seen_spk{$s} = 1; + push @spklist, $s; + } + push (@{$spk_hash{$s}}, "$u"); +} +foreach $s (@spklist) { + $l = join(' ',@{$spk_hash{$s}}); + print "$s $l\n"; +} diff --git a/egs/alimeeting/sa-asr/utils/validate_data_dir.sh b/egs/alimeeting/sa-asr/utils/validate_data_dir.sh new file mode 100755 index 000000000..3eec443a0 --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/validate_data_dir.sh @@ -0,0 +1,404 @@ +#!/usr/bin/env bash + +cmd="$@" + +no_feats=false +no_wav=false +no_text=false +no_spk_sort=false +non_print=false + + +function show_help +{ + echo "Usage: $0 [--no-feats] [--no-text] [--non-print] [--no-wav] [--no-spk-sort] " + echo "The --no-xxx options mean that the script does not require " + echo "xxx.scp to be present, but it will check it if it is present." + echo "--no-spk-sort means that the script does not require the utt2spk to be " + echo "sorted by the speaker-id in addition to being sorted by utterance-id." + echo "--non-print ignore the presence of non-printable characters." + echo "By default, utt2spk is expected to be sorted by both, which can be " + echo "achieved by making the speaker-id prefixes of the utterance-ids" + echo "e.g.: $0 data/train" +} + +while [ $# -ne 0 ] ; do + case "$1" in + "--no-feats") + no_feats=true; + ;; + "--no-text") + no_text=true; + ;; + "--non-print") + non_print=true; + ;; + "--no-wav") + no_wav=true; + ;; + "--no-spk-sort") + no_spk_sort=true; + ;; + *) + if ! [ -z "$data" ] ; then + show_help; + exit 1 + fi + data=$1 + ;; + esac + shift +done + + + +if [ ! -d $data ]; then + echo "$0: no such directory $data" + exit 1; +fi + +if [ -f $data/images.scp ]; then + cmd=${cmd/--no-wav/} # remove --no-wav if supplied + image/validate_data_dir.sh $cmd + exit $? +fi + +for f in spk2utt utt2spk; do + if [ ! -f $data/$f ]; then + echo "$0: no such file $f" + exit 1; + fi + if [ ! -s $data/$f ]; then + echo "$0: empty file $f" + exit 1; + fi +done + +! cat $data/utt2spk | awk '{if (NF != 2) exit(1); }' && \ + echo "$0: $data/utt2spk has wrong format." && exit; + +ns=$(wc -l < $data/spk2utt) +if [ "$ns" == 1 ]; then + echo "$0: WARNING: you have only one speaker. This probably a bad idea." + echo " Search for the word 'bold' in http://kaldi-asr.org/doc/data_prep.html" + echo " for more information." +fi + + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted_and_uniq { + ! perl -ne '((substr $_,-1) eq "\n") or die "file $ARGV has invalid newline";' $1 && exit 1; + ! awk '{print $1}' < $1 | sort -uC && echo "$0: file $1 is not sorted or has duplicates" && exit 1; +} + +function partial_diff { + diff -U1 $1 $2 | (head -n 6; echo "..."; tail -n 6) + n1=`cat $1 | wc -l` + n2=`cat $2 | wc -l` + echo "[Lengths are $1=$n1 versus $2=$n2]" +} + +check_sorted_and_uniq $data/utt2spk + +if ! $no_spk_sort; then + ! sort -k2 -C $data/utt2spk && \ + echo "$0: utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; +fi + +check_sorted_and_uniq $data/spk2utt + +! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \ + <(utils/spk2utt_to_utt2spk.pl $data/spk2utt) && \ + echo "$0: spk2utt and utt2spk do not seem to match" && exit 1; + +cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts + +if [ ! -f $data/text ] && ! $no_text; then + echo "$0: no such file $data/text (if this is by design, specify --no-text)" + exit 1; +fi + +num_utts=`cat $tmpdir/utts | wc -l` +if ! $no_text; then + if ! $non_print; then + if locale -a | grep "C.UTF-8" >/dev/null; then + L=C.UTF-8 + else + L=en_US.UTF-8 + fi + n_non_print=$(LC_ALL="$L" grep -c '[^[:print:][:space:]]' $data/text) && \ + echo "$0: text contains $n_non_print lines with non-printable characters" &&\ + exit 1; + fi + utils/validate_text.pl $data/text || exit 1; + check_sorted_and_uniq $data/text + text_len=`cat $data/text | wc -l` + illegal_sym_list=" #0" + for x in $illegal_sym_list; do + if grep -w "$x" $data/text > /dev/null; then + echo "$0: Error: in $data, text contains illegal symbol $x" + exit 1; + fi + done + awk '{print $1}' < $data/text > $tmpdir/utts.txt + if ! cmp -s $tmpdir/utts{,.txt}; then + echo "$0: Error: in $data, utterance lists extracted from utt2spk and text" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.txt} + exit 1; + fi +fi + +if [ -f $data/segments ] && [ ! -f $data/wav.scp ]; then + echo "$0: in directory $data, segments file exists but no wav.scp" + exit 1; +fi + + +if [ ! -f $data/wav.scp ] && ! $no_wav; then + echo "$0: no such file $data/wav.scp (if this is by design, specify --no-wav)" + exit 1; +fi + +if [ -f $data/wav.scp ]; then + check_sorted_and_uniq $data/wav.scp + + if grep -E -q '^\S+\s+~' $data/wav.scp; then + # note: it's not a good idea to have any kind of tilde in wav.scp, even if + # part of a command, as it would cause compatibility problems if run by + # other users, but this used to be not checked for so we let it slide unless + # it's something of the form "foo ~/foo.wav" (i.e. a plain file name) which + # would definitely cause problems as the fopen system call does not do + # tilde expansion. + echo "$0: Please do not use tilde (~) in your wav.scp." + exit 1; + fi + + if [ -f $data/segments ]; then + + check_sorted_and_uniq $data/segments + # We have a segments file -> interpret wav file as "recording-ids" not utterance-ids. + ! cat $data/segments | \ + awk '{if (NF != 4 || $4 <= $3) { print "Bad line in segments file", $0; exit(1); }}' && \ + echo "$0: badly formatted segments file" && exit 1; + + segments_len=`cat $data/segments | wc -l` + if [ -f $data/text ]; then + ! cmp -s $tmpdir/utts <(awk '{print $1}' <$data/segments) && \ + echo "$0: Utterance list differs between $data/utt2spk and $data/segments " && \ + echo "$0: Lengths are $segments_len vs $num_utts" && \ + exit 1 + fi + + cat $data/segments | awk '{print $2}' | sort | uniq > $tmpdir/recordings + awk '{print $1}' $data/wav.scp > $tmpdir/recordings.wav + if ! cmp -s $tmpdir/recordings{,.wav}; then + echo "$0: Error: in $data, recording-ids extracted from segments and wav.scp" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/recordings{,.wav} + exit 1; + fi + if [ -f $data/reco2file_and_channel ]; then + # this file is needed only for ctm scoring; it's indexed by recording-id. + check_sorted_and_uniq $data/reco2file_and_channel + ! cat $data/reco2file_and_channel | \ + awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) { + if ( NF == 3 && $3 == "1" ) { + warning_issued = 1; + } else { + print "Bad line ", $0; exit 1; + } + } + } + END { + if (warning_issued == 1) { + print "The channel should be marked as A or B, not 1! You should change it ASAP! " + } + }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1; + cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/recordings.r2fc + if ! cmp -s $tmpdir/recordings{,.r2fc}; then + echo "$0: Error: in $data, recording-ids extracted from segments and reco2file_and_channel" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/recordings{,.r2fc} + exit 1; + fi + fi + else + # No segments file -> assume wav.scp indexed by utterance. + cat $data/wav.scp | awk '{print $1}' > $tmpdir/utts.wav + if ! cmp -s $tmpdir/utts{,.wav}; then + echo "$0: Error: in $data, utterance lists extracted from utt2spk and wav.scp" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.wav} + exit 1; + fi + + if [ -f $data/reco2file_and_channel ]; then + # this file is needed only for ctm scoring; it's indexed by recording-id. + check_sorted_and_uniq $data/reco2file_and_channel + ! cat $data/reco2file_and_channel | \ + awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) { + if ( NF == 3 && $3 == "1" ) { + warning_issued = 1; + } else { + print "Bad line ", $0; exit 1; + } + } + } + END { + if (warning_issued == 1) { + print "The channel should be marked as A or B, not 1! You should change it ASAP! " + } + }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1; + cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/utts.r2fc + if ! cmp -s $tmpdir/utts{,.r2fc}; then + echo "$0: Error: in $data, utterance-ids extracted from segments and reco2file_and_channel" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.r2fc} + exit 1; + fi + fi + fi +fi + +if [ ! -f $data/feats.scp ] && ! $no_feats; then + echo "$0: no such file $data/feats.scp (if this is by design, specify --no-feats)" + exit 1; +fi + +if [ -f $data/feats.scp ]; then + check_sorted_and_uniq $data/feats.scp + cat $data/feats.scp | awk '{print $1}' > $tmpdir/utts.feats + if ! cmp -s $tmpdir/utts{,.feats}; then + echo "$0: Error: in $data, utterance-ids extracted from utt2spk and features" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.feats} + exit 1; + fi +fi + + +if [ -f $data/cmvn.scp ]; then + check_sorted_and_uniq $data/cmvn.scp + cat $data/cmvn.scp | awk '{print $1}' > $tmpdir/speakers.cmvn + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + if ! cmp -s $tmpdir/speakers{,.cmvn}; then + echo "$0: Error: in $data, speaker lists extracted from spk2utt and cmvn" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/speakers{,.cmvn} + exit 1; + fi +fi + +if [ -f $data/spk2gender ]; then + check_sorted_and_uniq $data/spk2gender + ! cat $data/spk2gender | awk '{if (!((NF == 2 && ($2 == "m" || $2 == "f")))) exit 1; }' && \ + echo "$0: Mal-formed spk2gender file" && exit 1; + cat $data/spk2gender | awk '{print $1}' > $tmpdir/speakers.spk2gender + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + if ! cmp -s $tmpdir/speakers{,.spk2gender}; then + echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2gender" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/speakers{,.spk2gender} + exit 1; + fi +fi + +if [ -f $data/spk2warp ]; then + check_sorted_and_uniq $data/spk2warp + ! cat $data/spk2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \ + echo "$0: Mal-formed spk2warp file" && exit 1; + cat $data/spk2warp | awk '{print $1}' > $tmpdir/speakers.spk2warp + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + if ! cmp -s $tmpdir/speakers{,.spk2warp}; then + echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2warp" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/speakers{,.spk2warp} + exit 1; + fi +fi + +if [ -f $data/utt2warp ]; then + check_sorted_and_uniq $data/utt2warp + ! cat $data/utt2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \ + echo "$0: Mal-formed utt2warp file" && exit 1; + cat $data/utt2warp | awk '{print $1}' > $tmpdir/utts.utt2warp + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + if ! cmp -s $tmpdir/utts{,.utt2warp}; then + echo "$0: Error: in $data, utterance lists extracted from utt2spk and utt2warp" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.utt2warp} + exit 1; + fi +fi + +# check some optionally-required things +for f in vad.scp utt2lang utt2uniq; do + if [ -f $data/$f ]; then + check_sorted_and_uniq $data/$f + if ! cmp -s <( awk '{print $1}' $data/utt2spk ) \ + <( awk '{print $1}' $data/$f ); then + echo "$0: error: in $data, $f and utt2spk do not have identical utterance-id list" + exit 1; + fi + fi +done + + +if [ -f $data/utt2dur ]; then + check_sorted_and_uniq $data/utt2dur + cat $data/utt2dur | awk '{print $1}' > $tmpdir/utts.utt2dur + if ! cmp -s $tmpdir/utts{,.utt2dur}; then + echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2dur file" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.utt2dur} + exit 1; + fi + cat $data/utt2dur | \ + awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line utt2dur:" NR ":" $0; exit(1) }}' || exit 1 +fi + +if [ -f $data/utt2num_frames ]; then + check_sorted_and_uniq $data/utt2num_frames + cat $data/utt2num_frames | awk '{print $1}' > $tmpdir/utts.utt2num_frames + if ! cmp -s $tmpdir/utts{,.utt2num_frames}; then + echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2num_frames file" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/utts{,.utt2num_frames} + exit 1 + fi + awk <$data/utt2num_frames '{ + if (NF != 2 || !($2 > 0) || $2 != int($2)) { + print "Bad line utt2num_frames:" NR ":" $0 + exit 1 } }' || exit 1 +fi + +if [ -f $data/reco2dur ]; then + check_sorted_and_uniq $data/reco2dur + cat $data/reco2dur | awk '{print $1}' > $tmpdir/recordings.reco2dur + if [ -f $tmpdir/recordings ]; then + if ! cmp -s $tmpdir/recordings{,.reco2dur}; then + echo "$0: Error: in $data, recording-ids extracted from segments and reco2dur file" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/recordings{,.reco2dur} + exit 1; + fi + else + if ! cmp -s $tmpdir/{utts,recordings.reco2dur}; then + echo "$0: Error: in $data, recording-ids extracted from wav.scp and reco2dur file" + echo "$0: differ, partial diff is:" + partial_diff $tmpdir/{utts,recordings.reco2dur} + exit 1; + fi + fi + cat $data/reco2dur | \ + awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line : " $0; exit(1) }}' || exit 1 +fi + + +echo "$0: Successfully validated data-directory $data" diff --git a/egs/alimeeting/sa-asr/utils/validate_text.pl b/egs/alimeeting/sa-asr/utils/validate_text.pl new file mode 100755 index 000000000..7f75cf12f --- /dev/null +++ b/egs/alimeeting/sa-asr/utils/validate_text.pl @@ -0,0 +1,136 @@ +#!/usr/bin/env perl +# +#=============================================================================== +# Copyright 2017 Johns Hopkins University (author: Yenda Trmal ) +# Johns Hopkins University (author: Daniel Povey) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +# validation script for data//text +# to be called (preferably) from utils/validate_data_dir.sh +use strict; +use warnings; +use utf8; +use Fcntl qw< SEEK_SET >; + +# this function reads the opened file (supplied as a first +# parameter) into an array of lines. For each +# line, it tests whether it's a valid utf-8 compatible +# line. If all lines are valid utf-8, it returns the lines +# decoded as utf-8, otherwise it assumes the file's encoding +# is one of those 1-byte encodings, such as ISO-8859-x +# or Windows CP-X. +# Please recall we do not really care about +# the actually encoding, we just need to +# make sure the length of the (decoded) string +# is correct (to make the output formatting looking right). +sub get_utf8_or_bytestream { + use Encode qw(decode encode); + my $is_utf_compatible = 1; + my @unicode_lines; + my @raw_lines; + my $raw_text; + my $lineno = 0; + my $file = shift; + + while (<$file>) { + $raw_text = $_; + last unless $raw_text; + if ($is_utf_compatible) { + my $decoded_text = eval { decode("UTF-8", $raw_text, Encode::FB_CROAK) } ; + $is_utf_compatible = $is_utf_compatible && defined($decoded_text); + push @unicode_lines, $decoded_text; + } else { + #print STDERR "WARNING: the line $raw_text cannot be interpreted as UTF-8: $decoded_text\n"; + ; + } + push @raw_lines, $raw_text; + $lineno += 1; + } + + if (!$is_utf_compatible) { + return (0, @raw_lines); + } else { + return (1, @unicode_lines); + } +} + +# check if the given unicode string contain unicode whitespaces +# other than the usual four: TAB, LF, CR and SPACE +sub validate_utf8_whitespaces { + my $unicode_lines = shift; + use feature 'unicode_strings'; + for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) { + my $current_line = $unicode_lines->[$i]; + if ((substr $current_line, -1) ne "\n"){ + print STDERR "$0: The current line (nr. $i) has invalid newline\n"; + return 1; + } + my @A = split(" ", $current_line); + my $utt_id = $A[0]; + # we replace TAB, LF, CR, and SPACE + # this is to simplify the test + if ($current_line =~ /\x{000d}/) { + print STDERR "$0: The line for utterance $utt_id contains CR (0x0D) character\n"; + return 1; + } + $current_line =~ s/[\x{0009}\x{000a}\x{0020}]/./g; + if ($current_line =~/\s/) { + print STDERR "$0: The line for utterance $utt_id contains disallowed Unicode whitespaces\n"; + return 1; + } + } + return 0; +} + +# checks if the text in the file (supplied as the argument) is utf-8 compatible +# if yes, checks if it contains only allowed whitespaces. If no, then does not +# do anything. The function seeks to the original position in the file after +# reading the text. +sub check_allowed_whitespace { + my $file = shift; + my $filename = shift; + my $pos = tell($file); + (my $is_utf, my @lines) = get_utf8_or_bytestream($file); + seek($file, $pos, SEEK_SET); + if ($is_utf) { + my $has_invalid_whitespaces = validate_utf8_whitespaces(\@lines); + if ($has_invalid_whitespaces) { + print STDERR "$0: ERROR: text file '$filename' contains disallowed UTF-8 whitespace character(s)\n"; + return 0; + } + } + return 1; +} + +if(@ARGV != 1) { + die "Usage: validate_text.pl \n" . + "e.g.: validate_text.pl data/train/text\n"; +} + +my $text = shift @ARGV; + +if (-z "$text") { + print STDERR "$0: ERROR: file '$text' is empty or does not exist\n"; + exit 1; +} + +if(!open(FILE, "<$text")) { + print STDERR "$0: ERROR: failed to open $text\n"; + exit 1; +} + +check_allowed_whitespace(\*FILE, $text) or exit 1; +close(FILE); diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py index 47226021f..c18472f51 100644 --- a/funasr/bin/asr_inference.py +++ b/funasr/bin/asr_inference.py @@ -40,7 +40,6 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend header_colors = '\033[95m' @@ -91,8 +90,6 @@ class Speech2Text: asr_train_config, asr_model_file, cmvn_file, device ) frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) @@ -111,7 +108,7 @@ class Speech2Text: # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( - lm_train_config, lm_file, device + lm_train_config, lm_file, None, device ) scorers["lm"] = lm.lm @@ -142,6 +139,13 @@ class Speech2Text: pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + logging.info(f"Beam_search: {beam_search}") + logging.info(f"Decoding device={device}, dtype={dtype}") + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type @@ -198,16 +202,7 @@ class Speech2Text: if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} + batch = {"speech": speech, "speech_lengths": speech_lengths} # a. To device batch = to_device(batch, device=self.device) @@ -355,6 +350,9 @@ def inference_modelscope( if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", @@ -408,6 +406,7 @@ def inference_modelscope( data_path_and_name_and_type, dtype=dtype, fs=fs, + mc=True, batch_size=batch_size, key_file=key_file, num_workers=num_workers, @@ -452,7 +451,7 @@ def inference_modelscope( # Write the result to each file ibest_writer["token"][key] = " ".join(token) - # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["token_int"][key] = " ".join(map(str, token_int)) ibest_writer["score"][key] = str(hyp.score) if text is not None: @@ -463,6 +462,9 @@ def inference_modelscope( asr_utils.print_progress(finish_count / file_count) if writer is not None: ibest_writer["text"][key] = text + + logging.info("uttid: {}".format(key)) + logging.info("text predictions: {}\n".format(text)) return asr_result_list return _forward @@ -637,4 +639,4 @@ def main(cmd=None): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index e10ebf404..e165531f8 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -288,6 +288,9 @@ def inference_launch_funasr(**kwargs): if mode == "asr": from funasr.bin.asr_inference import inference return inference(**kwargs) + elif mode == "sa_asr": + from funasr.bin.sa_asr_inference import inference + return inference(**kwargs) elif mode == "uniasr": from funasr.bin.asr_inference_uniasr import inference return inference(**kwargs) @@ -342,4 +345,4 @@ def main(cmd=None): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py index bba50daf0..c1e2cb2a1 100755 --- a/funasr/bin/asr_train.py +++ b/funasr/bin/asr_train.py @@ -2,6 +2,14 @@ import os +import logging + +logging.basicConfig( + level='INFO', + format=f"[{os.uname()[1].split('.')[0]}]" + f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", +) + from funasr.tasks.asr import ASRTask @@ -27,7 +35,8 @@ if __name__ == '__main__': args = parse_args() # setup local gpu_id - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + if args.ngpu > 0: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) # DDP settings if args.ngpu > 1: @@ -38,9 +47,9 @@ if __name__ == '__main__': # re-compute batch size: when dataset type is small if args.dataset_type == "small": - if args.batch_size is not None: + if args.batch_size is not None and args.ngpu > 0: args.batch_size = args.batch_size * args.ngpu - if args.batch_bins is not None: + if args.batch_bins is not None and args.ngpu > 0: args.batch_bins = args.batch_bins * args.ngpu main(args=args) diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py new file mode 100644 index 000000000..be63af111 --- /dev/null +++ b/funasr/bin/sa_asr_inference.py @@ -0,0 +1,674 @@ +import argparse +import logging +import sys +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import Dict + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from funasr.fileio.datadir_writer import DatadirWriter +from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim +from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch +from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis +from funasr.modules.scorers.ctc import CTCPrefixScorer +from funasr.modules.scorers.length_bonus import LengthBonus +from funasr.modules.scorers.scorer_interface import BatchScorerInterface +from funasr.modules.subsampling import TooShortUttError +from funasr.tasks.sa_asr import ASRTask +from funasr.tasks.lm import LMTask +from funasr.text.build_tokenizer import build_tokenizer +from funasr.text.token_id_converter import TokenIDConverter +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from funasr.utils import asr_utils, wav_utils, postprocess_utils + + +header_colors = '\033[95m' +end_colors = '\033[0m' + + +class Speech2Text: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + ngram_weight: float = 0.9, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + frontend_conf: dict = None, + **kwargs, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, cmvn_file, device + ) + frontend = None + + logging.info("asr_model: {}".format(asr_model)) + logging.info("asr_train_args: {}".format(asr_train_args)) + asr_model.to(dtype=getattr(torch, dtype)).eval() + + decoder = asr_model.decoder + + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + token_list = asr_model.token_list + scorers.update( + decoder=decoder, + ctc=ctc, + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, None, device + ) + scorers["lm"] = lm.lm + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + # 4. Build BeamSearch object + # transducer is not supported now + beam_search_transducer = None + + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + ngram=ngram_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + logging.info(f"Beam_search: {beam_search}") + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + self.beam_search = beam_search + self.beam_search_transducer = beam_search_transducer + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + self.frontend = frontend + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray] + ) -> List[ + Tuple[ + Optional[str], + Optional[str], + List[str], + List[int], + Union[Hypothesis], + ] + ]: + """Inference + + Args: + speech: Input speech data + Returns: + text, text_id, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + if isinstance(profile, np.ndarray): + profile = torch.tensor(profile) + + batch = {"speech": speech, "speech_lengths": speech_lengths} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + asr_enc, _, spk_enc = self.asr_model.encode(**batch) + if isinstance(asr_enc, tuple): + asr_enc = asr_enc[0] + if isinstance(spk_enc, tuple): + spk_enc = spk_enc[0] + assert len(asr_enc) == 1, len(asr_enc) + assert len(spk_enc) == 1, len(spk_enc) + + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + for hyp in nbest_hyps: + assert isinstance(hyp, (Hypothesis)), type(hyp) + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1: last_pos] + else: + token_int = hyp.yseq[1: last_pos].tolist() + + spk_weigths=torch.stack(hyp.spk_weigths, dim=0) + + token_ori = self.converter.ids2tokens(token_int) + text_ori = self.tokenizer.tokens2text(token_ori) + + text_ori_spklist = text_ori.split('$') + cur_index = 0 + spk_choose = [] + for i in range(len(text_ori_spklist)): + text_ori_split = text_ori_spklist[i] + n = len(text_ori_split) + spk_weights_local = spk_weigths[cur_index: cur_index + n] + cur_index = cur_index + n + 1 + spk_weights_local = spk_weights_local.mean(dim=0) + spk_choose_local = spk_weights_local.argmax(-1) + spk_choose.append(spk_choose_local.item() + 1) + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + + text_spklist = text.split('$') + assert len(spk_choose) == len(text_spklist) + + spk_list=[] + for i in range(len(text_spklist)): + text_split = text_spklist[i] + n = len(text_split) + spk_list.append(str(spk_choose[i]) * n) + + text_id = '$'.join(spk_list) + + assert len(text) == len(text_id) + + results.append((text, text_id, token, token_int, hyp)) + + assert check_return_type(results) + return results + +def inference( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + streaming: bool = False, + output_dir: Optional[str] = None, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + **kwargs, +): + inference_pipeline = inference_modelscope( + maxlenratio=maxlenratio, + minlenratio=minlenratio, + batch_size=batch_size, + beam_size=beam_size, + ngpu=ngpu, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + penalty=penalty, + log_level=log_level, + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + raw_inputs=raw_inputs, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + key_file=key_file, + word_lm_train_config=word_lm_train_config, + bpemodel=bpemodel, + allow_variable_data_keys=allow_variable_data_keys, + streaming=streaming, + output_dir=output_dir, + dtype=dtype, + seed=seed, + ngram_weight=ngram_weight, + nbest=nbest, + num_workers=num_workers, + **kwargs, + ) + return inference_pipeline(data_path_and_name_and_type, raw_inputs) + +def inference_modelscope( + maxlenratio: float, + minlenratio: float, + batch_size: int, + beam_size: int, + ngpu: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + log_level: Union[int, str], + # data_path_and_name_and_type, + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str] = None, + lm_train_config: Optional[str] = None, + lm_file: Optional[str] = None, + token_type: Optional[str] = None, + key_file: Optional[str] = None, + word_lm_train_config: Optional[str] = None, + bpemodel: Optional[str] = None, + allow_variable_data_keys: bool = False, + streaming: bool = False, + output_dir: Optional[str] = None, + dtype: str = "float32", + seed: int = 0, + ngram_weight: float = 0.9, + nbest: int = 1, + num_workers: int = 1, + param_dict: dict = None, + **kwargs, +): + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text_kwargs = dict( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + cmvn_file=cmvn_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + ngram_weight=ngram_weight, + penalty=penalty, + nbest=nbest, + streaming=streaming, + ) + logging.info("speech2text_kwargs: {}".format(speech2text_kwargs)) + speech2text = Speech2Text(**speech2text_kwargs) + + def _forward(data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None, + **kwargs, + ): + # 3. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + fs=fs, + mc=True, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), + collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + finish_count = 0 + file_count = 1 + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + asr_result_list = [] + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + if output_path is not None: + writer = DatadirWriter(output_path) + else: + writer = None + + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + # N-best list of (text, token, token_int, hyp_object) + try: + results = speech2text(**batch) + except TooShortUttError as e: + logging.warning(f"Utterance {keys} {e}") + hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) + results = [[" ", ["sil"], [2], hyp]] * nbest + + # Only supporting batch_size==1 + key = keys[0] + for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results): + # Create a directory: outdir/{n}best_recog + if writer is not None: + ibest_writer = writer[f"{n}best_recog"] + + # Write the result to each file + ibest_writer["token"][key] = " ".join(token) + ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["score"][key] = str(hyp.score) + ibest_writer["text_id"][key] = text_id + + if text is not None: + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + item = {'key': key, 'value': text_postprocessed} + asr_result_list.append(item) + finish_count += 1 + asr_utils.print_progress(finish_count / file_count) + if writer is not None: + ibest_writer["text"][key] = text + + logging.info("uttid: {}".format(key)) + logging.info("text predictions: {}".format(text)) + logging.info("text_id predictions: {}\n".format(text_id)) + return asr_result_list + + return _forward + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="ASR Decoding", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--gpuid_list", + type=str, + default="", + help="The visible gpus", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=False, + action="append", + ) + group.add_argument("--raw_inputs", type=list, default=None) + # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--asr_train_config", + type=str, + help="ASR training configuration", + ) + group.add_argument( + "--asr_model_file", + type=str, + help="ASR model parameter file", + ) + group.add_argument( + "--cmvn_file", + type=str, + help="Global cmvn file", + ) + group.add_argument( + "--lm_train_config", + type=str, + help="LM training configuration", + ) + group.add_argument( + "--lm_file", + type=str, + help="LM parameter file", + ) + group.add_argument( + "--word_lm_train_config", + type=str, + help="Word LM training configuration", + ) + group.add_argument( + "--word_lm_file", + type=str, + help="Word LM parameter file", + ) + group.add_argument( + "--ngram_file", + type=str, + help="N-gram parameter file", + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + + group = parser.add_argument_group("Beam-search related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + group.add_argument("--beam_size", type=int, default=20, help="Beam size") + group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") + group.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain max output length. " + "If maxlenratio=0.0 (default), it uses a end-detect " + "function " + "to automatically find maximum hypothesis lengths." + "If maxlenratio<0.0, its absolute value is interpreted" + "as a constant max output length", + ) + group.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + group.add_argument( + "--ctc_weight", + type=float, + default=0.5, + help="CTC weight in joint decoding", + ) + group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") + group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") + group.add_argument("--streaming", type=str2bool, default=False) + + group = parser.add_argument_group("Text converter related") + group.add_argument( + "--token_type", + type=str_or_none, + default=None, + choices=["char", "bpe", None], + help="The token type for ASR model. " + "If not given, refers from the training args", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model path of sentencepiece. " + "If not given, refers from the training args", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + inference(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py new file mode 100755 index 000000000..c7c7c42a4 --- /dev/null +++ b/funasr/bin/sa_asr_train.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 + +import os + +import logging + +logging.basicConfig( + level='INFO', + format=f"[{os.uname()[1].split('.')[0]}]" + f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", +) + +from funasr.tasks.sa_asr import ASRTask + + +# for ASR Training +def parse_args(): + parser = ASRTask.get_parser() + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + args = parser.parse_args() + return args + + +def main(args=None, cmd=None): + # for ASR Training + ASRTask.main(args=args, cmd=cmd) + + +if __name__ == '__main__': + args = parse_args() + + # setup local gpu_id + if args.ngpu > 0: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + assert args.num_worker_count == 1 + + # re-compute batch size: when dataset type is small + if args.dataset_type == "small": + if args.batch_size is not None and args.ngpu > 0: + args.batch_size = args.batch_size * args.ngpu + if args.batch_bins is not None and args.ngpu > 0: + args.batch_bins = args.batch_bins * args.ngpu + + main(args=args) diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py index dc872b047..d757f7f8c 100644 --- a/funasr/fileio/sound_scp.py +++ b/funasr/fileio/sound_scp.py @@ -46,13 +46,15 @@ class SoundScpReader(collections.abc.Mapping): if self.normalize: # soundfile.read normalizes data to [-1,1] if dtype is not given array, rate = librosa.load( - wav, sr=self.dest_sample_rate, mono=not self.always_2d + wav, sr=self.dest_sample_rate, mono=self.always_2d ) else: array, rate = librosa.load( - wav, sr=self.dest_sample_rate, mono=not self.always_2d, dtype=self.dtype + wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype ) + if array.ndim==2: + array=array.transpose((1, 0)) return rate, array def get_path(self, key): diff --git a/funasr/losses/nll_loss.py b/funasr/losses/nll_loss.py new file mode 100644 index 000000000..7e4e29496 --- /dev/null +++ b/funasr/losses/nll_loss.py @@ -0,0 +1,47 @@ +import torch +from torch import nn + +class NllLoss(nn.Module): + """Nll loss. + + :param int size: the number of class + :param int padding_idx: ignored class id + :param bool normalize_length: normalize loss by sequence length if True + :param torch.nn.Module criterion: loss function + """ + + def __init__( + self, + size, + padding_idx, + normalize_length=False, + criterion=nn.NLLLoss(reduction='none'), + ): + """Construct an LabelSmoothingLoss object.""" + super(NllLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + + def forward(self, x, target): + """Compute loss between x and target. + + :param torch.Tensor x: prediction (batch, seqlen, class) + :param torch.Tensor target: + target signal masked with self.padding_id (batch, seqlen) + :return: scalar float value + :rtype torch.Tensor + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + with torch.no_grad(): + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + kl = self.criterion(x , target) + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore, 0).sum() / denom diff --git a/funasr/models/decoder/decoder_layer_sa_asr.py b/funasr/models/decoder/decoder_layer_sa_asr.py new file mode 100644 index 000000000..80afc5168 --- /dev/null +++ b/funasr/models/decoder/decoder_layer_sa_asr.py @@ -0,0 +1,169 @@ +import torch +from torch import nn + +from funasr.modules.layer_norm import LayerNorm + + +class SpeakerAttributeSpkDecoderFirstLayer(nn.Module): + + def __init__( + self, + size, + self_attn, + src_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an DecoderLayer object.""" + super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None): + + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = None + if tgt_mask is not None: + tgt_q_mask = tgt_mask[:, -1:, :] + + if self.concat_after: + tgt_concat = torch.cat( + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 + ) + x = residual + self.concat_linear1(tgt_concat) + else: + x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) + if not self.normalize_before: + x = self.norm1(x) + z = x + + residual = x + if self.normalize_before: + x = self.norm1(x) + + skip = self.src_attn(x, asr_memory, spk_memory, memory_mask) + + if self.concat_after: + x_concat = torch.cat( + (x, skip), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout(skip) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, asr_memory, spk_memory, memory_mask, z + +class SpeakerAttributeAsrDecoderFirstLayer(nn.Module): + + def __init__( + self, + size, + d_size, + src_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an DecoderLayer object.""" + super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__() + self.size = size + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.norm3 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + self.spk_linear = nn.Linear(d_size, size, bias=False) + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None): + + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = None + if tgt_mask is not None: + tgt_q_mask = tgt_mask[:, -1:, :] + + x = tgt_q + if self.normalize_before: + x = self.norm2(x) + if self.concat_after: + x_concat = torch.cat( + (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) + if not self.normalize_before: + x = self.norm2(x) + residual = x + + if dn!=None: + x = x + self.spk_linear(dn) + if self.normalize_before: + x = self.norm3(x) + + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask + + + diff --git a/funasr/models/decoder/transformer_decoder_sa_asr.py b/funasr/models/decoder/transformer_decoder_sa_asr.py new file mode 100644 index 000000000..949f9c898 --- /dev/null +++ b/funasr/models/decoder/transformer_decoder_sa_asr.py @@ -0,0 +1,291 @@ +from typing import Any +from typing import List +from typing import Sequence +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.attention import MultiHeadedAttention +from funasr.modules.attention import CosineDistanceAttention +from funasr.models.decoder.transformer_decoder import DecoderLayer +from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer +from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer +from funasr.modules.dynamic_conv import DynamicConvolution +from funasr.modules.dynamic_conv2d import DynamicConvolution2D +from funasr.modules.embedding import PositionalEncoding +from funasr.modules.layer_norm import LayerNorm +from funasr.modules.lightconv import LightweightConvolution +from funasr.modules.lightconv2d import LightweightConvolution2D +from funasr.modules.mask import subsequent_mask +from funasr.modules.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from funasr.modules.repeat import repeat +from funasr.modules.scorers.scorer_interface import BatchScorerInterface +from funasr.models.decoder.abs_decoder import AbsDecoder + +class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + spker_embedding_dim: int = 256, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + input_layer: str = "embed", + use_asr_output_layer: bool = True, + use_spk_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + ): + assert check_argument_types() + super().__init__() + attention_dim = encoder_output_size + + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(vocab_size, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") + + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + if use_asr_output_layer: + self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.asr_output_layer = None + + if use_spk_output_layer: + self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim) + else: + self.spk_output_layer = None + + self.cos_distance_att = CosineDistanceAttention() + + self.decoder1 = None + self.decoder2 = None + self.decoder3 = None + self.decoder4 = None + + def forward( + self, + asr_hs_pad: torch.Tensor, + spk_hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + profile: torch.Tensor, + profile_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + + asr_memory = asr_hs_pad + spk_memory = spk_hs_pad + memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device) + # Spk decoder + x = self.embed(tgt) + + x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1( + x, tgt_mask, asr_memory, spk_memory, memory_mask + ) + x, tgt_mask, spk_memory, memory_mask = self.decoder2( + x, tgt_mask, spk_memory, memory_mask + ) + if self.normalize_before: + x = self.after_norm(x) + if self.spk_output_layer is not None: + x = self.spk_output_layer(x) + dn, weights = self.cos_distance_att(x, profile, profile_lens) + # Asr decoder + x, tgt_mask, asr_memory, memory_mask = self.decoder3( + z, tgt_mask, asr_memory, memory_mask, dn + ) + x, tgt_mask, asr_memory, memory_mask = self.decoder4( + x, tgt_mask, asr_memory, memory_mask + ) + + if self.normalize_before: + x = self.after_norm(x) + if self.asr_output_layer is not None: + x = self.asr_output_layer(x) + + olens = tgt_mask.sum(1) + return x, weights, olens + + + def forward_one_step( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + asr_memory: torch.Tensor, + spk_memory: torch.Tensor, + profile: torch.Tensor, + cache: List[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + + x = self.embed(tgt) + + if cache is None: + cache = [None] * (2 + len(self.decoder2) + len(self.decoder4)) + new_cache = [] + x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1( + x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0] + ) + new_cache.append(x) + for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2): + x, tgt_mask, spk_memory, _ = decoder( + x, tgt_mask, spk_memory, None, cache=c + ) + new_cache.append(x) + if self.normalize_before: + x = self.after_norm(x) + else: + x = x + if self.spk_output_layer is not None: + x = self.spk_output_layer(x) + dn, weights = self.cos_distance_att(x, profile, None) + + x, tgt_mask, asr_memory, _ = self.decoder3( + z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1] + ) + new_cache.append(x) + + for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4): + x, tgt_mask, asr_memory, _ = decoder( + x, tgt_mask, asr_memory, None, cache=c + ) + new_cache.append(x) + + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.asr_output_layer is not None: + y = torch.log_softmax(self.asr_output_layer(y), dim=-1) + + return y, weights, new_cache + + def score(self, ys, state, asr_enc, spk_enc, profile): + """Score.""" + ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0) + logp, weights, state = self.forward_one_step( + ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state + ) + return logp.squeeze(0), weights.squeeze(), state + +class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder): + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + spker_embedding_dim: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + asr_num_blocks: int = 6, + spk_num_blocks: int = 3, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_asr_output_layer: bool = True, + use_spk_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + ): + assert check_argument_types() + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + spker_embedding_dim=spker_embedding_dim, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_asr_output_layer=use_asr_output_layer, + use_spk_output_layer=use_spk_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + + self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ) + self.decoder2 = repeat( + spk_num_blocks - 1, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + + self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer( + attention_dim, + spker_embedding_dim, + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ) + self.decoder4 = repeat( + asr_num_blocks - 1, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py new file mode 100644 index 000000000..0d4097ec2 --- /dev/null +++ b/funasr/models/e2e_sa_asr.py @@ -0,0 +1,521 @@ +# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging +from contextlib import contextmanager +from distutils.version import LooseVersion +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +import torch.nn.functional as F +from typeguard import check_argument_types + +from funasr.layers.abs_normalize import AbsNormalize +from funasr.losses.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from funasr.losses.nll_loss import NllLoss +from funasr.models.ctc import CTC +from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.postencoder.abs_postencoder import AbsPostEncoder +from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.modules.add_sos_eos import add_sos_eos +from funasr.modules.e2e_asr_common import ErrorCalculator +from funasr.modules.nets_utils import th_accuracy +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class ESPnetASRModel(AbsESPnetModel): + """CTC-attention hybrid Encoder-Decoder model""" + + def __init__( + self, + vocab_size: int, + max_spk_num: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + asr_encoder: AbsEncoder, + spk_encoder: torch.nn.Module, + postencoder: Optional[AbsPostEncoder], + decoder: AbsDecoder, + ctc: CTC, + spk_weight: float = 0.5, + ctc_weight: float = 0.5, + interctc_weight: float = 0.0, + ignore_id: int = -1, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + extract_feats_in_collect_stats: bool = True, + ): + assert check_argument_types() + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + assert 0.0 <= interctc_weight < 1.0, interctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.blank_id = 0 + self.sos = 1 + self.eos = 2 + self.vocab_size = vocab_size + self.max_spk_num=max_spk_num + self.ignore_id = ignore_id + self.spk_weight = spk_weight + self.ctc_weight = ctc_weight + self.interctc_weight = interctc_weight + self.token_list = token_list.copy() + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + self.preencoder = preencoder + self.postencoder = postencoder + self.asr_encoder = asr_encoder + self.spk_encoder = spk_encoder + + if not hasattr(self.asr_encoder, "interctc_use_conditioning"): + self.asr_encoder.interctc_use_conditioning = False + if self.asr_encoder.interctc_use_conditioning: + self.asr_encoder.conditioning_layer = torch.nn.Linear( + vocab_size, self.asr_encoder.output_size() + ) + + self.error_calculator = None + + + # we set self.decoder = None in the CTC mode since + # self.decoder parameters were never used and PyTorch complained + # and threw an Exception in the multi-GPU experiment. + # thanks Jeff Farris for pointing out the issue. + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + self.criterion_spk = NllLoss( + size=max_spk_num, + padding_idx=ignore_id, + normalize_length=length_normalized_loss, + ) + + if report_cer or report_wer: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, report_wer + ) + + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + profile: torch.Tensor, + profile_lengths: torch.Tensor, + text_id: torch.Tensor, + text_id_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + profile: (Batch, Length, Dim) + profile_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + + # for data-parallel + text = text[:, : text_lengths.max()] + + # 1. Encoder + asr_encoder_out, encoder_out_lens, spk_encoder_out = self.encode(speech, speech_lengths) + intermediate_outs = None + if isinstance(asr_encoder_out, tuple): + intermediate_outs = asr_encoder_out[1] + asr_encoder_out = asr_encoder_out[0] + + loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = None, None, None, None, None, None + loss_ctc, cer_ctc = None, None + stats = dict() + + # 1. CTC branch + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + asr_encoder_out, encoder_out_lens, text, text_lengths + ) + + + # Intermediate CTC (optional) + loss_interctc = 0.0 + if self.interctc_weight != 0.0 and intermediate_outs is not None: + for layer_idx, intermediate_out in intermediate_outs: + # we assume intermediate_out has the same length & padding + # as those of encoder_out + loss_ic, cer_ic = self._calc_ctc_loss( + intermediate_out, encoder_out_lens, text, text_lengths + ) + loss_interctc = loss_interctc + loss_ic + + # Collect Intermedaite CTC stats + stats["loss_interctc_layer{}".format(layer_idx)] = ( + loss_ic.detach() if loss_ic is not None else None + ) + stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic + + loss_interctc = loss_interctc / len(intermediate_outs) + + # calculate whole encoder loss + loss_ctc = ( + 1 - self.interctc_weight + ) * loss_ctc + self.interctc_weight * loss_interctc + + + # 2b. Attention decoder branch + if self.ctc_weight != 1.0: + loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss( + asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss_asr = loss_att + elif self.ctc_weight == 1.0: + loss_asr = loss_ctc + else: + loss_asr = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + + if self.spk_weight == 0.0: + loss = loss_asr + else: + loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr + + + stats = dict( + loss=loss.detach(), + loss_asr=loss_asr.detach(), + loss_att=loss_att.detach() if loss_att is not None else None, + loss_ctc=loss_ctc.detach() if loss_ctc is not None else None, + loss_spk=loss_spk.detach() if loss_spk is not None else None, + acc=acc_att, + acc_spk=acc_spk, + cer=cer_att, + wer=wer_att, + cer_ctc=cer_ctc, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + feats, feats_lengths = speech, speech_lengths + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + feats_raw = feats.clone() + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + if self.asr_encoder.interctc_use_conditioning: + encoder_out, encoder_out_lens, _ = self.asr_encoder( + feats, feats_lengths, ctc=self.ctc + ) + else: + encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths) + intermediate_outs = None + if isinstance(encoder_out, tuple): + intermediate_outs = encoder_out[1] + encoder_out = encoder_out[0] + + encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0] + # import ipdb;ipdb.set_trace() + if encoder_out_spk_ori.size(1)!=encoder_out.size(1): + encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1) + else: + encoder_out_spk=encoder_out_spk_ori + # Post-encoder, e.g. NLU + if self.postencoder is not None: + encoder_out, encoder_out_lens = self.postencoder( + encoder_out, encoder_out_lens + ) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + assert encoder_out_spk.size(0) == speech.size(0), ( + encoder_out_spk.size(), + speech.size(0), + ) + + if intermediate_outs is not None: + return (encoder_out, intermediate_outs), encoder_out_lens + + return encoder_out, encoder_out_lens, encoder_out_spk + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from transformer-decoder + + Normally, this function is called in batchify_nll. + + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) # [batch, seqlen, dim] + batch_size = decoder_out.size(0) + decoder_num_class = decoder_out.size(2) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + decoder_out.view(-1, decoder_num_class), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction="none", + ) + nll = nll.view(batch_size, -1) + nll = nll.sum(dim=1) + assert nll.size(0) == batch_size + return nll + + def batchify_nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + batch_size: int = 100, + ): + """Compute negative log likelihood(nll) from transformer-decoder + + To avoid OOM, this fuction seperate the input into batches. + Then call nll for each batch and combine and return results. + Args: + encoder_out: (Batch, Length, Dim) + encoder_out_lens: (Batch,) + ys_pad: (Batch, Length) + ys_pad_lens: (Batch,) + batch_size: int, samples each batch contain when computing nll, + you may change this to avoid OOM or increase + GPU memory usage + """ + total_num = encoder_out.size(0) + if total_num <= batch_size: + nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + else: + nll = [] + start_idx = 0 + while True: + end_idx = min(start_idx + batch_size, total_num) + batch_encoder_out = encoder_out[start_idx:end_idx, :, :] + batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] + batch_ys_pad = ys_pad[start_idx:end_idx, :] + batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] + batch_nll = self.nll( + batch_encoder_out, + batch_encoder_out_lens, + batch_ys_pad, + batch_ys_pad_lens, + ) + nll.append(batch_nll) + start_idx = end_idx + if start_idx == total_num: + break + nll = torch.cat(nll) + assert nll.size(0) == total_num + return nll + + def _calc_att_loss( + self, + asr_encoder_out: torch.Tensor, + spk_encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + profile: torch.Tensor, + profile_lens: torch.Tensor, + text_id: torch.Tensor, + text_id_lengths: torch.Tensor + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, weights_no_pad, _ = self.decoder( + asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens + ) + + spk_num_no_pad=weights_no_pad.size(-1) + pad=(0,self.max_spk_num-spk_num_no_pad) + weights=F.pad(weights_no_pad, pad, mode='constant', value=0) + + # pre_id=weights.argmax(-1) + # pre_text=decoder_out.argmax(-1) + # id_mask=(pre_id==text_id).to(dtype=text_id.dtype) + # pre_text_mask=pre_text*id_mask+1-id_mask #相同的地方不变,不同的地方设为1() + # padding_mask= ys_out_pad != self.ignore_id + # numerator = torch.sum(pre_text_mask.masked_select(padding_mask) == ys_out_pad.masked_select(padding_mask)) + # denominator = torch.sum(padding_mask) + # sd_acc = float(numerator) / float(denominator) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + loss_spk = self.criterion_spk(torch.log(weights), text_id) + + acc_spk= th_accuracy( + weights.view(-1, self.max_spk_num), + text_id, + ignore_label=self.ignore_id, + ) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py index 9671fe9d9..2e1b0c454 100644 --- a/funasr/models/frontend/default.py +++ b/funasr/models/frontend/default.py @@ -38,6 +38,7 @@ class DefaultFrontend(AbsFrontend): htk: bool = False, frontend_conf: Optional[dict] = get_default_kwargs(Frontend), apply_stft: bool = True, + use_channel: int = None, ): assert check_argument_types() super().__init__() @@ -77,6 +78,7 @@ class DefaultFrontend(AbsFrontend): ) self.n_mels = n_mels self.frontend_type = "default" + self.use_channel = use_channel def output_size(self) -> int: return self.n_mels @@ -100,9 +102,12 @@ class DefaultFrontend(AbsFrontend): if input_stft.dim() == 4: # h: (B, T, C, F) -> h: (B, T, F) if self.training: - # Select 1ch randomly - ch = np.random.randint(input_stft.size(2)) - input_stft = input_stft[:, :, ch, :] + if self.use_channel == None: + input_stft = input_stft[:, :, 0, :] + else: + # Select 1ch randomly + ch = np.random.randint(input_stft.size(2)) + input_stft = input_stft[:, :, ch, :] else: # Use the first channel input_stft = input_stft[:, :, 0, :] diff --git a/funasr/models/pooling/statistic_pooling.py b/funasr/models/pooling/statistic_pooling.py index 8f85de99d..39d94be5b 100644 --- a/funasr/models/pooling/statistic_pooling.py +++ b/funasr/models/pooling/statistic_pooling.py @@ -83,9 +83,9 @@ def windowed_statistic_pooling( num_chunk = int(math.ceil(tt / pooling_stride)) pad = pooling_size // 2 if len(xs_pad.shape) == 4: - features = F.pad(xs_pad, (0, 0, pad, pad), "reflect") + features = F.pad(xs_pad, (0, 0, pad, pad), "replicate") else: - features = F.pad(xs_pad, (pad, pad), "reflect") + features = F.pad(xs_pad, (pad, pad), "replicate") stat_list = [] for i in range(num_chunk): diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py index 62020796e..fcb3ed412 100644 --- a/funasr/modules/attention.py +++ b/funasr/modules/attention.py @@ -13,6 +13,9 @@ import torch from torch import nn from typing import Optional, Tuple +import torch.nn.functional as F +from funasr.modules.nets_utils import make_pad_mask + class MultiHeadedAttention(nn.Module): """Multi-Head Attention layer. @@ -959,3 +962,37 @@ class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): q, k, v = self.forward_qkv(query, key, value) scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) + + +class CosineDistanceAttention(nn.Module): + """ Compute Cosine Distance between spk decoder output and speaker profile + Args: + profile_path: speaker profile file path (.npy file) + """ + + def __init__(self): + super().__init__() + self.softmax = nn.Softmax(dim=-1) + + def forward(self, spk_decoder_out, profile, profile_lens=None): + """ + Args: + spk_decoder_out(torch.Tensor):(B, L, D) + spk_profiles(torch.Tensor):(B, N, D) + """ + x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D) + if profile_lens is not None: + + mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device) + min_value = float( + numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min + ) + weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value) + weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N) + else: + x = x[:, -1:, :, :] + weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1) + weights = self.softmax(weights_not_softmax) # (B, 1, N) + spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D) + + return spk_embedding, weights diff --git a/funasr/modules/beam_search/beam_search_sa_asr.py b/funasr/modules/beam_search/beam_search_sa_asr.py new file mode 100755 index 000000000..b2b6833c8 --- /dev/null +++ b/funasr/modules/beam_search/beam_search_sa_asr.py @@ -0,0 +1,525 @@ +"""Beam search module.""" + +from itertools import chain +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Tuple +from typing import Union + +import torch + +from funasr.modules.e2e_asr_common import end_detect +from funasr.modules.scorers.scorer_interface import PartialScorerInterface +from funasr.modules.scorers.scorer_interface import ScorerInterface +from funasr.models.decoder.abs_decoder import AbsDecoder + + +class Hypothesis(NamedTuple): + """Hypothesis data type.""" + + yseq: torch.Tensor + spk_weigths : List + score: Union[float, torch.Tensor] = 0 + scores: Dict[str, Union[float, torch.Tensor]] = dict() + states: Dict[str, Any] = dict() + + def asdict(self) -> dict: + """Convert data to JSON-friendly dict.""" + return self._replace( + yseq=self.yseq.tolist(), + score=float(self.score), + scores={k: float(v) for k, v in self.scores.items()}, + )._asdict() + + +class BeamSearch(torch.nn.Module): + """Beam search implementation.""" + + def __init__( + self, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + beam_size: int, + vocab_size: int, + sos: int, + eos: int, + token_list: List[str] = None, + pre_beam_ratio: float = 1.5, + pre_beam_score_key: str = None, + ): + """Initialize beam search. + + Args: + scorers (dict[str, ScorerInterface]): Dict of decoder modules + e.g., Decoder, CTCPrefixScorer, LM + The scorer will be ignored if it is `None` + weights (dict[str, float]): Dict of weights for each scorers + The scorer will be ignored if its weight is 0 + beam_size (int): The number of hypotheses kept during search + vocab_size (int): The number of vocabulary + sos (int): Start of sequence id + eos (int): End of sequence id + token_list (list[str]): List of tokens for debug log + pre_beam_score_key (str): key of scores to perform pre-beam search + pre_beam_ratio (float): beam size in the pre-beam search + will be `int(pre_beam_ratio * beam_size)` + + """ + super().__init__() + # set scorers + self.weights = weights + self.scorers = dict() + self.full_scorers = dict() + self.part_scorers = dict() + # this module dict is required for recursive cast + # `self.to(device, dtype)` in `recog.py` + self.nn_dict = torch.nn.ModuleDict() + for k, v in scorers.items(): + w = weights.get(k, 0) + if w == 0 or v is None: + continue + assert isinstance( + v, ScorerInterface + ), f"{k} ({type(v)}) does not implement ScorerInterface" + self.scorers[k] = v + if isinstance(v, PartialScorerInterface): + self.part_scorers[k] = v + else: + self.full_scorers[k] = v + if isinstance(v, torch.nn.Module): + self.nn_dict[k] = v + + # set configurations + self.sos = sos + self.eos = eos + self.token_list = token_list + self.pre_beam_size = int(pre_beam_ratio * beam_size) + self.beam_size = beam_size + self.n_vocab = vocab_size + if ( + pre_beam_score_key is not None + and pre_beam_score_key != "full" + and pre_beam_score_key not in self.full_scorers + ): + raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") + self.pre_beam_score_key = pre_beam_score_key + self.do_pre_beam = ( + self.pre_beam_score_key is not None + and self.pre_beam_size < self.n_vocab + and len(self.part_scorers) > 0 + ) + + def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]: + """Get an initial hypothesis data. + + Args: + x (torch.Tensor): The encoder output feature + + Returns: + Hypothesis: The initial hypothesis. + + """ + init_states = dict() + init_scores = dict() + for k, d in self.scorers.items(): + init_states[k] = d.init_state(x) + init_scores[k] = 0.0 + return [ + Hypothesis( + score=0.0, + scores=init_scores, + states=init_states, + yseq=torch.tensor([self.sos], device=x.device), + spk_weigths=[], + ) + ] + + @staticmethod + def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: + """Append new token to prefix tokens. + + Args: + xs (torch.Tensor): The prefix token + x (int): The new token to append + + Returns: + torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device + + """ + x = torch.tensor([x], dtype=xs.dtype, device=xs.device) + return torch.cat((xs, x)) + + def score_full( + self, hyp: Hypothesis, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.full_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.full_scorers` + and tensor score values of shape: `(self.n_vocab,)`, + and state dict that has string keys + and state values of `self.full_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.full_scorers.items(): + if isinstance(d, AbsDecoder): + scores[k], spk_weigths, states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc, spk_enc, profile) + else: + scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc) + return scores, spk_weigths, states + + def score_partial( + self, hyp: Hypothesis, ids: torch.Tensor, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.part_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + ids (torch.Tensor): 1D tensor of new partial tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.part_scorers` + and tensor score values of shape: `(len(ids),)`, + and state dict that has string keys + and state values of `self.part_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.part_scorers.items(): + if isinstance(d, AbsDecoder): + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc, spk_enc, profile) + else: + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc) + return scores, states + + def beam( + self, weighted_scores: torch.Tensor, ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute topk full token ids and partial token ids. + + Args: + weighted_scores (torch.Tensor): The weighted sum scores for each tokens. + Its shape is `(self.n_vocab,)`. + ids (torch.Tensor): The partial token ids to compute topk + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + The topk full token ids and partial token ids. + Their shapes are `(self.beam_size,)` + + """ + # no pre beam performed + if weighted_scores.size(0) == ids.size(0): + top_ids = weighted_scores.topk(self.beam_size)[1] + return top_ids, top_ids + + # mask pruned in pre-beam not to select in topk + tmp = weighted_scores[ids] + weighted_scores[:] = -float("inf") + weighted_scores[ids] = tmp + top_ids = weighted_scores.topk(self.beam_size)[1] + local_ids = weighted_scores[ids].topk(self.beam_size)[1] + return top_ids, local_ids + + @staticmethod + def merge_scores( + prev_scores: Dict[str, float], + next_full_scores: Dict[str, torch.Tensor], + full_idx: int, + next_part_scores: Dict[str, torch.Tensor], + part_idx: int, + ) -> Dict[str, torch.Tensor]: + """Merge scores for new hypothesis. + + Args: + prev_scores (Dict[str, float]): + The previous hypothesis scores by `self.scorers` + next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` + full_idx (int): The next token id for `next_full_scores` + next_part_scores (Dict[str, torch.Tensor]): + scores of partial tokens by `self.part_scorers` + part_idx (int): The new token id for `next_part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are scalar tensors by the scorers. + + """ + new_scores = dict() + for k, v in next_full_scores.items(): + new_scores[k] = prev_scores[k] + v[full_idx] + for k, v in next_part_scores.items(): + new_scores[k] = prev_scores[k] + v[part_idx] + return new_scores + + def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: + """Merge states for new hypothesis. + + Args: + states: states of `self.full_scorers` + part_states: states of `self.part_scorers` + part_idx (int): The new token id for `part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are states of the scorers. + + """ + new_states = dict() + for k, v in states.items(): + new_states[k] = v + for k, d in self.part_scorers.items(): + new_states[k] = d.select_state(part_states[k], part_idx) + return new_states + + def search( + self, running_hyps: List[Hypothesis], asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor + ) -> List[Hypothesis]: + """Search new tokens for running hypotheses and encoded speech x. + + Args: + running_hyps (List[Hypothesis]): Running hypotheses on beam + x (torch.Tensor): Encoded speech feature (T, D) + + Returns: + List[Hypotheses]: Best sorted hypotheses + + """ + # import ipdb;ipdb.set_trace() + best_hyps = [] + part_ids = torch.arange(self.n_vocab, device=asr_enc.device) # no pre-beam + for hyp in running_hyps: + # scoring + weighted_scores = torch.zeros(self.n_vocab, dtype=asr_enc.dtype, device=asr_enc.device) + scores, spk_weigths, states = self.score_full(hyp, asr_enc, spk_enc, profile) + for k in self.full_scorers: + weighted_scores += self.weights[k] * scores[k] + # partial scoring + if self.do_pre_beam: + pre_beam_scores = ( + weighted_scores + if self.pre_beam_score_key == "full" + else scores[self.pre_beam_score_key] + ) + part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] + part_scores, part_states = self.score_partial(hyp, part_ids, asr_enc, spk_enc, profile) + for k in self.part_scorers: + weighted_scores[part_ids] += self.weights[k] * part_scores[k] + # add previous hyp score + weighted_scores += hyp.score + + # update hyps + for j, part_j in zip(*self.beam(weighted_scores, part_ids)): + # will be (2 x beam at most) + best_hyps.append( + Hypothesis( + score=weighted_scores[j], + yseq=self.append_token(hyp.yseq, j), + scores=self.merge_scores( + hyp.scores, scores, j, part_scores, part_j + ), + states=self.merge_states(states, part_states, part_j), + spk_weigths=hyp.spk_weigths+[spk_weigths], + ) + ) + + # sort and prune 2 x beam -> beam + best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ + : min(len(best_hyps), self.beam_size) + ] + return best_hyps + + def forward( + self, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 + ) -> List[Hypothesis]: + """Perform beam search. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + minlenratio (float): Input length ratio to obtain min output length. + + Returns: + list[Hypothesis]: N-best decoding results + + """ + # import ipdb;ipdb.set_trace() + # set length bounds + if maxlenratio == 0: + maxlen = asr_enc.shape[0] + else: + maxlen = max(1, int(maxlenratio * asr_enc.size(0))) + minlen = int(minlenratio * asr_enc.size(0)) + logging.info("decoder input length: " + str(asr_enc.shape[0])) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # main loop of prefix search + running_hyps = self.init_hyp(asr_enc) + ended_hyps = [] + for i in range(maxlen): + logging.debug("position " + str(i)) + best = self.search(running_hyps, asr_enc, spk_enc, profile) + #import pdb;pdb.set_trace() + # post process of one iteration + running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) + # end detection + if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): + logging.info(f"end detected at {i}") + break + if len(running_hyps) == 0: + logging.info("no hypothesis. Finish decoding.") + break + else: + logging.debug(f"remained hypotheses: {len(running_hyps)}") + + nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) + # check the number of hypotheses reaching to eos + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform recognition " + "again with smaller minlenratio." + ) + return ( + [] + if minlenratio < 0.1 + else self.forward(asr_enc, spk_enc, profile, maxlenratio, max(0.0, minlenratio - 0.1)) + ) + + # report the best result + best = nbest_hyps[0] + for k, v in best.scores.items(): + logging.info( + f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" + ) + logging.info(f"total log probability: {best.score:.2f}") + logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") + if self.token_list is not None: + logging.info( + "best hypo: " + + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + + "\n" + ) + return nbest_hyps + + def post_process( + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: List[Hypothesis], + ended_hyps: List[Hypothesis], + ) -> List[Hypothesis]: + """Perform post-processing of beam search iterations. + + Args: + i (int): The length of hypothesis tokens. + maxlen (int): The maximum length of tokens in beam search. + maxlenratio (int): The maximum length ratio in beam search. + running_hyps (List[Hypothesis]): The running hypotheses in beam search. + ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. + + Returns: + List[Hypothesis]: The new running hypotheses. + + """ + logging.debug(f"the number of running hypotheses: {len(running_hyps)}") + if self.token_list is not None: + logging.debug( + "best hypo: " + + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) + ) + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last position in the loop") + running_hyps = [ + h._replace(yseq=self.append_token(h.yseq, self.eos)) + for h in running_hyps + ] + + # add ended hypotheses to a final list, and removed them from current hypotheses + # (this will be a problem, number of hyps < beam) + remained_hyps = [] + for hyp in running_hyps: + if hyp.yseq[-1] == self.eos: + # e.g., Word LM needs to add final score + for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): + s = d.final_score(hyp.states[k]) + hyp.scores[k] += s + hyp = hyp._replace(score=hyp.score + self.weights[k] * s) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + return remained_hyps + + +def beam_search( + x: torch.Tensor, + sos: int, + eos: int, + beam_size: int, + vocab_size: int, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + token_list: List[str] = None, + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + pre_beam_ratio: float = 1.5, + pre_beam_score_key: str = "full", +) -> list: + """Perform beam search with scorers. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + sos (int): Start of sequence id + eos (int): End of sequence id + beam_size (int): The number of hypotheses kept during search + vocab_size (int): The number of vocabulary + scorers (dict[str, ScorerInterface]): Dict of decoder modules + e.g., Decoder, CTCPrefixScorer, LM + The scorer will be ignored if it is `None` + weights (dict[str, float]): Dict of weights for each scorers + The scorer will be ignored if its weight is 0 + token_list (list[str]): List of tokens for debug log + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + minlenratio (float): Input length ratio to obtain min output length. + pre_beam_score_key (str): key of scores to perform pre-beam search + pre_beam_ratio (float): beam size in the pre-beam search + will be `int(pre_beam_ratio * beam_size)` + + Returns: + list: N-best decoding results + + """ + ret = BeamSearch( + scorers, + weights, + beam_size=beam_size, + vocab_size=vocab_size, + pre_beam_ratio=pre_beam_ratio, + pre_beam_score_key=pre_beam_score_key, + sos=sos, + eos=eos, + token_list=token_list, + ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) + return [h.asdict() for h in ret] diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 3d2004c2d..f8c100961 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -444,6 +444,12 @@ class AbsTask(ABC): default=False, help='Perform on "collect stats" mode', ) + group.add_argument( + "--mc", + type=bool, + default=False, + help="MultiChannel input", + ) group.add_argument( "--write_collected_feats", type=str2bool, @@ -635,8 +641,8 @@ class AbsTask(ABC): group.add_argument( "--init_param", type=str, + action="append", default=[], - nargs="*", help="Specify the file path used for initialization of parameters. " "The format is ':::', " "where file_path is the model file path, " @@ -662,7 +668,7 @@ class AbsTask(ABC): "--freeze_param", type=str, default=[], - nargs="*", + action="append", help="Freeze parameters", ) @@ -1153,10 +1159,10 @@ class AbsTask(ABC): elif args.distributed and args.simple_ddp: distributed_option.init_torch_distributed_pai(args) args.ngpu = dist.get_world_size() - if args.dataset_type == "small": + if args.dataset_type == "small" and args.ngpu > 0: if args.batch_size is not None: args.batch_size = args.batch_size * args.ngpu - if args.batch_bins is not None: + if args.batch_bins is not None and args.ngpu > 0: args.batch_bins = args.batch_bins * args.ngpu # filter samples if wav.scp and text are mismatch @@ -1316,6 +1322,7 @@ class AbsTask(ABC): data_path_and_name_and_type=args.train_data_path_and_name_and_type, key_file=train_key_file, batch_size=args.batch_size, + mc=args.mc, dtype=args.train_dtype, num_workers=args.num_workers, allow_variable_data_keys=args.allow_variable_data_keys, @@ -1327,6 +1334,7 @@ class AbsTask(ABC): data_path_and_name_and_type=args.valid_data_path_and_name_and_type, key_file=valid_key_file, batch_size=args.valid_batch_size, + mc=args.mc, dtype=args.train_dtype, num_workers=args.num_workers, allow_variable_data_keys=args.allow_variable_data_keys, diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py new file mode 100644 index 000000000..738ec522d --- /dev/null +++ b/funasr/tasks/sa_asr.py @@ -0,0 +1,623 @@ +import argparse +import logging +import os +from pathlib import Path +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import torch +import yaml +from typeguard import check_argument_types +from typeguard import check_return_type + +from funasr.datasets.collate_fn import CommonCollateFn +from funasr.datasets.preprocessor import CommonPreprocessor +from funasr.layers.abs_normalize import AbsNormalize +from funasr.layers.global_mvn import GlobalMVN +from funasr.layers.utterance_mvn import UtteranceMVN +from funasr.models.ctc import CTC +from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.decoder.rnn_decoder import RNNDecoder +from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt +from funasr.models.decoder.transformer_decoder import ( + DynamicConvolution2DTransformerDecoder, # noqa: H301 +) +from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder +from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder +from funasr.models.decoder.transformer_decoder import ( + LightweightConvolution2DTransformerDecoder, # noqa: H301 +) +from funasr.models.decoder.transformer_decoder import ( + LightweightConvolutionTransformerDecoder, # noqa: H301 +) +from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN +from funasr.models.decoder.transformer_decoder import TransformerDecoder +from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder +from funasr.models.e2e_sa_asr import ESPnetASRModel +from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer +from funasr.models.e2e_tp import TimestampPredictor +from funasr.models.e2e_asr_mfcca import MFCCA +from funasr.models.e2e_uni_asr import UniASR +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.encoder.conformer_encoder import ConformerEncoder +from funasr.models.encoder.data2vec_encoder import Data2VecEncoder +from funasr.models.encoder.rnn_encoder import RNNEncoder +from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt +from funasr.models.encoder.transformer_encoder import TransformerEncoder +from funasr.models.encoder.mfcca_encoder import MFCCAEncoder +from funasr.models.encoder.resnet34_encoder import ResNet34,ResNet34Diar +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.frontend.default import DefaultFrontend +from funasr.models.frontend.default import MultiChannelFrontend +from funasr.models.frontend.fused import FusedFrontends +from funasr.models.frontend.s3prl import S3prlFrontend +from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.models.frontend.windowing import SlidingWindow +from funasr.models.postencoder.abs_postencoder import AbsPostEncoder +from funasr.models.postencoder.hugging_face_transformers_postencoder import ( + HuggingFaceTransformersPostEncoder, # noqa: H301 +) +from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3 +from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.preencoder.linear import LinearProjection +from funasr.models.preencoder.sinc import LightweightSincConvs +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.specaug.specaug import SpecAug +from funasr.models.specaug.specaug import SpecAugLFR +from funasr.modules.subsampling import Conv1dSubsampling +from funasr.tasks.abs_task import AbsTask +from funasr.text.phoneme_tokenizer import g2p_choices +from funasr.torch_utils.initialize import initialize +from funasr.train.abs_espnet_model import AbsESPnetModel +from funasr.train.class_choices import ClassChoices +from funasr.train.trainer import Trainer +from funasr.utils.get_default_kwargs import get_default_kwargs +from funasr.utils.nested_dict_action import NestedDictAction +from funasr.utils.types import float_or_none +from funasr.utils.types import int_or_none +from funasr.utils.types import str2bool +from funasr.utils.types import str_or_none + +frontend_choices = ClassChoices( + name="frontend", + classes=dict( + default=DefaultFrontend, + sliding_window=SlidingWindow, + s3prl=S3prlFrontend, + fused=FusedFrontends, + wav_frontend=WavFrontend, + multichannelfrontend=MultiChannelFrontend, + ), + type_check=AbsFrontend, + default="default", +) +specaug_choices = ClassChoices( + name="specaug", + classes=dict( + specaug=SpecAug, + specaug_lfr=SpecAugLFR, + ), + type_check=AbsSpecAug, + default=None, + optional=True, +) +normalize_choices = ClassChoices( + "normalize", + classes=dict( + global_mvn=GlobalMVN, + utterance_mvn=UtteranceMVN, + ), + type_check=AbsNormalize, + default=None, + optional=True, +) +model_choices = ClassChoices( + "model", + classes=dict( + asr=ESPnetASRModel, + uniasr=UniASR, + paraformer=Paraformer, + paraformer_bert=ParaformerBert, + bicif_paraformer=BiCifParaformer, + contextual_paraformer=ContextualParaformer, + mfcca=MFCCA, + timestamp_prediction=TimestampPredictor, + ), + type_check=AbsESPnetModel, + default="asr", +) +preencoder_choices = ClassChoices( + name="preencoder", + classes=dict( + sinc=LightweightSincConvs, + linear=LinearProjection, + ), + type_check=AbsPreEncoder, + default=None, + optional=True, +) +asr_encoder_choices = ClassChoices( + "asr_encoder", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + rnn=RNNEncoder, + sanm=SANMEncoder, + sanm_chunk_opt=SANMEncoderChunkOpt, + data2vec_encoder=Data2VecEncoder, + mfcca_enc=MFCCAEncoder, + ), + type_check=AbsEncoder, + default="rnn", +) + +spk_encoder_choices = ClassChoices( + "spk_encoder", + classes=dict( + resnet34_diar=ResNet34Diar, + ), + default="resnet34_diar", +) + +encoder_choices2 = ClassChoices( + "encoder2", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + rnn=RNNEncoder, + sanm=SANMEncoder, + sanm_chunk_opt=SANMEncoderChunkOpt, + ), + type_check=AbsEncoder, + default="rnn", +) +postencoder_choices = ClassChoices( + name="postencoder", + classes=dict( + hugging_face_transformers=HuggingFaceTransformersPostEncoder, + ), + type_check=AbsPostEncoder, + default=None, + optional=True, +) +decoder_choices = ClassChoices( + "decoder", + classes=dict( + transformer=TransformerDecoder, + lightweight_conv=LightweightConvolutionTransformerDecoder, + lightweight_conv2d=LightweightConvolution2DTransformerDecoder, + dynamic_conv=DynamicConvolutionTransformerDecoder, + dynamic_conv2d=DynamicConvolution2DTransformerDecoder, + rnn=RNNDecoder, + fsmn_scama_opt=FsmnDecoderSCAMAOpt, + paraformer_decoder_sanm=ParaformerSANMDecoder, + paraformer_decoder_san=ParaformerDecoderSAN, + contextual_paraformer_decoder=ContextualParaformerDecoder, + sa_decoder=SAAsrTransformerDecoder, + ), + type_check=AbsDecoder, + default="sa_decoder", +) +decoder_choices2 = ClassChoices( + "decoder2", + classes=dict( + transformer=TransformerDecoder, + lightweight_conv=LightweightConvolutionTransformerDecoder, + lightweight_conv2d=LightweightConvolution2DTransformerDecoder, + dynamic_conv=DynamicConvolutionTransformerDecoder, + dynamic_conv2d=DynamicConvolution2DTransformerDecoder, + rnn=RNNDecoder, + fsmn_scama_opt=FsmnDecoderSCAMAOpt, + paraformer_decoder_sanm=ParaformerSANMDecoder, + ), + type_check=AbsDecoder, + default="rnn", +) +predictor_choices = ClassChoices( + name="predictor", + classes=dict( + cif_predictor=CifPredictor, + ctc_predictor=None, + cif_predictor_v2=CifPredictorV2, + cif_predictor_v3=CifPredictorV3, + ), + type_check=None, + default="cif_predictor", + optional=True, +) +predictor_choices2 = ClassChoices( + name="predictor2", + classes=dict( + cif_predictor=CifPredictor, + ctc_predictor=None, + cif_predictor_v2=CifPredictorV2, + ), + type_check=None, + default="cif_predictor", + optional=True, +) +stride_conv_choices = ClassChoices( + name="stride_conv", + classes=dict( + stride_conv1d=Conv1dSubsampling + ), + type_check=None, + default="stride_conv1d", + optional=True, +) + + +class ASRTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + specaug_choices, + # --normalize and --normalize_conf + normalize_choices, + # --model and --model_conf + model_choices, + # --preencoder and --preencoder_conf + preencoder_choices, + # --asr_encoder and --asr_encoder_conf + asr_encoder_choices, + # --spk_encoder and --spk_encoder_conf + spk_encoder_choices, + # --postencoder and --postencoder_conf + postencoder_choices, + # --decoder and --decoder_conf + decoder_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + # required = parser.get_default("required") + # required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) + group.add_argument( + "--max_spk_num", + type=int_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--seg_dict_file", + type=str, + default=None, + help="seg_dict_file for text processing", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group.add_argument( + "--ctc_conf", + action=NestedDictAction, + default=get_default_kwargs(CTC), + help="The keyword arguments for CTC class.", + ) + group.add_argument( + "--joint_net_conf", + action=NestedDictAction, + default=None, + help="The keyword arguments for joint network class.", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word", "phn"], + help="The text will be tokenized " "in the specified level token", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model file of sentencepiece", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + default=None, + help="non_linguistic_symbols file path", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=g2p_choices, + default=None, + help="Specify g2p method if --token_type=phn", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Scale the maximum amplitude to the given value.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The file path of rir scp file.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="THe probability for applying RIR convolution.", + ) + parser.add_argument( + "--cmvn_file", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability applying Noise adding.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of noise decibel level.", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, + seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech", "text") + else: + # Recognition mode + retval = ("speech",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace): + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size}") + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + if args.frontend == 'wav_frontend': + frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) + else: + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Pre-encoder input block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + if getattr(args, "preencoder", None) is not None: + preencoder_class = preencoder_choices.get_class(args.preencoder) + preencoder = preencoder_class(**args.preencoder_conf) + input_size = preencoder.output_size() + else: + preencoder = None + + # 5. Encoder + asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder) + asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf) + spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder) + spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf) + + # 6. Post-encoder block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + asr_encoder_output_size = asr_encoder.output_size() + if getattr(args, "postencoder", None) is not None: + postencoder_class = postencoder_choices.get_class(args.postencoder) + postencoder = postencoder_class( + input_size=asr_encoder_output_size, **args.postencoder_conf + ) + asr_encoder_output_size = postencoder.output_size() + else: + postencoder = None + + # 7. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=asr_encoder_output_size, + **args.decoder_conf, + ) + + # 8. CTC + ctc = CTC( + odim=vocab_size, encoder_output_size=asr_encoder_output_size, **args.ctc_conf + ) + + max_spk_num=int(args.max_spk_num) + + # import ipdb;ipdb.set_trace() + # 9. Build model + try: + model_class = model_choices.get_class(args.model) + except AttributeError: + model_class = model_choices.get_class("asr") + model = model_class( + vocab_size=vocab_size, + max_spk_num=max_spk_num, + frontend=frontend, + specaug=specaug, + normalize=normalize, + preencoder=preencoder, + asr_encoder=asr_encoder, + spk_encoder=spk_encoder, + postencoder=postencoder, + decoder=decoder, + ctc=ctc, + token_list=token_list, + **args.model_conf, + ) + + # 10. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py index b607e1da0..014a79ffa 100644 --- a/funasr/utils/postprocess_utils.py +++ b/funasr/utils/postprocess_utils.py @@ -106,18 +106,17 @@ def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]: if num in abbr_begin: if time_stamp is not None: begin = time_stamp[ts_nums[num]][0] - abbr_word = words[num].upper() + word_lists.append(words[num].upper()) num += 1 while num < words_size: if num in abbr_end: - abbr_word += words[num].upper() + word_lists.append(words[num].upper()) last_num = num break else: if words[num].encode('utf-8').isalpha(): - abbr_word += words[num].upper() + word_lists.append(words[num].upper()) num += 1 - word_lists.append(abbr_word) if time_stamp is not None: end = time_stamp[ts_nums[num]][1] ts_lists.append([begin, end]) diff --git a/setup.py b/setup.py index e83763726..ea556066b 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ requirements = { "install": [ "setuptools>=38.5.1", # "configargparse>=1.2.1", - "typeguard<=2.13.3", + "typeguard==2.13.3", "humanfriendly", "scipy>=1.4.1", # "filelock", @@ -42,7 +42,10 @@ requirements = { "oss2", # "kaldi-native-fbank", # timestamp - "edit-distance" + "edit-distance", + # textgrid + "textgrid", + "protobuf==3.20.0", ], # train: The modules invoked when training only. "train": [ From 3b7e4b0d34ab0989b942ba84e077fadc5b96c036 Mon Sep 17 00:00:00 2001 From: smohan-speech Date: Sat, 6 May 2023 16:38:09 +0800 Subject: [PATCH 2/5] add speaker-attributed ASR task for alimeeting --- funasr/utils/postprocess_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py index 014a79ffa..f4efea66f 100644 --- a/funasr/utils/postprocess_utils.py +++ b/funasr/utils/postprocess_utils.py @@ -106,17 +106,18 @@ def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]: if num in abbr_begin: if time_stamp is not None: begin = time_stamp[ts_nums[num]][0] - word_lists.append(words[num].upper()) + abbr_word = words[num].upper() num += 1 while num < words_size: if num in abbr_end: - word_lists.append(words[num].upper()) + abbr_word += words[num].upper() last_num = num break else: if words[num].encode('utf-8').isalpha(): - word_lists.append(words[num].upper()) + abbr_word += words[num].upper() num += 1 + word_lists.append(abbr_word) if time_stamp is not None: end = time_stamp[ts_nums[num]][1] ts_lists.append([begin, end]) @@ -241,4 +242,4 @@ def sentence_postprocess(words: List[Any], time_stamp: List[List] = None): if ch != ' ': real_word_lists.append(ch) sentence = ''.join(word_lists).strip() - return sentence, real_word_lists + return sentence, real_word_lists \ No newline at end of file From d76aea23d9f5daac4df7ee1985d07f7428abc719 Mon Sep 17 00:00:00 2001 From: smohan-speech Date: Sun, 7 May 2023 02:21:58 +0800 Subject: [PATCH 3/5] add speaker-attributed ASR task for alimeeting --- egs/alimeeting/sa-asr/asr_local.sh | 33 +- egs/alimeeting/sa-asr/asr_local_infer.sh | 3 +- .../sa-asr/conf/train_asr_conformer.yaml | 1 - .../sa-asr/conf/train_sa_asr_conformer.yaml | 1 - .../sa-asr/local/alimeeting_data_prep.sh | 14 +- .../local/alimeeting_data_prep_test_2023.sh | 10 +- .../sa-asr/{utils => local}/apply_map.pl | 0 .../sa-asr/{utils => local}/combine_data.sh | 6 +- .../sa-asr/{utils => local}/copy_data_dir.sh | 28 +- .../{utils => local}/data/get_reco2dur.sh | 0 .../data/get_segments_for_data.sh | 2 +- .../{utils => local}/data/get_utt2dur.sh | 2 +- .../{utils => local}/data/split_data.sh | 6 +- .../sa-asr/{utils => local}/fix_data_dir.sh | 6 +- egs/alimeeting/sa-asr/local/format_wav_scp.py | 243 ++++++++++ egs/alimeeting/sa-asr/local/format_wav_scp.sh | 142 ++++++ .../sa-asr/local/perturb_data_dir_speed.sh | 116 +++++ .../{utils => local}/spk2utt_to_utt2spk.pl | 0 .../{utils => local}/utt2spk_to_spk2utt.pl | 0 .../{utils => local}/validate_data_dir.sh | 4 +- .../sa-asr/{utils => local}/validate_text.pl | 0 egs/alimeeting/sa-asr/path.sh | 3 +- egs/alimeeting/sa-asr/utils | 1 + egs/alimeeting/sa-asr/utils/filter_scp.pl | 87 ---- egs/alimeeting/sa-asr/utils/parse_options.sh | 97 ---- egs/alimeeting/sa-asr/utils/split_scp.pl | 246 ---------- funasr/bin/asr_inference.py | 28 +- funasr/bin/asr_inference_launch.py | 8 +- funasr/bin/asr_train.py | 8 - funasr/bin/sa_asr_inference.py | 24 +- funasr/bin/sa_asr_train.py | 8 - funasr/losses/label_smoothing_loss.py | 46 ++ funasr/models/decoder/transformer_decoder.py | 428 +++++++++++++++++- funasr/models/e2e_sa_asr.py | 3 +- funasr/tasks/sa_asr.py | 2 +- 35 files changed, 1090 insertions(+), 516 deletions(-) rename egs/alimeeting/sa-asr/{utils => local}/apply_map.pl (100%) rename egs/alimeeting/sa-asr/{utils => local}/combine_data.sh (96%) rename egs/alimeeting/sa-asr/{utils => local}/copy_data_dir.sh (80%) rename egs/alimeeting/sa-asr/{utils => local}/data/get_reco2dur.sh (100%) rename egs/alimeeting/sa-asr/{utils => local}/data/get_segments_for_data.sh (93%) rename egs/alimeeting/sa-asr/{utils => local}/data/get_utt2dur.sh (99%) rename egs/alimeeting/sa-asr/{utils => local}/data/split_data.sh (96%) rename egs/alimeeting/sa-asr/{utils => local}/fix_data_dir.sh (97%) create mode 100755 egs/alimeeting/sa-asr/local/format_wav_scp.py create mode 100755 egs/alimeeting/sa-asr/local/format_wav_scp.sh create mode 100755 egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh rename egs/alimeeting/sa-asr/{utils => local}/spk2utt_to_utt2spk.pl (100%) rename egs/alimeeting/sa-asr/{utils => local}/utt2spk_to_spk2utt.pl (100%) rename egs/alimeeting/sa-asr/{utils => local}/validate_data_dir.sh (99%) rename egs/alimeeting/sa-asr/{utils => local}/validate_text.pl (100%) create mode 120000 egs/alimeeting/sa-asr/utils delete mode 100755 egs/alimeeting/sa-asr/utils/filter_scp.pl delete mode 100755 egs/alimeeting/sa-asr/utils/parse_options.sh delete mode 100755 egs/alimeeting/sa-asr/utils/split_scp.pl diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh index c0359eb35..419e34144 100755 --- a/egs/alimeeting/sa-asr/asr_local.sh +++ b/egs/alimeeting/sa-asr/asr_local.sh @@ -434,14 +434,14 @@ if ! "${skip_data_prep}"; then log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp" for factor in ${speed_perturb_factors}; do if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then - scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}" + local/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}" _dirs+="data/${train_set}_sp${factor} " else # If speed factor is 1, same as the original _dirs+="data/${train_set} " fi done - utils/combine_data.sh "data/${train_set}_sp" ${_dirs} + local/combine_data.sh "data/${train_set}_sp" ${_dirs} else log "Skip stage 2: Speed perturbation" fi @@ -473,7 +473,7 @@ if ! "${skip_data_prep}"; then _suf="" fi fi - utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" + local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/" @@ -488,7 +488,7 @@ if ! "${skip_data_prep}"; then _opts+="--segments data/${dset}/segments " fi # shellcheck disable=SC2086 - scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ + local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ --audio-format "${audio_format}" --fs "${fs}" ${_opts} \ "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}" @@ -515,7 +515,7 @@ if ! "${skip_data_prep}"; then for dset in $rm_dset; do # Copy data dir - utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}" + local/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}" cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type" # Remove short utterances @@ -564,7 +564,7 @@ if ! "${skip_data_prep}"; then awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text" # fix_data_dir.sh leaves only utts which exist in all files - utils/fix_data_dir.sh "${data_feats}/${dset}" + local/fix_data_dir.sh "${data_feats}/${dset}" # generate uttid cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid" @@ -1283,6 +1283,7 @@ if ! "${skip_eval}"; then ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ python -m funasr.bin.asr_inference_launch \ --batch_size 1 \ + --mc True \ --nbest 1 \ --ngpu "${_ngpu}" \ --njob ${njob_infer} \ @@ -1312,10 +1313,10 @@ if ! "${skip_eval}"; then _data="${data_feats}/${dset}" _dir="${asr_exp}/${inference_tag}/${dset}" - python local/proce_text.py ${_data}/text ${_data}/text.proc - python local/proce_text.py ${_dir}/text ${_dir}/text.proc + python utils/proce_text.py ${_data}/text ${_data}/text.proc + python utils/proce_text.py ${_dir}/text ${_dir}/text.proc - python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt cat ${_dir}/text.cer.txt @@ -1390,6 +1391,7 @@ if ! "${skip_eval}"; then ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ python -m funasr.bin.asr_inference_launch \ --batch_size 1 \ + --mc True \ --nbest 1 \ --ngpu "${_ngpu}" \ --njob ${njob_infer} \ @@ -1421,10 +1423,10 @@ if ! "${skip_eval}"; then _data="${data_feats}/${dset}" _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}" - python local/proce_text.py ${_data}/text ${_data}/text.proc - python local/proce_text.py ${_dir}/text ${_dir}/text.proc + python utils/proce_text.py ${_data}/text ${_data}/text.proc + python utils/proce_text.py ${_dir}/text ${_dir}/text.proc - python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt cat ${_dir}/text.cer.txt @@ -1506,6 +1508,7 @@ if ! "${skip_eval}"; then ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ python -m funasr.bin.asr_inference_launch \ --batch_size 1 \ + --mc True \ --nbest 1 \ --ngpu "${_ngpu}" \ --njob ${njob_infer} \ @@ -1536,10 +1539,10 @@ if ! "${skip_eval}"; then _data="${data_feats}/${dset}" _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" - python local/proce_text.py ${_data}/text ${_data}/text.proc - python local/proce_text.py ${_dir}/text ${_dir}/text.proc + python utils/proce_text.py ${_data}/text ${_data}/text.proc + python utils/proce_text.py ${_dir}/text ${_dir}/text.proc - python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt cat ${_dir}/text.cer.txt diff --git a/egs/alimeeting/sa-asr/asr_local_infer.sh b/egs/alimeeting/sa-asr/asr_local_infer.sh index 8e8148ff8..b7a928977 100755 --- a/egs/alimeeting/sa-asr/asr_local_infer.sh +++ b/egs/alimeeting/sa-asr/asr_local_infer.sh @@ -436,7 +436,7 @@ if ! "${skip_data_prep}"; then _suf="" - utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" + local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur} _opts= @@ -548,6 +548,7 @@ if ! "${skip_eval}"; then ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ python -m funasr.bin.asr_inference_launch \ --batch_size 1 \ + --mc True \ --nbest 1 \ --ngpu "${_ngpu}" \ --njob ${njob_infer} \ diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml index a8c996875..7865763b8 100644 --- a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml +++ b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml @@ -4,7 +4,6 @@ frontend_conf: n_fft: 400 win_length: 400 hop_length: 160 - use_channel: 0 # encoder related encoder: conformer diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml index e91db1804..421d7df5b 100644 --- a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml +++ b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml @@ -4,7 +4,6 @@ frontend_conf: n_fft: 400 win_length: 400 hop_length: 160 - use_channel: 0 # encoder related asr_encoder: conformer diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh index 8151bae30..7d39cdc14 100755 --- a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh +++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh @@ -78,7 +78,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1 #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk - utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt + local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1 sed -e 's/!//g' $near_dir/tmp1> $near_dir/tmp2 @@ -109,7 +109,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk - utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt + local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1 sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2 @@ -121,8 +121,8 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then log "stage 3: finali data process" - utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near - utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far + local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near + local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo @@ -146,10 +146,10 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text - utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt + local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt - ./utils/fix_data_dir.sh $far_single_speaker_dir - utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker + ./local/fix_data_dir.sh $far_single_speaker_dir + local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker # remove space in text for x in ${tgt}_Ali_far_single_speaker; do diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh index 382a05669..e3ce934dc 100755 --- a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh +++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh @@ -77,7 +77,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk - utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt + local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1 sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2 @@ -89,7 +89,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then log "stage 2: finali data process" - utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far + local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo @@ -113,10 +113,10 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text - utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt + local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt - ./utils/fix_data_dir.sh $far_single_speaker_dir - utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker + ./local/fix_data_dir.sh $far_single_speaker_dir + local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker # remove space in text for x in ${tgt}_Ali_far_single_speaker; do diff --git a/egs/alimeeting/sa-asr/utils/apply_map.pl b/egs/alimeeting/sa-asr/local/apply_map.pl similarity index 100% rename from egs/alimeeting/sa-asr/utils/apply_map.pl rename to egs/alimeeting/sa-asr/local/apply_map.pl diff --git a/egs/alimeeting/sa-asr/utils/combine_data.sh b/egs/alimeeting/sa-asr/local/combine_data.sh similarity index 96% rename from egs/alimeeting/sa-asr/utils/combine_data.sh rename to egs/alimeeting/sa-asr/local/combine_data.sh index e1eba8539..a3436b503 100755 --- a/egs/alimeeting/sa-asr/utils/combine_data.sh +++ b/egs/alimeeting/sa-asr/local/combine_data.sh @@ -98,7 +98,7 @@ if $has_segments; then for in_dir in $*; do if [ ! -f $in_dir/segments ]; then echo "$0 [info]: will generate missing segments for $in_dir" 1>&2 - utils/data/get_segments_for_data.sh $in_dir + local/data/get_segments_for_data.sh $in_dir else cat $in_dir/segments fi @@ -133,14 +133,14 @@ for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn fi done -utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt +local/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt if [[ $dir_with_frame_shift ]]; then cp $dir_with_frame_shift/frame_shift $dest fi if ! $skip_fix ; then - utils/fix_data_dir.sh $dest || exit 1; + local/fix_data_dir.sh $dest || exit 1; fi exit 0 diff --git a/egs/alimeeting/sa-asr/utils/copy_data_dir.sh b/egs/alimeeting/sa-asr/local/copy_data_dir.sh similarity index 80% rename from egs/alimeeting/sa-asr/utils/copy_data_dir.sh rename to egs/alimeeting/sa-asr/local/copy_data_dir.sh index 9fd420c42..6e748dd9f 100755 --- a/egs/alimeeting/sa-asr/utils/copy_data_dir.sh +++ b/egs/alimeeting/sa-asr/local/copy_data_dir.sh @@ -71,25 +71,25 @@ else cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq fi -cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \ - utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk +cat $srcdir/utt2spk | local/apply_map.pl -f 1 $destdir/utt_map | \ + local/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk -utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt +local/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt if [ -f $srcdir/feats.scp ]; then - utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp + local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp fi if [ -f $srcdir/vad.scp ]; then - utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp + local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp fi if [ -f $srcdir/segments ]; then - utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments + local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments cp $srcdir/wav.scp $destdir else # no segments->wav indexed by utt. if [ -f $srcdir/wav.scp ]; then - utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp + local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp fi fi @@ -98,26 +98,26 @@ if [ -f $srcdir/reco2file_and_channel ]; then fi if [ -f $srcdir/text ]; then - utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text + local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text fi if [ -f $srcdir/utt2dur ]; then - utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur + local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur fi if [ -f $srcdir/utt2num_frames ]; then - utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames + local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames fi if [ -f $srcdir/reco2dur ]; then if [ -f $srcdir/segments ]; then cp $srcdir/reco2dur $destdir/reco2dur else - utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur + local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur fi fi if [ -f $srcdir/spk2gender ]; then - utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender + local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender fi if [ -f $srcdir/cmvn.scp ]; then - utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp + local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp fi for f in frame_shift stm glm ctm; do if [ -f $srcdir/$f ]; then @@ -142,4 +142,4 @@ done [ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" [ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text" -utils/validate_data_dir.sh $validate_opts $destdir +local/validate_data_dir.sh $validate_opts $destdir diff --git a/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh b/egs/alimeeting/sa-asr/local/data/get_reco2dur.sh similarity index 100% rename from egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh rename to egs/alimeeting/sa-asr/local/data/get_reco2dur.sh diff --git a/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh b/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh similarity index 93% rename from egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh rename to egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh index 6b161b31e..93107157b 100755 --- a/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh +++ b/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh @@ -20,7 +20,7 @@ fi data=$1 if [ ! -s $data/utt2dur ]; then - utils/data/get_utt2dur.sh $data 1>&2 || exit 1; + local/data/get_utt2dur.sh $data 1>&2 || exit 1; fi # 0 diff --git a/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh b/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh similarity index 99% rename from egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh rename to egs/alimeeting/sa-asr/local/data/get_utt2dur.sh index 5ee7ea30d..833a7fc59 100755 --- a/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh +++ b/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh @@ -94,7 +94,7 @@ elif [ -f $data/wav.scp ]; then nj=$num_utts fi - utils/data/split_data.sh --per-utt $data $nj + local/data/split_data.sh --per-utt $data $nj sdata=$data/split${nj}utt $cmd JOB=1:$nj $data/log/get_durations.JOB.log \ diff --git a/egs/alimeeting/sa-asr/utils/data/split_data.sh b/egs/alimeeting/sa-asr/local/data/split_data.sh similarity index 96% rename from egs/alimeeting/sa-asr/utils/data/split_data.sh rename to egs/alimeeting/sa-asr/local/data/split_data.sh index 8aa71a1f2..97ad8c53b 100755 --- a/egs/alimeeting/sa-asr/utils/data/split_data.sh +++ b/egs/alimeeting/sa-asr/local/data/split_data.sh @@ -60,11 +60,11 @@ nf=`cat $data/feats.scp 2>/dev/null | wc -l` nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can " - echo "** use utils/fix_data_dir.sh $data to fix this." + echo "** use local/fix_data_dir.sh $data to fix this." fi if [ -f $data/text ] && [ $nu -ne $nt ]; then echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can " - echo "** use utils/fix_data_dir.sh to fix this." + echo "** use local/fix_data_dir.sh to fix this." fi @@ -112,7 +112,7 @@ utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 for n in `seq $numsplit`; do dsn=$data/split${numsplit}${utt}/$n - utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1; + local/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1; done maybe_wav_scp= diff --git a/egs/alimeeting/sa-asr/utils/fix_data_dir.sh b/egs/alimeeting/sa-asr/local/fix_data_dir.sh similarity index 97% rename from egs/alimeeting/sa-asr/utils/fix_data_dir.sh rename to egs/alimeeting/sa-asr/local/fix_data_dir.sh index ed4710d0b..3abd4652a 100755 --- a/egs/alimeeting/sa-asr/utils/fix_data_dir.sh +++ b/egs/alimeeting/sa-asr/local/fix_data_dir.sh @@ -112,7 +112,7 @@ function filter_recordings { function filter_speakers { # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... - utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers for s in cmvn.scp spk2gender; do @@ -123,7 +123,7 @@ function filter_speakers { done filter_file $tmpdir/speakers $data/spk2utt - utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk + local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk for s in cmvn.scp spk2gender $spk_extra_files; do f=$data/$s @@ -210,6 +210,6 @@ filter_utts filter_speakers filter_recordings -utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt +local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.py b/egs/alimeeting/sa-asr/local/format_wav_scp.py new file mode 100755 index 000000000..1fd63d690 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/format_wav_scp.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +import argparse +import logging +from io import BytesIO +from pathlib import Path +from typing import Tuple, Optional + +import kaldiio +import humanfriendly +import numpy as np +import resampy +import soundfile +from tqdm import tqdm +from typeguard import check_argument_types + +from funasr.utils.cli_utils import get_commandline_args +from funasr.fileio.read_text import read_2column_text +from funasr.fileio.sound_scp import SoundScpWriter + + +def humanfriendly_or_none(value: str): + if value in ("none", "None", "NONE"): + return None + return humanfriendly.parse_size(value) + + +def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]: + """ + + >>> str2int_tuple('3,4,5') + (3, 4, 5) + + """ + assert check_argument_types() + if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"): + return None + return tuple(map(int, integers.strip().split(","))) + + +def main(): + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + logging.basicConfig(level=logging.INFO, format=logfmt) + logging.info(get_commandline_args()) + + parser = argparse.ArgumentParser( + description='Create waves list from "wav.scp"', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("scp") + parser.add_argument("outdir") + parser.add_argument( + "--name", + default="wav", + help="Specify the prefix word of output file name " 'such as "wav.scp"', + ) + parser.add_argument("--segments", default=None) + parser.add_argument( + "--fs", + type=humanfriendly_or_none, + default=None, + help="If the sampling rate specified, " "Change the sampling rate.", + ) + parser.add_argument("--audio-format", default="wav") + group = parser.add_mutually_exclusive_group() + group.add_argument("--ref-channels", default=None, type=str2int_tuple) + group.add_argument("--utt2ref-channels", default=None, type=str) + args = parser.parse_args() + + out_num_samples = Path(args.outdir) / f"utt2num_samples" + + if args.ref_channels is not None: + + def utt2ref_channels(x) -> Tuple[int, ...]: + return args.ref_channels + + elif args.utt2ref_channels is not None: + utt2ref_channels_dict = read_2column_text(args.utt2ref_channels) + + def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]: + chs_str = d[x] + return tuple(map(int, chs_str.split())) + + else: + utt2ref_channels = None + + Path(args.outdir).mkdir(parents=True, exist_ok=True) + out_wavscp = Path(args.outdir) / f"{args.name}.scp" + if args.segments is not None: + # Note: kaldiio supports only wav-pcm-int16le file. + loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments) + if args.audio_format.endswith("ark"): + fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") + fscp = out_wavscp.open("w") + else: + writer = SoundScpWriter( + args.outdir, + out_wavscp, + format=args.audio_format, + ) + + with out_num_samples.open("w") as fnum_samples: + for uttid, (rate, wave) in tqdm(loader): + # wave: (Time,) or (Time, Nmic) + if wave.ndim == 2 and utt2ref_channels is not None: + wave = wave[:, utt2ref_channels(uttid)] + + if args.fs is not None and args.fs != rate: + # FIXME(kamo): To use sox? + wave = resampy.resample( + wave.astype(np.float64), rate, args.fs, axis=0 + ) + wave = wave.astype(np.int16) + rate = args.fs + if args.audio_format.endswith("ark"): + if "flac" in args.audio_format: + suf = "flac" + elif "wav" in args.audio_format: + suf = "wav" + else: + raise RuntimeError("wav.ark or flac") + + # NOTE(kamo): Using extended ark format style here. + # This format is incompatible with Kaldi + kaldiio.save_ark( + fark, + {uttid: (wave, rate)}, + scp=fscp, + append=True, + write_function=f"soundfile_{suf}", + ) + + else: + writer[uttid] = rate, wave + fnum_samples.write(f"{uttid} {len(wave)}\n") + else: + if args.audio_format.endswith("ark"): + fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") + else: + wavdir = Path(args.outdir) / f"data_{args.name}" + wavdir.mkdir(parents=True, exist_ok=True) + + with Path(args.scp).open("r") as fscp, out_wavscp.open( + "w" + ) as fout, out_num_samples.open("w") as fnum_samples: + for line in tqdm(fscp): + uttid, wavpath = line.strip().split(None, 1) + + if wavpath.endswith("|"): + # Streaming input e.g. cat a.wav | + with kaldiio.open_like_kaldi(wavpath, "rb") as f: + with BytesIO(f.read()) as g: + wave, rate = soundfile.read(g, dtype=np.int16) + if wave.ndim == 2 and utt2ref_channels is not None: + wave = wave[:, utt2ref_channels(uttid)] + + if args.fs is not None and args.fs != rate: + # FIXME(kamo): To use sox? + wave = resampy.resample( + wave.astype(np.float64), rate, args.fs, axis=0 + ) + wave = wave.astype(np.int16) + rate = args.fs + + if args.audio_format.endswith("ark"): + if "flac" in args.audio_format: + suf = "flac" + elif "wav" in args.audio_format: + suf = "wav" + else: + raise RuntimeError("wav.ark or flac") + + # NOTE(kamo): Using extended ark format style here. + # This format is incompatible with Kaldi + kaldiio.save_ark( + fark, + {uttid: (wave, rate)}, + scp=fout, + append=True, + write_function=f"soundfile_{suf}", + ) + else: + owavpath = str(wavdir / f"{uttid}.{args.audio_format}") + soundfile.write(owavpath, wave, rate) + fout.write(f"{uttid} {owavpath}\n") + else: + wave, rate = soundfile.read(wavpath, dtype=np.int16) + if wave.ndim == 2 and utt2ref_channels is not None: + wave = wave[:, utt2ref_channels(uttid)] + save_asis = False + + elif args.audio_format.endswith("ark"): + save_asis = False + + elif Path(wavpath).suffix == "." + args.audio_format and ( + args.fs is None or args.fs == rate + ): + save_asis = True + + else: + save_asis = False + + if save_asis: + # Neither --segments nor --fs are specified and + # the line doesn't end with "|", + # i.e. not using unix-pipe, + # only in this case, + # just using the original file as is. + fout.write(f"{uttid} {wavpath}\n") + else: + if args.fs is not None and args.fs != rate: + # FIXME(kamo): To use sox? + wave = resampy.resample( + wave.astype(np.float64), rate, args.fs, axis=0 + ) + wave = wave.astype(np.int16) + rate = args.fs + + if args.audio_format.endswith("ark"): + if "flac" in args.audio_format: + suf = "flac" + elif "wav" in args.audio_format: + suf = "wav" + else: + raise RuntimeError("wav.ark or flac") + + # NOTE(kamo): Using extended ark format style here. + # This format is not supported in Kaldi. + kaldiio.save_ark( + fark, + {uttid: (wave, rate)}, + scp=fout, + append=True, + write_function=f"soundfile_{suf}", + ) + else: + owavpath = str(wavdir / f"{uttid}.{args.audio_format}") + soundfile.write(owavpath, wave, rate) + fout.write(f"{uttid} {owavpath}\n") + fnum_samples.write(f"{uttid} {len(wave)}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.sh b/egs/alimeeting/sa-asr/local/format_wav_scp.sh new file mode 100755 index 000000000..04fc4a59e --- /dev/null +++ b/egs/alimeeting/sa-asr/local/format_wav_scp.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +set -euo pipefail +SECONDS=0 +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} +help_message=$(cat << EOF +Usage: $0 [ []] +e.g. +$0 data/test/wav.scp data/test_format/ + +Format 'wav.scp': In short words, +changing "kaldi-datadir" to "modified-kaldi-datadir" + +The 'wav.scp' format in kaldi is very flexible, +e.g. It can use unix-pipe as describing that wav file, +but it sometime looks confusing and make scripts more complex. +This tools creates actual wav files from 'wav.scp' +and also segments wav files using 'segments'. + +Options + --fs + --segments + --nj + --cmd +EOF +) + +out_filename=wav.scp +cmd=utils/run.pl +nj=30 +fs=none +segments= + +ref_channels= +utt2ref_channels= + +audio_format=wav +write_utt2num_samples=true + +log "$0 $*" +. utils/parse_options.sh + +if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then + log "${help_message}" + log "Error: invalid command line arguments" + exit 1 +fi + +. ./path.sh # Setup the environment + +scp=$1 +if [ ! -f "${scp}" ]; then + log "${help_message}" + echo "$0: Error: No such file: ${scp}" + exit 1 +fi +dir=$2 + + +if [ $# -eq 2 ]; then + logdir=${dir}/logs + outdir=${dir}/data + +elif [ $# -eq 3 ]; then + logdir=$3 + outdir=${dir}/data + +elif [ $# -eq 4 ]; then + logdir=$3 + outdir=$4 +fi + + +mkdir -p ${logdir} + +rm -f "${dir}/${out_filename}" + + +opts= +if [ -n "${utt2ref_channels}" ]; then + opts="--utt2ref-channels ${utt2ref_channels} " +elif [ -n "${ref_channels}" ]; then + opts="--ref-channels ${ref_channels} " +fi + + +if [ -n "${segments}" ]; then + log "[info]: using ${segments}" + nutt=$(<${segments} wc -l) + nj=$((nj /dev/null + +# concatenate the .scp files together. +for n in $(seq ${nj}); do + cat "${outdir}/format.${n}/wav.scp" || exit 1; +done > "${dir}/${out_filename}" || exit 1 + +if "${write_utt2num_samples}"; then + for n in $(seq ${nj}); do + cat "${outdir}/format.${n}/utt2num_samples" || exit 1; + done > "${dir}/utt2num_samples" || exit 1 +fi + +log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh new file mode 100755 index 000000000..9e08dba72 --- /dev/null +++ b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash + +# 2020 @kamo-naoyuki +# This file was copied from Kaldi and +# I deleted parts related to wav duration +# because we shouldn't use kaldi's command here +# and we don't need the files actually. + +# Copyright 2013 Johns Hopkins University (author: Daniel Povey) +# 2014 Tom Ko +# 2018 Emotech LTD (author: Pawel Swietojanski) +# Apache 2.0 + +# This script operates on a directory, such as in data/train/, +# that contains some subset of the following files: +# wav.scp +# spk2utt +# utt2spk +# text +# +# It generates the files which are used for perturbing the speed of the original data. + +export LC_ALL=C +set -euo pipefail + +if [[ $# != 3 ]]; then + echo "Usage: perturb_data_dir_speed.sh " + echo "e.g.:" + echo " $0 0.9 data/train_si284 data/train_si284p" + exit 1 +fi + +factor=$1 +srcdir=$2 +destdir=$3 +label="sp" +spk_prefix="${label}${factor}-" +utt_prefix="${label}${factor}-" + +#check is sox on the path + +! command -v sox &>/dev/null && echo "sox: command not found" && exit 1; + +if [[ ! -f ${srcdir}/utt2spk ]]; then + echo "$0: no such file ${srcdir}/utt2spk" + exit 1; +fi + +if [[ ${destdir} == "${srcdir}" ]]; then + echo "$0: this script requires and to be different." + exit 1 +fi + +mkdir -p "${destdir}" + +<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map" +<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map" +<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map" +if [[ ! -f ${srcdir}/utt2uniq ]]; then + <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq" +else + <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq" +fi + + +<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \ + utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk + +utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt + +if [[ -f ${srcdir}/segments ]]; then + + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \ + utils/apply_map.pl -f 2 "${destdir}"/reco_map | \ + awk -v factor="${factor}" \ + '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \ + >"${destdir}"/segments + + utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ + # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" + awk -v factor="${factor}" \ + '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} + else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } + else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ + > "${destdir}"/wav.scp + if [[ -f ${srcdir}/reco2file_and_channel ]]; then + utils/apply_map.pl -f 1 "${destdir}"/reco_map \ + <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel + fi + +else # no segments->wav indexed by utterance. + if [[ -f ${srcdir}/wav.scp ]]; then + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ + # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" + awk -v factor="${factor}" \ + '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} + else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } + else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ + > "${destdir}"/wav.scp + fi +fi + +if [[ -f ${srcdir}/text ]]; then + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text +fi +if [[ -f ${srcdir}/spk2gender ]]; then + utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender +fi +if [[ -f ${srcdir}/utt2lang ]]; then + utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang +fi + +rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null +echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}" + +utils/validate_data_dir.sh --no-feats --no-text "${destdir}" diff --git a/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl similarity index 100% rename from egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl rename to egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl diff --git a/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl similarity index 100% rename from egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl rename to egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl diff --git a/egs/alimeeting/sa-asr/utils/validate_data_dir.sh b/egs/alimeeting/sa-asr/local/validate_data_dir.sh similarity index 99% rename from egs/alimeeting/sa-asr/utils/validate_data_dir.sh rename to egs/alimeeting/sa-asr/local/validate_data_dir.sh index 3eec443a0..37c99aec0 100755 --- a/egs/alimeeting/sa-asr/utils/validate_data_dir.sh +++ b/egs/alimeeting/sa-asr/local/validate_data_dir.sh @@ -113,7 +113,7 @@ fi check_sorted_and_uniq $data/spk2utt ! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \ - <(utils/spk2utt_to_utt2spk.pl $data/spk2utt) && \ + <(local/spk2utt_to_utt2spk.pl $data/spk2utt) && \ echo "$0: spk2utt and utt2spk do not seem to match" && exit 1; cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts @@ -135,7 +135,7 @@ if ! $no_text; then echo "$0: text contains $n_non_print lines with non-printable characters" &&\ exit 1; fi - utils/validate_text.pl $data/text || exit 1; + local/validate_text.pl $data/text || exit 1; check_sorted_and_uniq $data/text text_len=`cat $data/text | wc -l` illegal_sym_list=" #0" diff --git a/egs/alimeeting/sa-asr/utils/validate_text.pl b/egs/alimeeting/sa-asr/local/validate_text.pl similarity index 100% rename from egs/alimeeting/sa-asr/utils/validate_text.pl rename to egs/alimeeting/sa-asr/local/validate_text.pl diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa-asr/path.sh index 3aa13d0c2..5721f3f48 100755 --- a/egs/alimeeting/sa-asr/path.sh +++ b/egs/alimeeting/sa-asr/path.sh @@ -2,5 +2,4 @@ export FUNASR_DIR=$PWD/../../.. # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C export PYTHONIOENCODING=UTF-8 -export PATH=$FUNASR_DIR/funasr/bin:$PATH -export PATH=$PWD/utils/:$PATH \ No newline at end of file +export PATH=$FUNASR_DIR/funasr/bin:$PATH \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/utils b/egs/alimeeting/sa-asr/utils new file mode 120000 index 000000000..fe070dd3a --- /dev/null +++ b/egs/alimeeting/sa-asr/utils @@ -0,0 +1 @@ +../../aishell/transformer/utils \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/utils/filter_scp.pl b/egs/alimeeting/sa-asr/utils/filter_scp.pl deleted file mode 100755 index b76d37f41..000000000 --- a/egs/alimeeting/sa-asr/utils/filter_scp.pl +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env perl -# Copyright 2010-2012 Microsoft Corporation -# Johns Hopkins University (author: Daniel Povey) - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# This script takes a list of utterance-ids or any file whose first field -# of each line is an utterance-id, and filters an scp -# file (or any file whose "n-th" field is an utterance id), printing -# out only those lines whose "n-th" field is in id_list. The index of -# the "n-th" field is 1, by default, but can be changed by using -# the -f switch - -$exclude = 0; -$field = 1; -$shifted = 0; - -do { - $shifted=0; - if ($ARGV[0] eq "--exclude") { - $exclude = 1; - shift @ARGV; - $shifted=1; - } - if ($ARGV[0] eq "-f") { - $field = $ARGV[1]; - shift @ARGV; shift @ARGV; - $shifted=1 - } -} while ($shifted); - -if(@ARGV < 1 || @ARGV > 2) { - die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . - "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . - "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . - "only the lines that were *not* in id_list.\n" . - "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . - "If your older scripts (written before Oct 2014) stopped working and you used the\n" . - "-f option, add 1 to the argument.\n" . - "See also: utils/filter_scp.pl .\n"; -} - - -$idlist = shift @ARGV; -open(F, "<$idlist") || die "Could not open id-list file $idlist"; -while() { - @A = split; - @A>=1 || die "Invalid id-list file line $_"; - $seen{$A[0]} = 1; -} - -if ($field == 1) { # Treat this as special case, since it is common. - while(<>) { - $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; - # $1 is what we filter on. - if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { - print $_; - } - } -} else { - while(<>) { - @A = split; - @A > 0 || die "Invalid scp file line $_"; - @A >= $field || die "Invalid scp file line $_"; - if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { - print $_; - } - } -} - -# tests: -# the following should print "foo 1" -# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) -# the following should print "bar 2". -# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/alimeeting/sa-asr/utils/parse_options.sh b/egs/alimeeting/sa-asr/utils/parse_options.sh deleted file mode 100755 index 71fb9e5ea..000000000 --- a/egs/alimeeting/sa-asr/utils/parse_options.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash - -# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); -# Arnab Ghoshal, Karel Vesely - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# Parse command-line options. -# To be sourced by another script (as in ". parse_options.sh"). -# Option format is: --option-name arg -# and shell variable "option_name" gets set to value "arg." -# The exception is --help, which takes no arguments, but prints the -# $help_message variable (if defined). - - -### -### The --config file options have lower priority to command line -### options, so we need to import them first... -### - -# Now import all the configs specified by command-line, in left-to-right order -for ((argpos=1; argpos<$#; argpos++)); do - if [ "${!argpos}" == "--config" ]; then - argpos_plus1=$((argpos+1)) - config=${!argpos_plus1} - [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 - . $config # source the config file. - fi -done - - -### -### Now we process the command line options -### -while true; do - [ -z "${1:-}" ] && break; # break if there are no arguments - case "$1" in - # If the enclosing script is called with --help option, print the help - # message and exit. Scripts should put help messages in $help_message - --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; - else printf "$help_message\n" 1>&2 ; fi; - exit 0 ;; - --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" - exit 1 ;; - # If the first command-line argument begins with "--" (e.g. --foo-bar), - # then work out the variable name as $name, which will equal "foo_bar". - --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; - # Next we test whether the variable in question is undefned-- if so it's - # an invalid option and we die. Note: $0 evaluates to the name of the - # enclosing script. - # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar - # is undefined. We then have to wrap this test inside "eval" because - # foo_bar is itself inside a variable ($name). - eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; - - oldval="`eval echo \\$$name`"; - # Work out whether we seem to be expecting a Boolean argument. - if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then - was_bool=true; - else - was_bool=false; - fi - - # Set the variable to the right value-- the escaped quotes make it work if - # the option had spaces, like --cmd "queue.pl -sync y" - eval $name=\"$2\"; - - # Check that Boolean-valued arguments are really Boolean. - if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then - echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 - exit 1; - fi - shift 2; - ;; - *) break; - esac -done - - -# Check for an empty argument to the --cmd option, which can easily occur as a -# result of scripting errors. -[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; - - -true; # so this script returns exit code 0. diff --git a/egs/alimeeting/sa-asr/utils/split_scp.pl b/egs/alimeeting/sa-asr/utils/split_scp.pl deleted file mode 100755 index 0876dcb6d..000000000 --- a/egs/alimeeting/sa-asr/utils/split_scp.pl +++ /dev/null @@ -1,246 +0,0 @@ -#!/usr/bin/env perl - -# Copyright 2010-2011 Microsoft Corporation - -# See ../../COPYING for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# This program splits up any kind of .scp or archive-type file. -# If there is no utt2spk option it will work on any text file and -# will split it up with an approximately equal number of lines in -# each but. -# With the --utt2spk option it will work on anything that has the -# utterance-id as the first entry on each line; the utt2spk file is -# of the form "utterance speaker" (on each line). -# It splits it into equal size chunks as far as it can. If you use the utt2spk -# option it will make sure these chunks coincide with speaker boundaries. In -# this case, if there are more chunks than speakers (and in some other -# circumstances), some of the resulting chunks will be empty and it will print -# an error message and exit with nonzero status. -# You will normally call this like: -# split_scp.pl scp scp.1 scp.2 scp.3 ... -# or -# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ... -# Note that you can use this script to split the utt2spk file itself, -# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ... - -# You can also call the scripts like: -# split_scp.pl -j 3 0 scp scp.0 -# [note: with this option, it assumes zero-based indexing of the split parts, -# i.e. the second number must be 0 <= n < num-jobs.] - -use warnings; - -$num_jobs = 0; -$job_id = 0; -$utt2spk_file = ""; -$one_based = 0; - -for ($x = 1; $x <= 3 && @ARGV > 0; $x++) { - if ($ARGV[0] eq "-j") { - shift @ARGV; - $num_jobs = shift @ARGV; - $job_id = shift @ARGV; - } - if ($ARGV[0] =~ /--utt2spk=(.+)/) { - $utt2spk_file=$1; - shift; - } - if ($ARGV[0] eq '--one-based') { - $one_based = 1; - shift @ARGV; - } -} - -if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 || - $job_id - $one_based >= $num_jobs)) { - die "$0: Invalid job number/index values for '-j $num_jobs $job_id" . - ($one_based ? " --one-based" : "") . "'\n" -} - -$one_based - and $job_id--; - -if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) { - die -"Usage: split_scp.pl [--utt2spk=] in.scp out1.scp out2.scp ... - or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=] in.scp [out.scp] - ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n"; -} - -$error = 0; -$inscp = shift @ARGV; -if ($num_jobs == 0) { # without -j option - @OUTPUTS = @ARGV; -} else { - for ($j = 0; $j < $num_jobs; $j++) { - if ($j == $job_id) { - if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; } - else { push @OUTPUTS, "-"; } - } else { - push @OUTPUTS, "/dev/null"; - } - } -} - -if ($utt2spk_file ne "") { # We have the --utt2spk option... - open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n"; - while(<$u_fh>) { - @A = split; - @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n"; - ($u,$s) = @A; - $utt2spk{$u} = $s; - } - close $u_fh; - open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; - @spkrs = (); - while(<$i_fh>) { - @A = split; - if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; } - $u = $A[0]; - $s = $utt2spk{$u}; - defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n"; - if(!defined $spk_count{$s}) { - push @spkrs, $s; - $spk_count{$s} = 0; - $spk_data{$s} = []; # ref to new empty array. - } - $spk_count{$s}++; - push @{$spk_data{$s}}, $_; - } - # Now split as equally as possible .. - # First allocate spks to files by allocating an approximately - # equal number of speakers. - $numspks = @spkrs; # number of speakers. - $numscps = @OUTPUTS; # number of output files. - if ($numspks < $numscps) { - die "$0: Refusing to split data because number of speakers $numspks " . - "is less than the number of output .scp files $numscps\n"; - } - for($scpidx = 0; $scpidx < $numscps; $scpidx++) { - $scparray[$scpidx] = []; # [] is array reference. - } - for ($spkidx = 0; $spkidx < $numspks; $spkidx++) { - $scpidx = int(($spkidx*$numscps) / $numspks); - $spk = $spkrs[$spkidx]; - push @{$scparray[$scpidx]}, $spk; - $scpcount[$scpidx] += $spk_count{$spk}; - } - - # Now will try to reassign beginning + ending speakers - # to different scp's and see if it gets more balanced. - # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2. - # We can show that if considering changing just 2 scp's, we minimize - # this by minimizing the squared difference in sizes. This is - # equivalent to minimizing the absolute difference in sizes. This - # shows this method is bound to converge. - - $changed = 1; - while($changed) { - $changed = 0; - for($scpidx = 0; $scpidx < $numscps; $scpidx++) { - # First try to reassign ending spk of this scp. - if($scpidx < $numscps-1) { - $sz = @{$scparray[$scpidx]}; - if($sz > 0) { - $spk = $scparray[$scpidx]->[$sz-1]; - $count = $spk_count{$spk}; - $nutt1 = $scpcount[$scpidx]; - $nutt2 = $scpcount[$scpidx+1]; - if( abs( ($nutt2+$count) - ($nutt1-$count)) - < abs($nutt2 - $nutt1)) { # Would decrease - # size-diff by reassigning spk... - $scpcount[$scpidx+1] += $count; - $scpcount[$scpidx] -= $count; - pop @{$scparray[$scpidx]}; - unshift @{$scparray[$scpidx+1]}, $spk; - $changed = 1; - } - } - } - if($scpidx > 0 && @{$scparray[$scpidx]} > 0) { - $spk = $scparray[$scpidx]->[0]; - $count = $spk_count{$spk}; - $nutt1 = $scpcount[$scpidx-1]; - $nutt2 = $scpcount[$scpidx]; - if( abs( ($nutt2-$count) - ($nutt1+$count)) - < abs($nutt2 - $nutt1)) { # Would decrease - # size-diff by reassigning spk... - $scpcount[$scpidx-1] += $count; - $scpcount[$scpidx] -= $count; - shift @{$scparray[$scpidx]}; - push @{$scparray[$scpidx-1]}, $spk; - $changed = 1; - } - } - } - } - # Now print out the files... - for($scpidx = 0; $scpidx < $numscps; $scpidx++) { - $scpfile = $OUTPUTS[$scpidx]; - ($scpfile ne '-' ? open($f_fh, '>', $scpfile) - : open($f_fh, '>&', \*STDOUT)) || - die "$0: Could not open scp file $scpfile for writing: $!\n"; - $count = 0; - if(@{$scparray[$scpidx]} == 0) { - print STDERR "$0: eError: split_scp.pl producing empty .scp file " . - "$scpfile (too many splits and too few speakers?)\n"; - $error = 1; - } else { - foreach $spk ( @{$scparray[$scpidx]} ) { - print $f_fh @{$spk_data{$spk}}; - $count += $spk_count{$spk}; - } - $count == $scpcount[$scpidx] || die "Count mismatch [code error]"; - } - close($f_fh); - } -} else { - # This block is the "normal" case where there is no --utt2spk - # option and we just break into equal size chunks. - - open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; - - $numscps = @OUTPUTS; # size of array. - @F = (); - while(<$i_fh>) { - push @F, $_; - } - $numlines = @F; - if($numlines == 0) { - print STDERR "$0: error: empty input scp file $inscp\n"; - $error = 1; - } - $linesperscp = int( $numlines / $numscps); # the "whole part".. - $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n"; - $remainder = $numlines - ($linesperscp * $numscps); - ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder"; - # [just doing int() rounds down]. - $n = 0; - for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) { - $scpfile = $OUTPUTS[$scpidx]; - ($scpfile ne '-' ? open($o_fh, '>', $scpfile) - : open($o_fh, '>&', \*STDOUT)) || - die "$0: Could not open scp file $scpfile for writing: $!\n"; - for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) { - print $o_fh $F[$n++]; - } - close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n"; - } - $n == $numlines || die "$n != $numlines [code error]"; -} - -exit ($error); diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py index c18472f51..18f0add55 100644 --- a/funasr/bin/asr_inference.py +++ b/funasr/bin/asr_inference.py @@ -40,6 +40,8 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils +from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.tasks.asr import frontend_choices header_colors = '\033[95m' @@ -90,6 +92,12 @@ class Speech2Text: asr_train_config, asr_model_file, cmvn_file, device ) frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + if asr_train_args.frontend=='wav_frontend': + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval() + else: + frontend_class=frontend_choices.get_class(asr_train_args.frontend) + frontend = frontend_class(**asr_train_args.frontend_conf).eval() logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) @@ -197,12 +205,21 @@ class Speech2Text: """ assert check_argument_types() - + # Input as audio signal if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - batch = {"speech": speech, "speech_lengths": speech_lengths} + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + lfr_factor = max(1, (feats.size()[-1] // 80) - 1) + batch = {"speech": feats, "speech_lengths": feats_len} # a. To device batch = to_device(batch, device=self.device) @@ -275,6 +292,7 @@ def inference( ngram_weight: float = 0.9, nbest: int = 1, num_workers: int = 1, + mc: bool = False, **kwargs, ): inference_pipeline = inference_modelscope( @@ -305,6 +323,7 @@ def inference( ngram_weight=ngram_weight, nbest=nbest, num_workers=num_workers, + mc=mc, **kwargs, ) return inference_pipeline(data_path_and_name_and_type, raw_inputs) @@ -337,6 +356,7 @@ def inference_modelscope( ngram_weight: float = 0.9, nbest: int = 1, num_workers: int = 1, + mc: bool = False, param_dict: dict = None, **kwargs, ): @@ -406,7 +426,7 @@ def inference_modelscope( data_path_and_name_and_type, dtype=dtype, fs=fs, - mc=True, + mc=mc, batch_size=batch_size, key_file=key_file, num_workers=num_workers, @@ -415,7 +435,7 @@ def inference_modelscope( allow_variable_data_keys=allow_variable_data_keys, inference=True, ) - + finish_count = 0 file_count = 1 # 7 .Start for-loop diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index e165531f8..9a1ffe5ee 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -71,7 +71,13 @@ def get_parser(): ) group.add_argument("--key_file", type=str_or_none) group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) - + group.add_argument( + "--mc", + type=bool, + default=False, + help="MultiChannel input", + ) + group = parser.add_argument_group("The model configuration related") group.add_argument( "--vad_infer_config", diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py index c1e2cb2a1..a43472c01 100755 --- a/funasr/bin/asr_train.py +++ b/funasr/bin/asr_train.py @@ -2,14 +2,6 @@ import os -import logging - -logging.basicConfig( - level='INFO', - format=f"[{os.uname()[1].split('.')[0]}]" - f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", -) - from funasr.tasks.asr import ASRTask diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py index be63af111..ec575df2f 100644 --- a/funasr/bin/sa_asr_inference.py +++ b/funasr/bin/sa_asr_inference.py @@ -35,6 +35,8 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils +from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.tasks.asr import frontend_choices header_colors = '\033[95m' @@ -85,6 +87,12 @@ class Speech2Text: asr_train_config, asr_model_file, cmvn_file, device ) frontend = None + if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: + if asr_train_args.frontend=='wav_frontend': + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval() + else: + frontend_class=frontend_choices.get_class(asr_train_args.frontend) + frontend = frontend_class(**asr_train_args.frontend_conf).eval() logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) @@ -201,7 +209,16 @@ class Speech2Text: if isinstance(profile, np.ndarray): profile = torch.tensor(profile) - batch = {"speech": speech, "speech_lengths": speech_lengths} + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths + lfr_factor = max(1, (feats.size()[-1] // 80) - 1) + batch = {"speech": feats, "speech_lengths": feats_len} # a. To device batch = to_device(batch, device=self.device) @@ -308,6 +325,7 @@ def inference( ngram_weight: float = 0.9, nbest: int = 1, num_workers: int = 1, + mc: bool = False, **kwargs, ): inference_pipeline = inference_modelscope( @@ -338,6 +356,7 @@ def inference( ngram_weight=ngram_weight, nbest=nbest, num_workers=num_workers, + mc=mc, **kwargs, ) return inference_pipeline(data_path_and_name_and_type, raw_inputs) @@ -370,6 +389,7 @@ def inference_modelscope( ngram_weight: float = 0.9, nbest: int = 1, num_workers: int = 1, + mc: bool = False, param_dict: dict = None, **kwargs, ): @@ -437,7 +457,7 @@ def inference_modelscope( data_path_and_name_and_type, dtype=dtype, fs=fs, - mc=True, + mc=mc, batch_size=batch_size, key_file=key_file, num_workers=num_workers, diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py index c7c7c42a4..07b9b19db 100755 --- a/funasr/bin/sa_asr_train.py +++ b/funasr/bin/sa_asr_train.py @@ -2,14 +2,6 @@ import os -import logging - -logging.basicConfig( - level='INFO', - format=f"[{os.uname()[1].split('.')[0]}]" - f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", -) - from funasr.tasks.sa_asr import ASRTask diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py index 28df73fbb..8f63df9bc 100644 --- a/funasr/losses/label_smoothing_loss.py +++ b/funasr/losses/label_smoothing_loss.py @@ -79,3 +79,49 @@ class SequenceBinaryCrossEntropy(nn.Module): loss = self.criterion(pred, label) denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0] return loss.masked_fill(pad_mask, 0).sum() / denom + + +class NllLoss(nn.Module): + """Nll loss. + + :param int size: the number of class + :param int padding_idx: ignored class id + :param bool normalize_length: normalize loss by sequence length if True + :param torch.nn.Module criterion: loss function + """ + + def __init__( + self, + size, + padding_idx, + normalize_length=False, + criterion=nn.NLLLoss(reduction='none'), + ): + """Construct an LabelSmoothingLoss object.""" + super(NllLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + + def forward(self, x, target): + """Compute loss between x and target. + + :param torch.Tensor x: prediction (batch, seqlen, class) + :param torch.Tensor target: + target signal masked with self.padding_id (batch, seqlen) + :return: scalar float value + :rtype torch.Tensor + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + with torch.no_grad(): + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + kl = self.criterion(x , target) + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore, 0).sum() / denom diff --git a/funasr/models/decoder/transformer_decoder.py b/funasr/models/decoder/transformer_decoder.py index aed7f206d..45fdda818 100644 --- a/funasr/models/decoder/transformer_decoder.py +++ b/funasr/models/decoder/transformer_decoder.py @@ -13,6 +13,7 @@ from typeguard import check_argument_types from funasr.models.decoder.abs_decoder import AbsDecoder from funasr.modules.attention import MultiHeadedAttention +from funasr.modules.attention import CosineDistanceAttention from funasr.modules.dynamic_conv import DynamicConvolution from funasr.modules.dynamic_conv2d import DynamicConvolution2D from funasr.modules.embedding import PositionalEncoding @@ -763,4 +764,429 @@ class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): normalize_before, concat_after, ), - ) \ No newline at end of file + ) + +class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface): + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + spker_embedding_dim: int = 256, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + input_layer: str = "embed", + use_asr_output_layer: bool = True, + use_spk_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + ): + assert check_argument_types() + super().__init__() + attention_dim = encoder_output_size + + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(vocab_size, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") + + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + if use_asr_output_layer: + self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.asr_output_layer = None + + if use_spk_output_layer: + self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim) + else: + self.spk_output_layer = None + + self.cos_distance_att = CosineDistanceAttention() + + self.decoder1 = None + self.decoder2 = None + self.decoder3 = None + self.decoder4 = None + + def forward( + self, + asr_hs_pad: torch.Tensor, + spk_hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + profile: torch.Tensor, + profile_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + + asr_memory = asr_hs_pad + spk_memory = spk_hs_pad + memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device) + # Spk decoder + x = self.embed(tgt) + + x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1( + x, tgt_mask, asr_memory, spk_memory, memory_mask + ) + x, tgt_mask, spk_memory, memory_mask = self.decoder2( + x, tgt_mask, spk_memory, memory_mask + ) + if self.normalize_before: + x = self.after_norm(x) + if self.spk_output_layer is not None: + x = self.spk_output_layer(x) + dn, weights = self.cos_distance_att(x, profile, profile_lens) + # Asr decoder + x, tgt_mask, asr_memory, memory_mask = self.decoder3( + z, tgt_mask, asr_memory, memory_mask, dn + ) + x, tgt_mask, asr_memory, memory_mask = self.decoder4( + x, tgt_mask, asr_memory, memory_mask + ) + + if self.normalize_before: + x = self.after_norm(x) + if self.asr_output_layer is not None: + x = self.asr_output_layer(x) + + olens = tgt_mask.sum(1) + return x, weights, olens + + + def forward_one_step( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + asr_memory: torch.Tensor, + spk_memory: torch.Tensor, + profile: torch.Tensor, + cache: List[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + + x = self.embed(tgt) + + if cache is None: + cache = [None] * (2 + len(self.decoder2) + len(self.decoder4)) + new_cache = [] + x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1( + x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0] + ) + new_cache.append(x) + for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2): + x, tgt_mask, spk_memory, _ = decoder( + x, tgt_mask, spk_memory, None, cache=c + ) + new_cache.append(x) + if self.normalize_before: + x = self.after_norm(x) + else: + x = x + if self.spk_output_layer is not None: + x = self.spk_output_layer(x) + dn, weights = self.cos_distance_att(x, profile, None) + + x, tgt_mask, asr_memory, _ = self.decoder3( + z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1] + ) + new_cache.append(x) + + for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4): + x, tgt_mask, asr_memory, _ = decoder( + x, tgt_mask, asr_memory, None, cache=c + ) + new_cache.append(x) + + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.asr_output_layer is not None: + y = torch.log_softmax(self.asr_output_layer(y), dim=-1) + + return y, weights, new_cache + + def score(self, ys, state, asr_enc, spk_enc, profile): + """Score.""" + ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0) + logp, weights, state = self.forward_one_step( + ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state + ) + return logp.squeeze(0), weights.squeeze(), state + +class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder): + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + spker_embedding_dim: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + asr_num_blocks: int = 6, + spk_num_blocks: int = 3, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_asr_output_layer: bool = True, + use_spk_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + ): + assert check_argument_types() + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + spker_embedding_dim=spker_embedding_dim, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_asr_output_layer=use_asr_output_layer, + use_spk_output_layer=use_spk_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + + self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ) + self.decoder2 = repeat( + spk_num_blocks - 1, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + + self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer( + attention_dim, + spker_embedding_dim, + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ) + self.decoder4 = repeat( + asr_num_blocks - 1, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + +class SpeakerAttributeSpkDecoderFirstLayer(nn.Module): + + def __init__( + self, + size, + self_attn, + src_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an DecoderLayer object.""" + super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None): + + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = None + if tgt_mask is not None: + tgt_q_mask = tgt_mask[:, -1:, :] + + if self.concat_after: + tgt_concat = torch.cat( + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 + ) + x = residual + self.concat_linear1(tgt_concat) + else: + x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) + if not self.normalize_before: + x = self.norm1(x) + z = x + + residual = x + if self.normalize_before: + x = self.norm1(x) + + skip = self.src_attn(x, asr_memory, spk_memory, memory_mask) + + if self.concat_after: + x_concat = torch.cat( + (x, skip), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout(skip) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, asr_memory, spk_memory, memory_mask, z + +class SpeakerAttributeAsrDecoderFirstLayer(nn.Module): + + def __init__( + self, + size, + d_size, + src_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an DecoderLayer object.""" + super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__() + self.size = size + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.norm3 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + self.spk_linear = nn.Linear(d_size, size, bias=False) + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None): + + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = None + if tgt_mask is not None: + tgt_q_mask = tgt_mask[:, -1:, :] + + x = tgt_q + if self.normalize_before: + x = self.norm2(x) + if self.concat_after: + x_concat = torch.cat( + (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) + if not self.normalize_before: + x = self.norm2(x) + residual = x + + if dn!=None: + x = x + self.spk_linear(dn) + if self.normalize_before: + x = self.norm3(x) + + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask \ No newline at end of file diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py index 0d4097ec2..f694cc2df 100644 --- a/funasr/models/e2e_sa_asr.py +++ b/funasr/models/e2e_sa_asr.py @@ -16,9 +16,8 @@ from typeguard import check_argument_types from funasr.layers.abs_normalize import AbsNormalize from funasr.losses.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 + LabelSmoothingLoss, NllLoss # noqa: H301 ) -from funasr.losses.nll_loss import NllLoss from funasr.models.ctc import CTC from funasr.models.decoder.abs_decoder import AbsDecoder from funasr.models.encoder.abs_encoder import AbsEncoder diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py index 738ec522d..7cfcbd0ea 100644 --- a/funasr/tasks/sa_asr.py +++ b/funasr/tasks/sa_asr.py @@ -28,7 +28,7 @@ from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecode from funasr.models.decoder.transformer_decoder import ( DynamicConvolution2DTransformerDecoder, # noqa: H301 ) -from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder +from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder from funasr.models.decoder.transformer_decoder import ( LightweightConvolution2DTransformerDecoder, # noqa: H301 From 49f13908deaed06bb4b0a01631e85e2833f1f051 Mon Sep 17 00:00:00 2001 From: smohan-speech Date: Sun, 7 May 2023 02:27:58 +0800 Subject: [PATCH 4/5] add speaker-attributed ASR task for alimeeting --- .../sa-asr/pyscripts/audio/format_wav_scp.py | 243 --------------- .../sa-asr/pyscripts/utils/print_args.py | 45 --- .../sa-asr/scripts/audio/format_wav_scp.sh | 142 --------- .../scripts/utils/perturb_data_dir_speed.sh | 116 ------- funasr/losses/nll_loss.py | 47 --- funasr/models/decoder/decoder_layer_sa_asr.py | 169 ---------- .../decoder/transformer_decoder_sa_asr.py | 291 ------------------ 7 files changed, 1053 deletions(-) delete mode 100755 egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py delete mode 100755 egs/alimeeting/sa-asr/pyscripts/utils/print_args.py delete mode 100755 egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh delete mode 100755 egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh delete mode 100644 funasr/losses/nll_loss.py delete mode 100644 funasr/models/decoder/decoder_layer_sa_asr.py delete mode 100644 funasr/models/decoder/transformer_decoder_sa_asr.py diff --git a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py b/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py deleted file mode 100755 index 1fd63d690..000000000 --- a/egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py +++ /dev/null @@ -1,243 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import logging -from io import BytesIO -from pathlib import Path -from typing import Tuple, Optional - -import kaldiio -import humanfriendly -import numpy as np -import resampy -import soundfile -from tqdm import tqdm -from typeguard import check_argument_types - -from funasr.utils.cli_utils import get_commandline_args -from funasr.fileio.read_text import read_2column_text -from funasr.fileio.sound_scp import SoundScpWriter - - -def humanfriendly_or_none(value: str): - if value in ("none", "None", "NONE"): - return None - return humanfriendly.parse_size(value) - - -def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]: - """ - - >>> str2int_tuple('3,4,5') - (3, 4, 5) - - """ - assert check_argument_types() - if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"): - return None - return tuple(map(int, integers.strip().split(","))) - - -def main(): - logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" - logging.basicConfig(level=logging.INFO, format=logfmt) - logging.info(get_commandline_args()) - - parser = argparse.ArgumentParser( - description='Create waves list from "wav.scp"', - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("scp") - parser.add_argument("outdir") - parser.add_argument( - "--name", - default="wav", - help="Specify the prefix word of output file name " 'such as "wav.scp"', - ) - parser.add_argument("--segments", default=None) - parser.add_argument( - "--fs", - type=humanfriendly_or_none, - default=None, - help="If the sampling rate specified, " "Change the sampling rate.", - ) - parser.add_argument("--audio-format", default="wav") - group = parser.add_mutually_exclusive_group() - group.add_argument("--ref-channels", default=None, type=str2int_tuple) - group.add_argument("--utt2ref-channels", default=None, type=str) - args = parser.parse_args() - - out_num_samples = Path(args.outdir) / f"utt2num_samples" - - if args.ref_channels is not None: - - def utt2ref_channels(x) -> Tuple[int, ...]: - return args.ref_channels - - elif args.utt2ref_channels is not None: - utt2ref_channels_dict = read_2column_text(args.utt2ref_channels) - - def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]: - chs_str = d[x] - return tuple(map(int, chs_str.split())) - - else: - utt2ref_channels = None - - Path(args.outdir).mkdir(parents=True, exist_ok=True) - out_wavscp = Path(args.outdir) / f"{args.name}.scp" - if args.segments is not None: - # Note: kaldiio supports only wav-pcm-int16le file. - loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments) - if args.audio_format.endswith("ark"): - fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") - fscp = out_wavscp.open("w") - else: - writer = SoundScpWriter( - args.outdir, - out_wavscp, - format=args.audio_format, - ) - - with out_num_samples.open("w") as fnum_samples: - for uttid, (rate, wave) in tqdm(loader): - # wave: (Time,) or (Time, Nmic) - if wave.ndim == 2 and utt2ref_channels is not None: - wave = wave[:, utt2ref_channels(uttid)] - - if args.fs is not None and args.fs != rate: - # FIXME(kamo): To use sox? - wave = resampy.resample( - wave.astype(np.float64), rate, args.fs, axis=0 - ) - wave = wave.astype(np.int16) - rate = args.fs - if args.audio_format.endswith("ark"): - if "flac" in args.audio_format: - suf = "flac" - elif "wav" in args.audio_format: - suf = "wav" - else: - raise RuntimeError("wav.ark or flac") - - # NOTE(kamo): Using extended ark format style here. - # This format is incompatible with Kaldi - kaldiio.save_ark( - fark, - {uttid: (wave, rate)}, - scp=fscp, - append=True, - write_function=f"soundfile_{suf}", - ) - - else: - writer[uttid] = rate, wave - fnum_samples.write(f"{uttid} {len(wave)}\n") - else: - if args.audio_format.endswith("ark"): - fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") - else: - wavdir = Path(args.outdir) / f"data_{args.name}" - wavdir.mkdir(parents=True, exist_ok=True) - - with Path(args.scp).open("r") as fscp, out_wavscp.open( - "w" - ) as fout, out_num_samples.open("w") as fnum_samples: - for line in tqdm(fscp): - uttid, wavpath = line.strip().split(None, 1) - - if wavpath.endswith("|"): - # Streaming input e.g. cat a.wav | - with kaldiio.open_like_kaldi(wavpath, "rb") as f: - with BytesIO(f.read()) as g: - wave, rate = soundfile.read(g, dtype=np.int16) - if wave.ndim == 2 and utt2ref_channels is not None: - wave = wave[:, utt2ref_channels(uttid)] - - if args.fs is not None and args.fs != rate: - # FIXME(kamo): To use sox? - wave = resampy.resample( - wave.astype(np.float64), rate, args.fs, axis=0 - ) - wave = wave.astype(np.int16) - rate = args.fs - - if args.audio_format.endswith("ark"): - if "flac" in args.audio_format: - suf = "flac" - elif "wav" in args.audio_format: - suf = "wav" - else: - raise RuntimeError("wav.ark or flac") - - # NOTE(kamo): Using extended ark format style here. - # This format is incompatible with Kaldi - kaldiio.save_ark( - fark, - {uttid: (wave, rate)}, - scp=fout, - append=True, - write_function=f"soundfile_{suf}", - ) - else: - owavpath = str(wavdir / f"{uttid}.{args.audio_format}") - soundfile.write(owavpath, wave, rate) - fout.write(f"{uttid} {owavpath}\n") - else: - wave, rate = soundfile.read(wavpath, dtype=np.int16) - if wave.ndim == 2 and utt2ref_channels is not None: - wave = wave[:, utt2ref_channels(uttid)] - save_asis = False - - elif args.audio_format.endswith("ark"): - save_asis = False - - elif Path(wavpath).suffix == "." + args.audio_format and ( - args.fs is None or args.fs == rate - ): - save_asis = True - - else: - save_asis = False - - if save_asis: - # Neither --segments nor --fs are specified and - # the line doesn't end with "|", - # i.e. not using unix-pipe, - # only in this case, - # just using the original file as is. - fout.write(f"{uttid} {wavpath}\n") - else: - if args.fs is not None and args.fs != rate: - # FIXME(kamo): To use sox? - wave = resampy.resample( - wave.astype(np.float64), rate, args.fs, axis=0 - ) - wave = wave.astype(np.int16) - rate = args.fs - - if args.audio_format.endswith("ark"): - if "flac" in args.audio_format: - suf = "flac" - elif "wav" in args.audio_format: - suf = "wav" - else: - raise RuntimeError("wav.ark or flac") - - # NOTE(kamo): Using extended ark format style here. - # This format is not supported in Kaldi. - kaldiio.save_ark( - fark, - {uttid: (wave, rate)}, - scp=fout, - append=True, - write_function=f"soundfile_{suf}", - ) - else: - owavpath = str(wavdir / f"{uttid}.{args.audio_format}") - soundfile.write(owavpath, wave, rate) - fout.write(f"{uttid} {owavpath}\n") - fnum_samples.write(f"{uttid} {len(wave)}\n") - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py b/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py deleted file mode 100755 index b0c61e5b4..000000000 --- a/egs/alimeeting/sa-asr/pyscripts/utils/print_args.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python -import sys - - -def get_commandline_args(no_executable=True): - extra_chars = [ - " ", - ";", - "&", - "|", - "<", - ">", - "?", - "*", - "~", - "`", - '"', - "'", - "\\", - "{", - "}", - "(", - ")", - ] - - # Escape the extra characters for shell - argv = [ - arg.replace("'", "'\\''") - if all(char not in arg for char in extra_chars) - else "'" + arg.replace("'", "'\\''") + "'" - for arg in sys.argv - ] - - if no_executable: - return " ".join(argv[1:]) - else: - return sys.executable + " " + " ".join(argv) - - -def main(): - print(get_commandline_args()) - - -if __name__ == "__main__": - main() diff --git a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh b/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh deleted file mode 100755 index 15e4563f1..000000000 --- a/egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail -SECONDS=0 -log() { - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} -help_message=$(cat << EOF -Usage: $0 [ []] -e.g. -$0 data/test/wav.scp data/test_format/ - -Format 'wav.scp': In short words, -changing "kaldi-datadir" to "modified-kaldi-datadir" - -The 'wav.scp' format in kaldi is very flexible, -e.g. It can use unix-pipe as describing that wav file, -but it sometime looks confusing and make scripts more complex. -This tools creates actual wav files from 'wav.scp' -and also segments wav files using 'segments'. - -Options - --fs - --segments - --nj - --cmd -EOF -) - -out_filename=wav.scp -cmd=utils/run.pl -nj=30 -fs=none -segments= - -ref_channels= -utt2ref_channels= - -audio_format=wav -write_utt2num_samples=true - -log "$0 $*" -. utils/parse_options.sh - -if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then - log "${help_message}" - log "Error: invalid command line arguments" - exit 1 -fi - -. ./path.sh # Setup the environment - -scp=$1 -if [ ! -f "${scp}" ]; then - log "${help_message}" - echo "$0: Error: No such file: ${scp}" - exit 1 -fi -dir=$2 - - -if [ $# -eq 2 ]; then - logdir=${dir}/logs - outdir=${dir}/data - -elif [ $# -eq 3 ]; then - logdir=$3 - outdir=${dir}/data - -elif [ $# -eq 4 ]; then - logdir=$3 - outdir=$4 -fi - - -mkdir -p ${logdir} - -rm -f "${dir}/${out_filename}" - - -opts= -if [ -n "${utt2ref_channels}" ]; then - opts="--utt2ref-channels ${utt2ref_channels} " -elif [ -n "${ref_channels}" ]; then - opts="--ref-channels ${ref_channels} " -fi - - -if [ -n "${segments}" ]; then - log "[info]: using ${segments}" - nutt=$(<${segments} wc -l) - nj=$((nj /dev/null - -# concatenate the .scp files together. -for n in $(seq ${nj}); do - cat "${outdir}/format.${n}/wav.scp" || exit 1; -done > "${dir}/${out_filename}" || exit 1 - -if "${write_utt2num_samples}"; then - for n in $(seq ${nj}); do - cat "${outdir}/format.${n}/utt2num_samples" || exit 1; - done > "${dir}/utt2num_samples" || exit 1 -fi - -log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh deleted file mode 100755 index 9e08dba72..000000000 --- a/egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/env bash - -# 2020 @kamo-naoyuki -# This file was copied from Kaldi and -# I deleted parts related to wav duration -# because we shouldn't use kaldi's command here -# and we don't need the files actually. - -# Copyright 2013 Johns Hopkins University (author: Daniel Povey) -# 2014 Tom Ko -# 2018 Emotech LTD (author: Pawel Swietojanski) -# Apache 2.0 - -# This script operates on a directory, such as in data/train/, -# that contains some subset of the following files: -# wav.scp -# spk2utt -# utt2spk -# text -# -# It generates the files which are used for perturbing the speed of the original data. - -export LC_ALL=C -set -euo pipefail - -if [[ $# != 3 ]]; then - echo "Usage: perturb_data_dir_speed.sh " - echo "e.g.:" - echo " $0 0.9 data/train_si284 data/train_si284p" - exit 1 -fi - -factor=$1 -srcdir=$2 -destdir=$3 -label="sp" -spk_prefix="${label}${factor}-" -utt_prefix="${label}${factor}-" - -#check is sox on the path - -! command -v sox &>/dev/null && echo "sox: command not found" && exit 1; - -if [[ ! -f ${srcdir}/utt2spk ]]; then - echo "$0: no such file ${srcdir}/utt2spk" - exit 1; -fi - -if [[ ${destdir} == "${srcdir}" ]]; then - echo "$0: this script requires and to be different." - exit 1 -fi - -mkdir -p "${destdir}" - -<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map" -<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map" -<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map" -if [[ ! -f ${srcdir}/utt2uniq ]]; then - <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq" -else - <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq" -fi - - -<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \ - utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk - -utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt - -if [[ -f ${srcdir}/segments ]]; then - - utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \ - utils/apply_map.pl -f 2 "${destdir}"/reco_map | \ - awk -v factor="${factor}" \ - '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \ - >"${destdir}"/segments - - utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ - # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" - awk -v factor="${factor}" \ - '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} - else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } - else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ - > "${destdir}"/wav.scp - if [[ -f ${srcdir}/reco2file_and_channel ]]; then - utils/apply_map.pl -f 1 "${destdir}"/reco_map \ - <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel - fi - -else # no segments->wav indexed by utterance. - if [[ -f ${srcdir}/wav.scp ]]; then - utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ - # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" - awk -v factor="${factor}" \ - '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} - else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } - else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ - > "${destdir}"/wav.scp - fi -fi - -if [[ -f ${srcdir}/text ]]; then - utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text -fi -if [[ -f ${srcdir}/spk2gender ]]; then - utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender -fi -if [[ -f ${srcdir}/utt2lang ]]; then - utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang -fi - -rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null -echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}" - -utils/validate_data_dir.sh --no-feats --no-text "${destdir}" diff --git a/funasr/losses/nll_loss.py b/funasr/losses/nll_loss.py deleted file mode 100644 index 7e4e29496..000000000 --- a/funasr/losses/nll_loss.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -from torch import nn - -class NllLoss(nn.Module): - """Nll loss. - - :param int size: the number of class - :param int padding_idx: ignored class id - :param bool normalize_length: normalize loss by sequence length if True - :param torch.nn.Module criterion: loss function - """ - - def __init__( - self, - size, - padding_idx, - normalize_length=False, - criterion=nn.NLLLoss(reduction='none'), - ): - """Construct an LabelSmoothingLoss object.""" - super(NllLoss, self).__init__() - self.criterion = criterion - self.padding_idx = padding_idx - self.size = size - self.true_dist = None - self.normalize_length = normalize_length - - def forward(self, x, target): - """Compute loss between x and target. - - :param torch.Tensor x: prediction (batch, seqlen, class) - :param torch.Tensor target: - target signal masked with self.padding_id (batch, seqlen) - :return: scalar float value - :rtype torch.Tensor - """ - assert x.size(2) == self.size - batch_size = x.size(0) - x = x.view(-1, self.size) - target = target.view(-1) - with torch.no_grad(): - ignore = target == self.padding_idx # (B,) - total = len(target) - ignore.sum().item() - target = target.masked_fill(ignore, 0) # avoid -1 index - kl = self.criterion(x , target) - denom = total if self.normalize_length else batch_size - return kl.masked_fill(ignore, 0).sum() / denom diff --git a/funasr/models/decoder/decoder_layer_sa_asr.py b/funasr/models/decoder/decoder_layer_sa_asr.py deleted file mode 100644 index 80afc5168..000000000 --- a/funasr/models/decoder/decoder_layer_sa_asr.py +++ /dev/null @@ -1,169 +0,0 @@ -import torch -from torch import nn - -from funasr.modules.layer_norm import LayerNorm - - -class SpeakerAttributeSpkDecoderFirstLayer(nn.Module): - - def __init__( - self, - size, - self_attn, - src_attn, - feed_forward, - dropout_rate, - normalize_before=True, - concat_after=False, - ): - """Construct an DecoderLayer object.""" - super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__() - self.size = size - self.self_attn = self_attn - self.src_attn = src_attn - self.feed_forward = feed_forward - self.norm1 = LayerNorm(size) - self.norm2 = LayerNorm(size) - self.dropout = nn.Dropout(dropout_rate) - self.normalize_before = normalize_before - self.concat_after = concat_after - if self.concat_after: - self.concat_linear1 = nn.Linear(size + size, size) - self.concat_linear2 = nn.Linear(size + size, size) - - def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None): - - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - - if cache is None: - tgt_q = tgt - tgt_q_mask = tgt_mask - else: - # compute only the last frame query keeping dim: max_time_out -> 1 - assert cache.shape == ( - tgt.shape[0], - tgt.shape[1] - 1, - self.size, - ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" - tgt_q = tgt[:, -1:, :] - residual = residual[:, -1:, :] - tgt_q_mask = None - if tgt_mask is not None: - tgt_q_mask = tgt_mask[:, -1:, :] - - if self.concat_after: - tgt_concat = torch.cat( - (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 - ) - x = residual + self.concat_linear1(tgt_concat) - else: - x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) - if not self.normalize_before: - x = self.norm1(x) - z = x - - residual = x - if self.normalize_before: - x = self.norm1(x) - - skip = self.src_attn(x, asr_memory, spk_memory, memory_mask) - - if self.concat_after: - x_concat = torch.cat( - (x, skip), dim=-1 - ) - x = residual + self.concat_linear2(x_concat) - else: - x = residual + self.dropout(skip) - if not self.normalize_before: - x = self.norm1(x) - - residual = x - if self.normalize_before: - x = self.norm2(x) - x = residual + self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm2(x) - - if cache is not None: - x = torch.cat([cache, x], dim=1) - - return x, tgt_mask, asr_memory, spk_memory, memory_mask, z - -class SpeakerAttributeAsrDecoderFirstLayer(nn.Module): - - def __init__( - self, - size, - d_size, - src_attn, - feed_forward, - dropout_rate, - normalize_before=True, - concat_after=False, - ): - """Construct an DecoderLayer object.""" - super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__() - self.size = size - self.src_attn = src_attn - self.feed_forward = feed_forward - self.norm1 = LayerNorm(size) - self.norm2 = LayerNorm(size) - self.norm3 = LayerNorm(size) - self.dropout = nn.Dropout(dropout_rate) - self.normalize_before = normalize_before - self.concat_after = concat_after - self.spk_linear = nn.Linear(d_size, size, bias=False) - if self.concat_after: - self.concat_linear1 = nn.Linear(size + size, size) - self.concat_linear2 = nn.Linear(size + size, size) - - def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None): - - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - - if cache is None: - tgt_q = tgt - tgt_q_mask = tgt_mask - else: - - tgt_q = tgt[:, -1:, :] - residual = residual[:, -1:, :] - tgt_q_mask = None - if tgt_mask is not None: - tgt_q_mask = tgt_mask[:, -1:, :] - - x = tgt_q - if self.normalize_before: - x = self.norm2(x) - if self.concat_after: - x_concat = torch.cat( - (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 - ) - x = residual + self.concat_linear2(x_concat) - else: - x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) - if not self.normalize_before: - x = self.norm2(x) - residual = x - - if dn!=None: - x = x + self.spk_linear(dn) - if self.normalize_before: - x = self.norm3(x) - - x = residual + self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm3(x) - - if cache is not None: - x = torch.cat([cache, x], dim=1) - - return x, tgt_mask, memory, memory_mask - - - diff --git a/funasr/models/decoder/transformer_decoder_sa_asr.py b/funasr/models/decoder/transformer_decoder_sa_asr.py deleted file mode 100644 index 949f9c898..000000000 --- a/funasr/models/decoder/transformer_decoder_sa_asr.py +++ /dev/null @@ -1,291 +0,0 @@ -from typing import Any -from typing import List -from typing import Sequence -from typing import Tuple - -import torch -from typeguard import check_argument_types - -from funasr.modules.nets_utils import make_pad_mask -from funasr.modules.attention import MultiHeadedAttention -from funasr.modules.attention import CosineDistanceAttention -from funasr.models.decoder.transformer_decoder import DecoderLayer -from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer -from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer -from funasr.modules.dynamic_conv import DynamicConvolution -from funasr.modules.dynamic_conv2d import DynamicConvolution2D -from funasr.modules.embedding import PositionalEncoding -from funasr.modules.layer_norm import LayerNorm -from funasr.modules.lightconv import LightweightConvolution -from funasr.modules.lightconv2d import LightweightConvolution2D -from funasr.modules.mask import subsequent_mask -from funasr.modules.positionwise_feed_forward import ( - PositionwiseFeedForward, # noqa: H301 -) -from funasr.modules.repeat import repeat -from funasr.modules.scorers.scorer_interface import BatchScorerInterface -from funasr.models.decoder.abs_decoder import AbsDecoder - -class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface): - - def __init__( - self, - vocab_size: int, - encoder_output_size: int, - spker_embedding_dim: int = 256, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - input_layer: str = "embed", - use_asr_output_layer: bool = True, - use_spk_output_layer: bool = True, - pos_enc_class=PositionalEncoding, - normalize_before: bool = True, - ): - assert check_argument_types() - super().__init__() - attention_dim = encoder_output_size - - if input_layer == "embed": - self.embed = torch.nn.Sequential( - torch.nn.Embedding(vocab_size, attention_dim), - pos_enc_class(attention_dim, positional_dropout_rate), - ) - elif input_layer == "linear": - self.embed = torch.nn.Sequential( - torch.nn.Linear(vocab_size, attention_dim), - torch.nn.LayerNorm(attention_dim), - torch.nn.Dropout(dropout_rate), - torch.nn.ReLU(), - pos_enc_class(attention_dim, positional_dropout_rate), - ) - else: - raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") - - self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = LayerNorm(attention_dim) - if use_asr_output_layer: - self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size) - else: - self.asr_output_layer = None - - if use_spk_output_layer: - self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim) - else: - self.spk_output_layer = None - - self.cos_distance_att = CosineDistanceAttention() - - self.decoder1 = None - self.decoder2 = None - self.decoder3 = None - self.decoder4 = None - - def forward( - self, - asr_hs_pad: torch.Tensor, - spk_hs_pad: torch.Tensor, - hlens: torch.Tensor, - ys_in_pad: torch.Tensor, - ys_in_lens: torch.Tensor, - profile: torch.Tensor, - profile_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - - tgt = ys_in_pad - # tgt_mask: (B, 1, L) - tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) - # m: (1, L, L) - m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) - # tgt_mask: (B, L, L) - tgt_mask = tgt_mask & m - - asr_memory = asr_hs_pad - spk_memory = spk_hs_pad - memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device) - # Spk decoder - x = self.embed(tgt) - - x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1( - x, tgt_mask, asr_memory, spk_memory, memory_mask - ) - x, tgt_mask, spk_memory, memory_mask = self.decoder2( - x, tgt_mask, spk_memory, memory_mask - ) - if self.normalize_before: - x = self.after_norm(x) - if self.spk_output_layer is not None: - x = self.spk_output_layer(x) - dn, weights = self.cos_distance_att(x, profile, profile_lens) - # Asr decoder - x, tgt_mask, asr_memory, memory_mask = self.decoder3( - z, tgt_mask, asr_memory, memory_mask, dn - ) - x, tgt_mask, asr_memory, memory_mask = self.decoder4( - x, tgt_mask, asr_memory, memory_mask - ) - - if self.normalize_before: - x = self.after_norm(x) - if self.asr_output_layer is not None: - x = self.asr_output_layer(x) - - olens = tgt_mask.sum(1) - return x, weights, olens - - - def forward_one_step( - self, - tgt: torch.Tensor, - tgt_mask: torch.Tensor, - asr_memory: torch.Tensor, - spk_memory: torch.Tensor, - profile: torch.Tensor, - cache: List[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, List[torch.Tensor]]: - - x = self.embed(tgt) - - if cache is None: - cache = [None] * (2 + len(self.decoder2) + len(self.decoder4)) - new_cache = [] - x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1( - x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0] - ) - new_cache.append(x) - for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2): - x, tgt_mask, spk_memory, _ = decoder( - x, tgt_mask, spk_memory, None, cache=c - ) - new_cache.append(x) - if self.normalize_before: - x = self.after_norm(x) - else: - x = x - if self.spk_output_layer is not None: - x = self.spk_output_layer(x) - dn, weights = self.cos_distance_att(x, profile, None) - - x, tgt_mask, asr_memory, _ = self.decoder3( - z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1] - ) - new_cache.append(x) - - for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4): - x, tgt_mask, asr_memory, _ = decoder( - x, tgt_mask, asr_memory, None, cache=c - ) - new_cache.append(x) - - if self.normalize_before: - y = self.after_norm(x[:, -1]) - else: - y = x[:, -1] - if self.asr_output_layer is not None: - y = torch.log_softmax(self.asr_output_layer(y), dim=-1) - - return y, weights, new_cache - - def score(self, ys, state, asr_enc, spk_enc, profile): - """Score.""" - ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0) - logp, weights, state = self.forward_one_step( - ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state - ) - return logp.squeeze(0), weights.squeeze(), state - -class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder): - def __init__( - self, - vocab_size: int, - encoder_output_size: int, - spker_embedding_dim: int = 256, - attention_heads: int = 4, - linear_units: int = 2048, - asr_num_blocks: int = 6, - spk_num_blocks: int = 3, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - self_attention_dropout_rate: float = 0.0, - src_attention_dropout_rate: float = 0.0, - input_layer: str = "embed", - use_asr_output_layer: bool = True, - use_spk_output_layer: bool = True, - pos_enc_class=PositionalEncoding, - normalize_before: bool = True, - concat_after: bool = False, - ): - assert check_argument_types() - super().__init__( - vocab_size=vocab_size, - encoder_output_size=encoder_output_size, - spker_embedding_dim=spker_embedding_dim, - dropout_rate=dropout_rate, - positional_dropout_rate=positional_dropout_rate, - input_layer=input_layer, - use_asr_output_layer=use_asr_output_layer, - use_spk_output_layer=use_spk_output_layer, - pos_enc_class=pos_enc_class, - normalize_before=normalize_before, - ) - - attention_dim = encoder_output_size - - self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer( - attention_dim, - MultiHeadedAttention( - attention_heads, attention_dim, self_attention_dropout_rate - ), - MultiHeadedAttention( - attention_heads, attention_dim, src_attention_dropout_rate - ), - PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), - dropout_rate, - normalize_before, - concat_after, - ) - self.decoder2 = repeat( - spk_num_blocks - 1, - lambda lnum: DecoderLayer( - attention_dim, - MultiHeadedAttention( - attention_heads, attention_dim, self_attention_dropout_rate - ), - MultiHeadedAttention( - attention_heads, attention_dim, src_attention_dropout_rate - ), - PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), - dropout_rate, - normalize_before, - concat_after, - ), - ) - - - self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer( - attention_dim, - spker_embedding_dim, - MultiHeadedAttention( - attention_heads, attention_dim, src_attention_dropout_rate - ), - PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), - dropout_rate, - normalize_before, - concat_after, - ) - self.decoder4 = repeat( - asr_num_blocks - 1, - lambda lnum: DecoderLayer( - attention_dim, - MultiHeadedAttention( - attention_heads, attention_dim, self_attention_dropout_rate - ), - MultiHeadedAttention( - attention_heads, attention_dim, src_attention_dropout_rate - ), - PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), - dropout_rate, - normalize_before, - concat_after, - ), - ) From af6740a2207840a772261b8a033ab9996f862529 Mon Sep 17 00:00:00 2001 From: smohan-speech Date: Mon, 8 May 2023 16:13:23 +0800 Subject: [PATCH 5/5] add speaker-attributed ASR task for alimeeting --- egs/alimeeting/sa-asr/README.md | 79 +++++++++ egs/alimeeting/sa-asr/asr_local.sh | 15 +- ...infer.sh => asr_local_m2met_2023_infer.sh} | 4 +- egs/alimeeting/sa-asr/local/compute_wer.py | 157 ------------------ .../sa-asr/local/perturb_data_dir_speed.sh | 24 +-- egs/alimeeting/sa-asr/local/proce_text.py | 32 ---- .../sa-asr/{run_m2met_2023.sh => run.sh} | 1 - egs/alimeeting/sa-asr/run_m2met_2023_infer.sh | 2 +- funasr/bin/asr_inference.py | 9 +- funasr/bin/sa_asr_inference.py | 9 +- funasr/losses/label_smoothing_loss.py | 2 +- 11 files changed, 108 insertions(+), 226 deletions(-) create mode 100644 egs/alimeeting/sa-asr/README.md rename egs/alimeeting/sa-asr/{asr_local_infer.sh => asr_local_m2met_2023_infer.sh} (99%) delete mode 100755 egs/alimeeting/sa-asr/local/compute_wer.py delete mode 100755 egs/alimeeting/sa-asr/local/proce_text.py rename egs/alimeeting/sa-asr/{run_m2met_2023.sh => run.sh} (98%) diff --git a/egs/alimeeting/sa-asr/README.md b/egs/alimeeting/sa-asr/README.md new file mode 100644 index 000000000..882345c25 --- /dev/null +++ b/egs/alimeeting/sa-asr/README.md @@ -0,0 +1,79 @@ +# Get Started +Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf). +To run this receipe, first you need to install FunASR and ModelScope. ([installation](https://alibaba-damo-academy.github.io/FunASR/en/installation.html)) +There are two startup scripts, `run.sh` for training and evaluating on the old eval and test sets, and `run_m2met_2023_infer.sh` for inference on the new test set of the Multi-Channel Multi-Party Meeting Transcription 2.0 ([M2MET2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)) Challenge. +Before running `run.sh`, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory: +```shell +dataset +|—— Eval_Ali_far +|—— Eval_Ali_near +|—— Test_Ali_far +|—— Test_Ali_near +|—— Train_Ali_far +|—— Train_Ali_near +``` +There are 18 stages in `run.sh`: +```shell +stage 1 - 5: Data preparation and processing. +stage 6: Generate speaker profiles (Stage 6 takes a lot of time). +stage 7 - 9: Language model training (Optional). +stage 10 - 11: ASR training (SA-ASR requires loading the pre-trained ASR model). +stage 12: SA-ASR training. +stage 13 - 18: Inference and evaluation. +``` +Before running `run_m2met_2023_infer.sh`, you need to place the new test set `Test_2023_Ali_far` (to be released after the challenge starts) in the `./dataset` directory, which contains only raw audios. Then put the given `wav.scp`, `wav_raw.scp`, `segments`, `utt2spk` and `spk2utt` in the `./data/Test_2023_Ali_far` directory. +```shell +data/Test_2023_Ali_far +|—— wav.scp +|—— wav_raw.scp +|—— segments +|—— utt2spk +|—— spk2utt +``` +There are 4 stages in `run_m2met_2023_infer.sh`: +```shell +stage 1: Data preparation and processing. +stage 2: Generate speaker profiles for inference. +stage 3: Inference. +stage 4: Generation of SA-ASR results required for final submission. +``` +# Format of Final Submission +Finally, you need to submit a file called `text_spk_merge` with the following format: +```shell +Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ... +Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ... +... +``` +Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code. +# Baseline Results +The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Eval and Test Set are also provided to show the impact of speaker profile accuracy. + + + + + + + + + + + + + + + + + + + + + + + + + + +
SI-CER(%)cpCER(%)
EvalTestEvalTest
oracle profile31.9332.7548.5653.33
cluster profile31.9432.7755.4958.17
+ +# Reference +N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413–4417. \ No newline at end of file diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh index 419e34144..f8cdcd3b6 100755 --- a/egs/alimeeting/sa-asr/asr_local.sh +++ b/egs/alimeeting/sa-asr/asr_local.sh @@ -475,7 +475,9 @@ if ! "${skip_data_prep}"; then fi local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" - cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/" + if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then + cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/" + fi rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur} _opts= @@ -568,8 +570,11 @@ if ! "${skip_data_prep}"; then # generate uttid cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid" - # filter utt2spk_all_fifo - python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset} + + if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then + # filter utt2spk_all_fifo + python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset} + fi done # shellcheck disable=SC2002 @@ -585,7 +590,7 @@ if ! "${skip_data_prep}"; then echo "" > ${token_list} echo "" >> ${token_list} echo "" >> ${token_list} - local/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \ + utils/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} num_token=$(cat ${token_list} | wc -l) echo "" >> ${token_list} @@ -603,6 +608,7 @@ if ! "${skip_data_prep}"; then python local/process_text_id.py ${data_feats}/${dset} log "Successfully generate ${data_feats}/${dset}/text_id_train" # generate oracle_embedding from single-speaker audio segment + log "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${dset}.log" python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log" log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)" # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training) @@ -615,6 +621,7 @@ if ! "${skip_data_prep}"; then fi # generate cluster_profile with spectral-cluster directly (for infering and without oracle information) if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then + log "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${dset}.log" python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log" log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)" fi diff --git a/egs/alimeeting/sa-asr/asr_local_infer.sh b/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh similarity index 99% rename from egs/alimeeting/sa-asr/asr_local_infer.sh rename to egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh index b7a928977..a23215c04 100755 --- a/egs/alimeeting/sa-asr/asr_local_infer.sh +++ b/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh @@ -449,7 +449,7 @@ if ! "${skip_data_prep}"; then _opts+="--segments data/${dset}/segments " fi # shellcheck disable=SC2086 - scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ + local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ --audio-format "${audio_format}" --fs "${fs}" ${_opts} \ "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}" @@ -467,7 +467,7 @@ if ! "${skip_data_prep}"; then mkdir -p "profile_log" for dset in "${test_sets}"; do # generate cluster_profile with spectral-cluster directly (for infering and without oracle information) - python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log" + python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log" log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)" done fi diff --git a/egs/alimeeting/sa-asr/local/compute_wer.py b/egs/alimeeting/sa-asr/local/compute_wer.py deleted file mode 100755 index 349a3f609..000000000 --- a/egs/alimeeting/sa-asr/local/compute_wer.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import numpy as np -import sys - -def compute_wer(ref_file, - hyp_file, - cer_detail_file): - rst = { - 'Wrd': 0, - 'Corr': 0, - 'Ins': 0, - 'Del': 0, - 'Sub': 0, - 'Snt': 0, - 'Err': 0.0, - 'S.Err': 0.0, - 'wrong_words': 0, - 'wrong_sentences': 0 - } - - hyp_dict = {} - ref_dict = {} - with open(hyp_file, 'r') as hyp_reader: - for line in hyp_reader: - key = line.strip().split()[0] - value = line.strip().split()[1:] - hyp_dict[key] = value - with open(ref_file, 'r') as ref_reader: - for line in ref_reader: - key = line.strip().split()[0] - value = line.strip().split()[1:] - ref_dict[key] = value - - cer_detail_writer = open(cer_detail_file, 'w') - for hyp_key in hyp_dict: - if hyp_key in ref_dict: - out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key]) - rst['Wrd'] += out_item['nwords'] - rst['Corr'] += out_item['cor'] - rst['wrong_words'] += out_item['wrong'] - rst['Ins'] += out_item['ins'] - rst['Del'] += out_item['del'] - rst['Sub'] += out_item['sub'] - rst['Snt'] += 1 - if out_item['wrong'] > 0: - rst['wrong_sentences'] += 1 - cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n') - cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n') - cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n') - - if rst['Wrd'] > 0: - rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) - if rst['Snt'] > 0: - rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) - - cer_detail_writer.write('\n') - cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) + - ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n') - cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n') - cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n') - - -def compute_wer_by_line(hyp, - ref): - hyp = list(map(lambda x: x.lower(), hyp)) - ref = list(map(lambda x: x.lower(), ref)) - - len_hyp = len(hyp) - len_ref = len(ref) - - cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) - - ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) - - for i in range(len_hyp + 1): - cost_matrix[i][0] = i - for j in range(len_ref + 1): - cost_matrix[0][j] = j - - for i in range(1, len_hyp + 1): - for j in range(1, len_ref + 1): - if hyp[i - 1] == ref[j - 1]: - cost_matrix[i][j] = cost_matrix[i - 1][j - 1] - else: - substitution = cost_matrix[i - 1][j - 1] + 1 - insertion = cost_matrix[i - 1][j] + 1 - deletion = cost_matrix[i][j - 1] + 1 - - compare_val = [substitution, insertion, deletion] - - min_val = min(compare_val) - operation_idx = compare_val.index(min_val) + 1 - cost_matrix[i][j] = min_val - ops_matrix[i][j] = operation_idx - - match_idx = [] - i = len_hyp - j = len_ref - rst = { - 'nwords': len_ref, - 'cor': 0, - 'wrong': 0, - 'ins': 0, - 'del': 0, - 'sub': 0 - } - while i >= 0 or j >= 0: - i_idx = max(0, i) - j_idx = max(0, j) - - if ops_matrix[i_idx][j_idx] == 0: # correct - if i - 1 >= 0 and j - 1 >= 0: - match_idx.append((j - 1, i - 1)) - rst['cor'] += 1 - - i -= 1 - j -= 1 - - elif ops_matrix[i_idx][j_idx] == 2: # insert - i -= 1 - rst['ins'] += 1 - - elif ops_matrix[i_idx][j_idx] == 3: # delete - j -= 1 - rst['del'] += 1 - - elif ops_matrix[i_idx][j_idx] == 1: # substitute - i -= 1 - j -= 1 - rst['sub'] += 1 - - if i < 0 and j >= 0: - rst['del'] += 1 - elif j < 0 and i >= 0: - rst['ins'] += 1 - - match_idx.reverse() - wrong_cnt = cost_matrix[len_hyp][len_ref] - rst['wrong'] = wrong_cnt - - return rst - -def print_cer_detail(rst): - return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor']) - + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub=" - + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords']) - + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords'])) - -if __name__ == '__main__': - if len(sys.argv) != 4: - print("usage : python compute-wer.py test.ref test.hyp test.wer") - sys.exit(0) - - ref_file = sys.argv[1] - hyp_file = sys.argv[2] - cer_detail_file = sys.argv[3] - compute_wer(ref_file, hyp_file, cer_detail_file) diff --git a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh index 9e08dba72..1022ae62b 100755 --- a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh +++ b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh @@ -63,20 +63,20 @@ else fi -<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \ - utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk +<"${srcdir}"/utt2spk local/apply_map.pl -f 1 "${destdir}"/utt_map | \ + local/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk -utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt +local/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt if [[ -f ${srcdir}/segments ]]; then - utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \ - utils/apply_map.pl -f 2 "${destdir}"/reco_map | \ + local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \ + local/apply_map.pl -f 2 "${destdir}"/reco_map | \ awk -v factor="${factor}" \ '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \ >"${destdir}"/segments - utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ + local/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" awk -v factor="${factor}" \ '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} @@ -84,13 +84,13 @@ if [[ -f ${srcdir}/segments ]]; then else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ > "${destdir}"/wav.scp if [[ -f ${srcdir}/reco2file_and_channel ]]; then - utils/apply_map.pl -f 1 "${destdir}"/reco_map \ + local/apply_map.pl -f 1 "${destdir}"/reco_map \ <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel fi else # no segments->wav indexed by utterance. if [[ -f ${srcdir}/wav.scp ]]; then - utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ + local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" awk -v factor="${factor}" \ '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} @@ -101,16 +101,16 @@ else # no segments->wav indexed by utterance. fi if [[ -f ${srcdir}/text ]]; then - utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text + local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text fi if [[ -f ${srcdir}/spk2gender ]]; then - utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender + local/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender fi if [[ -f ${srcdir}/utt2lang ]]; then - utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang + local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang fi rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}" -utils/validate_data_dir.sh --no-feats --no-text "${destdir}" +local/validate_data_dir.sh --no-feats --no-text "${destdir}" diff --git a/egs/alimeeting/sa-asr/local/proce_text.py b/egs/alimeeting/sa-asr/local/proce_text.py deleted file mode 100755 index e56cc0f37..000000000 --- a/egs/alimeeting/sa-asr/local/proce_text.py +++ /dev/null @@ -1,32 +0,0 @@ - -import sys -import re - -in_f = sys.argv[1] -out_f = sys.argv[2] - - -with open(in_f, "r", encoding="utf-8") as f: - lines = f.readlines() - -with open(out_f, "w", encoding="utf-8") as f: - for line in lines: - outs = line.strip().split(" ", 1) - if len(outs) == 2: - idx, text = outs - text = re.sub("
", "", text) - text = re.sub("", "", text) - text = re.sub("@@", "", text) - text = re.sub("@", "", text) - text = re.sub("", "", text) - text = re.sub(" ", "", text) - text = re.sub("\$", "", text) - text = text.lower() - else: - idx = outs[0] - text = " " - - text = [x for x in text] - text = " ".join(text) - out = "{} {}\n".format(idx, text) - f.write(out) diff --git a/egs/alimeeting/sa-asr/run_m2met_2023.sh b/egs/alimeeting/sa-asr/run.sh similarity index 98% rename from egs/alimeeting/sa-asr/run_m2met_2023.sh rename to egs/alimeeting/sa-asr/run.sh index 807e49948..e5297b8b3 100755 --- a/egs/alimeeting/sa-asr/run_m2met_2023.sh +++ b/egs/alimeeting/sa-asr/run.sh @@ -8,7 +8,6 @@ set -o pipefail ngpu=4 device="0,1,2,3" -#stage 1 creat both near and far stage=1 stop_stage=18 diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh index d35e6a693..1967864d8 100755 --- a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh +++ b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh @@ -22,7 +22,7 @@ inference_config=conf/decode_asr_rnn.yaml lm_config=conf/train_lm_transformer.yaml use_lm=false use_wordlm=false -./asr_local_infer.sh \ +./asr_local_m2met_2023_infer.sh \ --device ${device} \ --ngpu ${ngpu} \ --stage ${stage} \ diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py index 18f0add55..a52e94a7f 100644 --- a/funasr/bin/asr_inference.py +++ b/funasr/bin/asr_inference.py @@ -94,7 +94,7 @@ class Speech2Text: frontend = None if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: if asr_train_args.frontend=='wav_frontend': - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval() + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) else: frontend_class=frontend_choices.get_class(asr_train_args.frontend) frontend = frontend_class(**asr_train_args.frontend_conf).eval() @@ -147,13 +147,6 @@ class Speech2Text: pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) - beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() - for scorer in scorers.values(): - if isinstance(scorer, torch.nn.Module): - scorer.to(device=device, dtype=getattr(torch, dtype)).eval() - logging.info(f"Beam_search: {beam_search}") - logging.info(f"Decoding device={device}, dtype={dtype}") - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py index ec575df2f..c894f5460 100644 --- a/funasr/bin/sa_asr_inference.py +++ b/funasr/bin/sa_asr_inference.py @@ -89,7 +89,7 @@ class Speech2Text: frontend = None if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: if asr_train_args.frontend=='wav_frontend': - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval() + frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) else: frontend_class=frontend_choices.get_class(asr_train_args.frontend) frontend = frontend_class(**asr_train_args.frontend_conf).eval() @@ -142,13 +142,6 @@ class Speech2Text: pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) - beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() - for scorer in scorers.values(): - if isinstance(scorer, torch.nn.Module): - scorer.to(device=device, dtype=getattr(torch, dtype)).eval() - logging.info(f"Beam_search: {beam_search}") - logging.info(f"Decoding device={device}, dtype={dtype}") - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py index 8f63df9bc..3ea34c048 100644 --- a/funasr/losses/label_smoothing_loss.py +++ b/funasr/losses/label_smoothing_loss.py @@ -97,7 +97,7 @@ class NllLoss(nn.Module): normalize_length=False, criterion=nn.NLLLoss(reduction='none'), ): - """Construct an LabelSmoothingLoss object.""" + """Construct an NllLoss object.""" super(NllLoss, self).__init__() self.criterion = criterion self.padding_idx = padding_idx