mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add speaker-attributed ASR task for alimeeting
This commit is contained in:
parent
5fc6f7899a
commit
a73123bcfc
1562
egs/alimeeting/sa-asr/asr_local.sh
Executable file
1562
egs/alimeeting/sa-asr/asr_local.sh
Executable file
File diff suppressed because it is too large
Load Diff
590
egs/alimeeting/sa-asr/asr_local_infer.sh
Executable file
590
egs/alimeeting/sa-asr/asr_local_infer.sh
Executable file
@ -0,0 +1,590 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
log() {
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
min() {
|
||||
local a b
|
||||
a=$1
|
||||
for b in "$@"; do
|
||||
if [ "${b}" -le "${a}" ]; then
|
||||
a="${b}"
|
||||
fi
|
||||
done
|
||||
echo "${a}"
|
||||
}
|
||||
SECONDS=0
|
||||
|
||||
# General configuration
|
||||
stage=1 # Processes starts from the specified stage.
|
||||
stop_stage=10000 # Processes is stopped at the specified stage.
|
||||
skip_data_prep=false # Skip data preparation stages.
|
||||
skip_train=false # Skip training stages.
|
||||
skip_eval=false # Skip decoding and evaluation stages.
|
||||
skip_upload=true # Skip packing and uploading stages.
|
||||
ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu).
|
||||
num_nodes=1 # The number of nodes.
|
||||
nj=16 # The number of parallel jobs.
|
||||
inference_nj=16 # The number of parallel jobs in decoding.
|
||||
gpu_inference=false # Whether to perform gpu decoding.
|
||||
njob_infer=4
|
||||
dumpdir=dump2 # Directory to dump features.
|
||||
expdir=exp # Directory to save experiments.
|
||||
python=python3 # Specify python to execute espnet commands.
|
||||
device=0
|
||||
|
||||
# Data preparation related
|
||||
local_data_opts= # The options given to local/data.sh.
|
||||
|
||||
# Speed perturbation related
|
||||
speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
|
||||
|
||||
# Feature extraction related
|
||||
feats_type=raw # Feature type (raw or fbank_pitch).
|
||||
audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw).
|
||||
fs=16000 # Sampling rate.
|
||||
min_wav_duration=0.1 # Minimum duration in second.
|
||||
max_wav_duration=20 # Maximum duration in second.
|
||||
|
||||
# Tokenization related
|
||||
token_type=bpe # Tokenization type (char or bpe).
|
||||
nbpe=30 # The number of BPE vocabulary.
|
||||
bpemode=unigram # Mode of BPE (unigram or bpe).
|
||||
oov="<unk>" # Out of vocabulary symbol.
|
||||
blank="<blank>" # CTC blank symbol
|
||||
sos_eos="<sos/eos>" # sos and eos symbole
|
||||
bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
|
||||
bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE
|
||||
bpe_char_cover=1.0 # character coverage when modeling BPE
|
||||
|
||||
# Language model related
|
||||
use_lm=true # Use language model for ASR decoding.
|
||||
lm_tag= # Suffix to the result dir for language model training.
|
||||
lm_exp= # Specify the direcotry path for LM experiment.
|
||||
# If this option is specified, lm_tag is ignored.
|
||||
lm_stats_dir= # Specify the direcotry path for LM statistics.
|
||||
lm_config= # Config for language model training.
|
||||
lm_args= # Arguments for language model training, e.g., "--max_epoch 10".
|
||||
# Note that it will overwrite args in lm config.
|
||||
use_word_lm=false # Whether to use word language model.
|
||||
num_splits_lm=1 # Number of splitting for lm corpus.
|
||||
# shellcheck disable=SC2034
|
||||
word_vocab_size=10000 # Size of word vocabulary.
|
||||
|
||||
# ASR model related
|
||||
asr_tag= # Suffix to the result dir for asr model training.
|
||||
asr_exp= # Specify the direcotry path for ASR experiment.
|
||||
# If this option is specified, asr_tag is ignored.
|
||||
sa_asr_exp=
|
||||
asr_stats_dir= # Specify the direcotry path for ASR statistics.
|
||||
asr_config= # Config for asr model training.
|
||||
sa_asr_config=
|
||||
asr_args= # Arguments for asr model training, e.g., "--max_epoch 10".
|
||||
# Note that it will overwrite args in asr config.
|
||||
feats_normalize=global_mvn # Normalizaton layer type.
|
||||
num_splits_asr=1 # Number of splitting for lm corpus.
|
||||
|
||||
# Decoding related
|
||||
inference_tag= # Suffix to the result dir for decoding.
|
||||
inference_config= # Config for decoding.
|
||||
inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
|
||||
# Note that it will overwrite args in inference config.
|
||||
sa_asr_inference_tag=
|
||||
sa_asr_inference_args=
|
||||
|
||||
inference_lm=valid.loss.ave.pb # Language modle path for decoding.
|
||||
inference_asr_model=valid.acc.ave.pb # ASR model path for decoding.
|
||||
# e.g.
|
||||
# inference_asr_model=train.loss.best.pth
|
||||
# inference_asr_model=3epoch.pth
|
||||
# inference_asr_model=valid.acc.best.pth
|
||||
# inference_asr_model=valid.loss.ave.pth
|
||||
inference_sa_asr_model=valid.acc_spk.ave.pb
|
||||
download_model= # Download a model from Model Zoo and use it for decoding.
|
||||
|
||||
# [Task dependent] Set the datadir name created by local/data.sh
|
||||
train_set= # Name of training set.
|
||||
valid_set= # Name of validation set used for monitoring/tuning network training.
|
||||
test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
|
||||
bpe_train_text= # Text file path of bpe training set.
|
||||
lm_train_text= # Text file path of language model training set.
|
||||
lm_dev_text= # Text file path of language model development set.
|
||||
lm_test_text= # Text file path of language model evaluation set.
|
||||
nlsyms_txt=none # Non-linguistic symbol list if existing.
|
||||
cleaner=none # Text cleaner.
|
||||
g2p=none # g2p method (needed if token_type=phn).
|
||||
lang=zh # The language type of corpus.
|
||||
score_opts= # The options given to sclite scoring
|
||||
local_score_opts= # The options given to local/score.sh.
|
||||
|
||||
help_message=$(cat << EOF
|
||||
Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
|
||||
|
||||
Options:
|
||||
# General configuration
|
||||
--stage # Processes starts from the specified stage (default="${stage}").
|
||||
--stop_stage # Processes is stopped at the specified stage (default="${stop_stage}").
|
||||
--skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
|
||||
--skip_train # Skip training stages (default="${skip_train}").
|
||||
--skip_eval # Skip decoding and evaluation stages (default="${skip_eval}").
|
||||
--skip_upload # Skip packing and uploading stages (default="${skip_upload}").
|
||||
--ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
|
||||
--num_nodes # The number of nodes (default="${num_nodes}").
|
||||
--nj # The number of parallel jobs (default="${nj}").
|
||||
--inference_nj # The number of parallel jobs in decoding (default="${inference_nj}").
|
||||
--gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}").
|
||||
--dumpdir # Directory to dump features (default="${dumpdir}").
|
||||
--expdir # Directory to save experiments (default="${expdir}").
|
||||
--python # Specify python to execute espnet commands (default="${python}").
|
||||
--device # Which GPUs are use for local training (defalut="${device}").
|
||||
|
||||
# Data preparation related
|
||||
--local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
|
||||
|
||||
# Speed perturbation related
|
||||
--speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
|
||||
|
||||
# Feature extraction related
|
||||
--feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
|
||||
--audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}").
|
||||
--fs # Sampling rate (default="${fs}").
|
||||
--min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
|
||||
--max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
|
||||
|
||||
# Tokenization related
|
||||
--token_type # Tokenization type (char or bpe, default="${token_type}").
|
||||
--nbpe # The number of BPE vocabulary (default="${nbpe}").
|
||||
--bpemode # Mode of BPE (unigram or bpe, default="${bpemode}").
|
||||
--oov # Out of vocabulary symbol (default="${oov}").
|
||||
--blank # CTC blank symbol (default="${blank}").
|
||||
--sos_eos # sos and eos symbole (default="${sos_eos}").
|
||||
--bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
|
||||
--bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
|
||||
--bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}").
|
||||
|
||||
# Language model related
|
||||
--lm_tag # Suffix to the result dir for language model training (default="${lm_tag}").
|
||||
--lm_exp # Specify the direcotry path for LM experiment.
|
||||
# If this option is specified, lm_tag is ignored (default="${lm_exp}").
|
||||
--lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
|
||||
--lm_config # Config for language model training (default="${lm_config}").
|
||||
--lm_args # Arguments for language model training (default="${lm_args}").
|
||||
# e.g., --lm_args "--max_epoch 10"
|
||||
# Note that it will overwrite args in lm config.
|
||||
--use_word_lm # Whether to use word language model (default="${use_word_lm}").
|
||||
--word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
|
||||
--num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}").
|
||||
|
||||
# ASR model related
|
||||
--asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}").
|
||||
--asr_exp # Specify the direcotry path for ASR experiment.
|
||||
# If this option is specified, asr_tag is ignored (default="${asr_exp}").
|
||||
--asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
|
||||
--asr_config # Config for asr model training (default="${asr_config}").
|
||||
--asr_args # Arguments for asr model training (default="${asr_args}").
|
||||
# e.g., --asr_args "--max_epoch 10"
|
||||
# Note that it will overwrite args in asr config.
|
||||
--feats_normalize # Normalizaton layer type (default="${feats_normalize}").
|
||||
--num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}").
|
||||
|
||||
# Decoding related
|
||||
--inference_tag # Suffix to the result dir for decoding (default="${inference_tag}").
|
||||
--inference_config # Config for decoding (default="${inference_config}").
|
||||
--inference_args # Arguments for decoding (default="${inference_args}").
|
||||
# e.g., --inference_args "--lm_weight 0.1"
|
||||
# Note that it will overwrite args in inference config.
|
||||
--inference_lm # Language modle path for decoding (default="${inference_lm}").
|
||||
--inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
|
||||
--download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}").
|
||||
|
||||
# [Task dependent] Set the datadir name created by local/data.sh
|
||||
--train_set # Name of training set (required).
|
||||
--valid_set # Name of validation set used for monitoring/tuning network training (required).
|
||||
--test_sets # Names of test sets.
|
||||
# Multiple items (e.g., both dev and eval sets) can be specified (required).
|
||||
--bpe_train_text # Text file path of bpe training set.
|
||||
--lm_train_text # Text file path of language model training set.
|
||||
--lm_dev_text # Text file path of language model development set (default="${lm_dev_text}").
|
||||
--lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}").
|
||||
--nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
|
||||
--cleaner # Text cleaner (default="${cleaner}").
|
||||
--g2p # g2p method (default="${g2p}").
|
||||
--lang # The language type of corpus (default=${lang}).
|
||||
--score_opts # The options given to sclite scoring (default="{score_opts}").
|
||||
--local_score_opts # The options given to local/score.sh (default="{local_score_opts}").
|
||||
EOF
|
||||
)
|
||||
|
||||
log "$0 $*"
|
||||
# Save command line args for logging (they will be lost after utils/parse_options.sh)
|
||||
run_args=$(python -m funasr.utils.cli_utils $0 "$@")
|
||||
. utils/parse_options.sh
|
||||
|
||||
if [ $# -ne 0 ]; then
|
||||
log "${help_message}"
|
||||
log "Error: No positional arguments are required."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
. ./path.sh
|
||||
|
||||
|
||||
# Check required arguments
|
||||
[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
|
||||
[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
|
||||
[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
|
||||
|
||||
# Check feature type
|
||||
if [ "${feats_type}" = raw ]; then
|
||||
data_feats=${dumpdir}/raw
|
||||
elif [ "${feats_type}" = fbank_pitch ]; then
|
||||
data_feats=${dumpdir}/fbank_pitch
|
||||
elif [ "${feats_type}" = fbank ]; then
|
||||
data_feats=${dumpdir}/fbank
|
||||
elif [ "${feats_type}" == extracted ]; then
|
||||
data_feats=${dumpdir}/extracted
|
||||
else
|
||||
log "${help_message}"
|
||||
log "Error: not supported: --feats_type ${feats_type}"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
# Use the same text as ASR for bpe training if not specified.
|
||||
[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
|
||||
# Use the same text as ASR for lm training if not specified.
|
||||
[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
|
||||
# Use the same text as ASR for lm training if not specified.
|
||||
[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
|
||||
# Use the text of the 1st evaldir if lm_test is not specified
|
||||
[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
|
||||
|
||||
# Check tokenization type
|
||||
if [ "${lang}" != noinfo ]; then
|
||||
token_listdir=data/${lang}_token_list
|
||||
else
|
||||
token_listdir=data/token_list
|
||||
fi
|
||||
bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
|
||||
bpeprefix="${bpedir}"/bpe
|
||||
bpemodel="${bpeprefix}".model
|
||||
bpetoken_list="${bpedir}"/tokens.txt
|
||||
chartoken_list="${token_listdir}"/char/tokens.txt
|
||||
# NOTE: keep for future development.
|
||||
# shellcheck disable=SC2034
|
||||
wordtoken_list="${token_listdir}"/word/tokens.txt
|
||||
|
||||
if [ "${token_type}" = bpe ]; then
|
||||
token_list="${bpetoken_list}"
|
||||
elif [ "${token_type}" = char ]; then
|
||||
token_list="${chartoken_list}"
|
||||
bpemodel=none
|
||||
elif [ "${token_type}" = word ]; then
|
||||
token_list="${wordtoken_list}"
|
||||
bpemodel=none
|
||||
else
|
||||
log "Error: not supported --token_type '${token_type}'"
|
||||
exit 2
|
||||
fi
|
||||
if ${use_word_lm}; then
|
||||
log "Error: Word LM is not supported yet"
|
||||
exit 2
|
||||
|
||||
lm_token_list="${wordtoken_list}"
|
||||
lm_token_type=word
|
||||
else
|
||||
lm_token_list="${token_list}"
|
||||
lm_token_type="${token_type}"
|
||||
fi
|
||||
|
||||
|
||||
# Set tag for naming of model directory
|
||||
if [ -z "${asr_tag}" ]; then
|
||||
if [ -n "${asr_config}" ]; then
|
||||
asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
|
||||
else
|
||||
asr_tag="train_${feats_type}"
|
||||
fi
|
||||
if [ "${lang}" != noinfo ]; then
|
||||
asr_tag+="_${lang}_${token_type}"
|
||||
else
|
||||
asr_tag+="_${token_type}"
|
||||
fi
|
||||
if [ "${token_type}" = bpe ]; then
|
||||
asr_tag+="${nbpe}"
|
||||
fi
|
||||
# Add overwritten arg's info
|
||||
if [ -n "${asr_args}" ]; then
|
||||
asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
|
||||
fi
|
||||
if [ -n "${speed_perturb_factors}" ]; then
|
||||
asr_tag+="_sp"
|
||||
fi
|
||||
fi
|
||||
if [ -z "${lm_tag}" ]; then
|
||||
if [ -n "${lm_config}" ]; then
|
||||
lm_tag="$(basename "${lm_config}" .yaml)"
|
||||
else
|
||||
lm_tag="train"
|
||||
fi
|
||||
if [ "${lang}" != noinfo ]; then
|
||||
lm_tag+="_${lang}_${lm_token_type}"
|
||||
else
|
||||
lm_tag+="_${lm_token_type}"
|
||||
fi
|
||||
if [ "${lm_token_type}" = bpe ]; then
|
||||
lm_tag+="${nbpe}"
|
||||
fi
|
||||
# Add overwritten arg's info
|
||||
if [ -n "${lm_args}" ]; then
|
||||
lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
|
||||
fi
|
||||
fi
|
||||
|
||||
# The directory used for collect-stats mode
|
||||
if [ -z "${asr_stats_dir}" ]; then
|
||||
if [ "${lang}" != noinfo ]; then
|
||||
asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
|
||||
else
|
||||
asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
|
||||
fi
|
||||
if [ "${token_type}" = bpe ]; then
|
||||
asr_stats_dir+="${nbpe}"
|
||||
fi
|
||||
if [ -n "${speed_perturb_factors}" ]; then
|
||||
asr_stats_dir+="_sp"
|
||||
fi
|
||||
fi
|
||||
if [ -z "${lm_stats_dir}" ]; then
|
||||
if [ "${lang}" != noinfo ]; then
|
||||
lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
|
||||
else
|
||||
lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
|
||||
fi
|
||||
if [ "${lm_token_type}" = bpe ]; then
|
||||
lm_stats_dir+="${nbpe}"
|
||||
fi
|
||||
fi
|
||||
# The directory used for training commands
|
||||
if [ -z "${asr_exp}" ]; then
|
||||
asr_exp="${expdir}/asr_${asr_tag}"
|
||||
fi
|
||||
if [ -z "${lm_exp}" ]; then
|
||||
lm_exp="${expdir}/lm_${lm_tag}"
|
||||
fi
|
||||
|
||||
|
||||
if [ -z "${inference_tag}" ]; then
|
||||
if [ -n "${inference_config}" ]; then
|
||||
inference_tag="$(basename "${inference_config}" .yaml)"
|
||||
else
|
||||
inference_tag=inference
|
||||
fi
|
||||
# Add overwritten arg's info
|
||||
if [ -n "${inference_args}" ]; then
|
||||
inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
|
||||
fi
|
||||
if "${use_lm}"; then
|
||||
inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
|
||||
fi
|
||||
inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
|
||||
fi
|
||||
|
||||
if [ -z "${sa_asr_inference_tag}" ]; then
|
||||
if [ -n "${inference_config}" ]; then
|
||||
sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
|
||||
else
|
||||
sa_asr_inference_tag=sa_asr_inference
|
||||
fi
|
||||
# Add overwritten arg's info
|
||||
if [ -n "${sa_asr_inference_args}" ]; then
|
||||
sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
|
||||
fi
|
||||
if "${use_lm}"; then
|
||||
sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
|
||||
fi
|
||||
sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
|
||||
fi
|
||||
|
||||
train_cmd="run.pl"
|
||||
cuda_cmd="run.pl"
|
||||
decode_cmd="run.pl"
|
||||
|
||||
# ========================== Main stages start from here. ==========================
|
||||
|
||||
if ! "${skip_data_prep}"; then
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
if [ "${feats_type}" = raw ]; then
|
||||
log "Stage 1: Format wav.scp: data/ -> ${data_feats}"
|
||||
|
||||
# ====== Recreating "wav.scp" ======
|
||||
# Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
|
||||
# shouldn't be used in training process.
|
||||
# "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
|
||||
# and it can also change the audio-format and sampling rate.
|
||||
# If nothing is need, then format_wav_scp.sh does nothing:
|
||||
# i.e. the input file format and rate is same as the output.
|
||||
|
||||
for dset in "${test_sets}" ; do
|
||||
|
||||
_suf=""
|
||||
|
||||
utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
|
||||
|
||||
rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
|
||||
_opts=
|
||||
if [ -e data/"${dset}"/segments ]; then
|
||||
# "segments" is used for splitting wav files which are written in "wav".scp
|
||||
# into utterances. The file format of segments:
|
||||
# <segment_id> <record_id> <start_time> <end_time>
|
||||
# "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
|
||||
# Where the time is written in seconds.
|
||||
_opts+="--segments data/${dset}/segments "
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
|
||||
--audio-format "${audio_format}" --fs "${fs}" ${_opts} \
|
||||
"data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
|
||||
|
||||
echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
|
||||
done
|
||||
|
||||
else
|
||||
log "Error: not supported: --feats_type ${feats_type}"
|
||||
exit 2
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
log "Stage 2: Generate speaker profile by spectral-cluster"
|
||||
mkdir -p "profile_log"
|
||||
for dset in "${test_sets}"; do
|
||||
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
|
||||
python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
|
||||
log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
|
||||
done
|
||||
fi
|
||||
|
||||
else
|
||||
log "Skip the stages for data preparation"
|
||||
fi
|
||||
|
||||
|
||||
# ========================== Data preparation is done here. ==========================
|
||||
|
||||
if ! "${skip_eval}"; then
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
log "Stage 3: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
|
||||
|
||||
if ${gpu_inference}; then
|
||||
_cmd="${cuda_cmd}"
|
||||
inference_nj=$[${ngpu}*${njob_infer}]
|
||||
_ngpu=1
|
||||
|
||||
else
|
||||
_cmd="${decode_cmd}"
|
||||
inference_nj=$njob_infer
|
||||
_ngpu=0
|
||||
fi
|
||||
|
||||
_opts=
|
||||
if [ -n "${inference_config}" ]; then
|
||||
_opts+="--config ${inference_config} "
|
||||
fi
|
||||
if "${use_lm}"; then
|
||||
if "${use_word_lm}"; then
|
||||
_opts+="--word_lm_train_config ${lm_exp}/config.yaml "
|
||||
_opts+="--word_lm_file ${lm_exp}/${inference_lm} "
|
||||
else
|
||||
_opts+="--lm_train_config ${lm_exp}/config.yaml "
|
||||
_opts+="--lm_file ${lm_exp}/${inference_lm} "
|
||||
fi
|
||||
fi
|
||||
|
||||
# 2. Generate run.sh
|
||||
log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
|
||||
mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
|
||||
|
||||
for dset in ${test_sets}; do
|
||||
_data="${data_feats}/${dset}"
|
||||
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
mkdir -p "${_logdir}"
|
||||
|
||||
_feats_type="$(<${_data}/feats_type)"
|
||||
if [ "${_feats_type}" = raw ]; then
|
||||
_scp=wav.scp
|
||||
if [[ "${audio_format}" == *ark* ]]; then
|
||||
_type=kaldi_ark
|
||||
else
|
||||
_type=sound
|
||||
fi
|
||||
else
|
||||
_scp=feats.scp
|
||||
_type=kaldi_ark
|
||||
fi
|
||||
|
||||
# 1. Split the key file
|
||||
key_file=${_data}/${_scp}
|
||||
split_scps=""
|
||||
_nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
# 2. Submit decoding jobs
|
||||
log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
|
||||
# shellcheck disable=SC2086
|
||||
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--nbest 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob_infer} \
|
||||
--gpuid_list ${device} \
|
||||
--data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
|
||||
--data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--allow_variable_data_keys true \
|
||||
--asr_train_config "${sa_asr_exp}"/config.yaml \
|
||||
--asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode sa_asr \
|
||||
${_opts}
|
||||
|
||||
# 3. Concatenates the output files from each jobs
|
||||
for f in token token_int score text text_id; do
|
||||
for i in $(seq "${_nj}"); do
|
||||
cat "${_logdir}/output.${i}/1best_recog/${f}"
|
||||
done | LC_ALL=C sort -k1 >"${_dir}/${f}"
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
log "Stage 4: Generate SA-ASR results (cluster profile)"
|
||||
|
||||
for dset in ${test_sets}; do
|
||||
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
|
||||
|
||||
python local/process_text_spk_merge.py ${_dir}
|
||||
done
|
||||
|
||||
fi
|
||||
|
||||
else
|
||||
log "Skip the evaluation stages"
|
||||
fi
|
||||
|
||||
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
6
egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
Normal file
6
egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
beam_size: 20
|
||||
penalty: 0.0
|
||||
maxlenratio: 0.0
|
||||
minlenratio: 0.0
|
||||
ctc_weight: 0.6
|
||||
lm_weight: 0.3
|
||||
88
egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
Normal file
88
egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
Normal file
@ -0,0 +1,88 @@
|
||||
# network architecture
|
||||
frontend: default
|
||||
frontend_conf:
|
||||
n_fft: 400
|
||||
win_length: 400
|
||||
hop_length: 160
|
||||
use_channel: 0
|
||||
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder architecture type
|
||||
normalize_before: true
|
||||
rel_pos_type: latest
|
||||
pos_enc_layer_type: rel_pos
|
||||
selfattention_layer_type: rel_selfattn
|
||||
activation_type: swish
|
||||
macaron_style: true
|
||||
use_cnn_module: true
|
||||
cnn_module_kernel: 15
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# ctc related
|
||||
ctc_conf:
|
||||
ignore_nan_grad: true
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
# minibatch related
|
||||
batch_type: numel
|
||||
batch_bins: 10000000 # reduce/increase this number according to your GPU memory
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 100
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- acc
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- acc
|
||||
- max
|
||||
keep_nbest_models: 10
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
|
||||
specaug: specaug
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
29
egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
Normal file
29
egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
lm: transformer
|
||||
lm_conf:
|
||||
pos_enc: null
|
||||
embed_unit: 128
|
||||
att_unit: 512
|
||||
head: 8
|
||||
unit: 2048
|
||||
layer: 16
|
||||
dropout_rate: 0.1
|
||||
|
||||
# optimization related
|
||||
grad_clip: 5.0
|
||||
batch_type: numel
|
||||
batch_bins: 500000 # 4gpus * 500000
|
||||
accum_grad: 1
|
||||
max_epoch: 15 # 15epoch is enougth
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- loss
|
||||
- min
|
||||
keep_nbest_models: 10 # 10 is good.
|
||||
116
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
Normal file
116
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
Normal file
@ -0,0 +1,116 @@
|
||||
# network architecture
|
||||
frontend: default
|
||||
frontend_conf:
|
||||
n_fft: 400
|
||||
win_length: 400
|
||||
hop_length: 160
|
||||
use_channel: 0
|
||||
|
||||
# encoder related
|
||||
asr_encoder: conformer
|
||||
asr_encoder_conf:
|
||||
output_size: 256 # dimension of attention
|
||||
attention_heads: 4
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder architecture type
|
||||
normalize_before: true
|
||||
pos_enc_layer_type: rel_pos
|
||||
selfattention_layer_type: rel_selfattn
|
||||
activation_type: swish
|
||||
macaron_style: true
|
||||
use_cnn_module: true
|
||||
cnn_module_kernel: 15
|
||||
|
||||
spk_encoder: resnet34_diar
|
||||
spk_encoder_conf:
|
||||
use_head_conv: true
|
||||
batchnorm_momentum: 0.5
|
||||
use_head_maxpool: false
|
||||
num_nodes_pooling_layer: 256
|
||||
layers_in_block:
|
||||
- 3
|
||||
- 4
|
||||
- 6
|
||||
- 3
|
||||
filters_in_block:
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
pooling_type: statistic
|
||||
num_nodes_resnet1: 256
|
||||
num_nodes_last_layer: 256
|
||||
batchnorm_momentum: 0.5
|
||||
|
||||
# decoder related
|
||||
decoder: sa_decoder
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
asr_num_blocks: 6
|
||||
spk_num_blocks: 3
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
spk_weight: 0.5
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
|
||||
ctc_conf:
|
||||
ignore_nan_grad: true
|
||||
|
||||
# minibatch related
|
||||
batch_type: numel
|
||||
batch_bins: 10000000
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 60
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- loss
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- acc
|
||||
- max
|
||||
- - valid
|
||||
- acc_spk
|
||||
- max
|
||||
- - valid
|
||||
- loss
|
||||
- min
|
||||
keep_nbest_models: 10
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.0005
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 8000
|
||||
|
||||
specaug: specaug
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
|
||||
162
egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
Executable file
162
egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
Executable file
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env bash
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
log() {
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
help_messge=$(cat << EOF
|
||||
Usage: $0
|
||||
|
||||
Options:
|
||||
--no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
|
||||
--tgt (string): Which set to process, test or train.
|
||||
EOF
|
||||
)
|
||||
|
||||
SECONDS=0
|
||||
tgt=Train #Train or Eval
|
||||
|
||||
|
||||
log "$0 $*"
|
||||
echo $tgt
|
||||
. ./utils/parse_options.sh
|
||||
|
||||
. ./path.sh
|
||||
|
||||
AliMeeting="${PWD}/dataset"
|
||||
|
||||
if [ $# -gt 2 ]; then
|
||||
log "${help_message}"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -d "${AliMeeting}" ]; then
|
||||
log "Error: ${AliMeeting} is empty."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
# To absolute path
|
||||
AliMeeting=$(cd ${AliMeeting}; pwd)
|
||||
echo $AliMeeting
|
||||
far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
|
||||
near_raw_dir=${AliMeeting}/${tgt}_Ali_near/
|
||||
|
||||
far_dir=data/local/${tgt}_Ali_far
|
||||
near_dir=data/local/${tgt}_Ali_near
|
||||
far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
|
||||
mkdir -p $far_single_speaker_dir
|
||||
|
||||
stage=1
|
||||
stop_stage=4
|
||||
mkdir -p $far_dir
|
||||
mkdir -p $near_dir
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
log "stage 1:process alimeeting near dir"
|
||||
|
||||
find -L $near_raw_dir/audio_dir -iname "*.wav" > $near_dir/wavlist
|
||||
awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
|
||||
find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist
|
||||
n1_wav=$(wc -l < $near_dir/wavlist)
|
||||
n2_text=$(wc -l < $near_dir/textgrid.flist)
|
||||
log near file found $n1_wav wav and $n2_text text.
|
||||
|
||||
paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
|
||||
|
||||
# cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
|
||||
cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
|
||||
|
||||
python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
|
||||
cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
|
||||
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
|
||||
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
|
||||
#sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
|
||||
utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
|
||||
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
|
||||
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
|
||||
sed -e 's/!//g' $near_dir/tmp1> $near_dir/tmp2
|
||||
sed -e 's/?//g' $near_dir/tmp2> $near_dir/text
|
||||
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
log "stage 2:process alimeeting far dir"
|
||||
|
||||
find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist
|
||||
awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid
|
||||
find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist
|
||||
n1_wav=$(wc -l < $far_dir/wavlist)
|
||||
n2_text=$(wc -l < $far_dir/textgrid.flist)
|
||||
log far file found $n1_wav wav and $n2_text text.
|
||||
|
||||
paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
|
||||
|
||||
cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
|
||||
|
||||
python local/alimeeting_process_overlap_force.py --path $far_dir \
|
||||
--no-overlap false --mars True \
|
||||
--overlap_length 0.8 --max_length 7
|
||||
|
||||
cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
|
||||
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
|
||||
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
|
||||
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
|
||||
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
|
||||
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
|
||||
sed -e 's/!//g' $far_dir/tmp2> $far_dir/tmp3
|
||||
sed -e 's/?//g' $far_dir/tmp3> $far_dir/text
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
log "stage 3: finali data process"
|
||||
|
||||
utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
|
||||
utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
|
||||
|
||||
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
|
||||
# remove space in text
|
||||
for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
|
||||
cp data/${x}/text data/${x}/text.org
|
||||
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
|
||||
> data/${x}/text
|
||||
rm data/${x}/text.org
|
||||
done
|
||||
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
|
||||
cp -r $far_dir/* $far_single_speaker_dir
|
||||
mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
|
||||
paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
|
||||
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
|
||||
|
||||
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
|
||||
utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
|
||||
|
||||
./utils/fix_data_dir.sh $far_single_speaker_dir
|
||||
utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
|
||||
|
||||
# remove space in text
|
||||
for x in ${tgt}_Ali_far_single_speaker; do
|
||||
cp data/${x}/text data/${x}/text.org
|
||||
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
|
||||
> data/${x}/text
|
||||
rm data/${x}/text.org
|
||||
done
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
fi
|
||||
129
egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
Executable file
129
egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
Executable file
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env bash
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
log() {
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
help_messge=$(cat << EOF
|
||||
Usage: $0
|
||||
|
||||
Options:
|
||||
--no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
|
||||
--tgt (string): Which set to process, test or train.
|
||||
EOF
|
||||
)
|
||||
|
||||
SECONDS=0
|
||||
tgt=Train #Train or Eval
|
||||
|
||||
|
||||
log "$0 $*"
|
||||
echo $tgt
|
||||
. ./utils/parse_options.sh
|
||||
|
||||
. ./path.sh
|
||||
|
||||
AliMeeting="${PWD}/dataset"
|
||||
|
||||
if [ $# -gt 2 ]; then
|
||||
log "${help_message}"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -d "${AliMeeting}" ]; then
|
||||
log "Error: ${AliMeeting} is empty."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
# To absolute path
|
||||
AliMeeting=$(cd ${AliMeeting}; pwd)
|
||||
echo $AliMeeting
|
||||
far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
|
||||
|
||||
far_dir=data/local/${tgt}_Ali_far
|
||||
far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
|
||||
mkdir -p $far_single_speaker_dir
|
||||
|
||||
stage=1
|
||||
stop_stage=3
|
||||
mkdir -p $far_dir
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
log "stage 1:process alimeeting far dir"
|
||||
|
||||
find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist
|
||||
awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid
|
||||
find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist
|
||||
n1_wav=$(wc -l < $far_dir/wavlist)
|
||||
n2_text=$(wc -l < $far_dir/textgrid.flist)
|
||||
log far file found $n1_wav wav and $n2_text text.
|
||||
|
||||
paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
|
||||
|
||||
cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
|
||||
|
||||
python local/alimeeting_process_overlap_force.py --path $far_dir \
|
||||
--no-overlap false --mars True \
|
||||
--overlap_length 0.8 --max_length 7
|
||||
|
||||
cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
|
||||
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
|
||||
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
|
||||
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
|
||||
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
|
||||
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
|
||||
sed -e 's/!//g' $far_dir/tmp2> $far_dir/tmp3
|
||||
sed -e 's/?//g' $far_dir/tmp3> $far_dir/text
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
log "stage 2: finali data process"
|
||||
|
||||
utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
|
||||
|
||||
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
|
||||
# remove space in text
|
||||
for x in ${tgt}_Ali_far; do
|
||||
cp data/${x}/text data/${x}/text.org
|
||||
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
|
||||
> data/${x}/text
|
||||
rm data/${x}/text.org
|
||||
done
|
||||
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
log "stage 3:process alimeeting far dir (single speaker by oracal time strap)"
|
||||
cp -r $far_dir/* $far_single_speaker_dir
|
||||
mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
|
||||
paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
|
||||
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
|
||||
|
||||
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
|
||||
utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
|
||||
|
||||
./utils/fix_data_dir.sh $far_single_speaker_dir
|
||||
utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
|
||||
|
||||
# remove space in text
|
||||
for x in ${tgt}_Ali_far_single_speaker; do
|
||||
cp data/${x}/text data/${x}/text.org
|
||||
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
|
||||
> data/${x}/text
|
||||
rm data/${x}/text.org
|
||||
done
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
fi
|
||||
235
egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
Executable file
235
egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
Executable file
@ -0,0 +1,235 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Process the textgrid files
|
||||
"""
|
||||
import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
import pdb
|
||||
|
||||
class Segment(object):
|
||||
def __init__(self, uttid, spkr, stime, etime, text):
|
||||
self.uttid = uttid
|
||||
self.spkr = spkr
|
||||
self.spkr_all = uttid+"-"+spkr
|
||||
self.stime = round(stime, 2)
|
||||
self.etime = round(etime, 2)
|
||||
self.text = text
|
||||
self.spk_text = {uttid+"-"+spkr: text}
|
||||
|
||||
def change_stime(self, time):
|
||||
self.stime = time
|
||||
|
||||
def change_etime(self, time):
|
||||
self.etime = time
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="process the textgrid files")
|
||||
parser.add_argument("--path", type=str, required=True, help="Data path")
|
||||
parser.add_argument(
|
||||
"--no-overlap",
|
||||
type=strtobool,
|
||||
default=False,
|
||||
help="Whether to ignore the overlapping utterances.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
default=100000,
|
||||
type=float,
|
||||
help="overlap speech max time,if longger than max length should cut",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overlap_length",
|
||||
default=1,
|
||||
type=float,
|
||||
help="if length longer than max length, speech overlength shorter, is cut",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mars",
|
||||
type=strtobool,
|
||||
default=False,
|
||||
help="Whether to process mars data set.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def preposs_overlap(segments,max_length,overlap_length):
|
||||
new_segments = []
|
||||
# init a helper list to store all overlap segments
|
||||
tmp_segments = segments[0]
|
||||
min_stime = segments[0].stime
|
||||
max_etime = segments[0].etime
|
||||
overlap_length_big = 1.5
|
||||
max_length_big = 15
|
||||
for i in range(1, len(segments)):
|
||||
if segments[i].stime >= max_etime:
|
||||
# doesn't overlap with preivous segments
|
||||
new_segments.append(tmp_segments)
|
||||
tmp_segments = segments[i]
|
||||
min_stime = segments[i].stime
|
||||
max_etime = segments[i].etime
|
||||
else:
|
||||
# overlap with previous segments
|
||||
dur_time = max_etime - min_stime
|
||||
if dur_time < max_length:
|
||||
if min_stime > segments[i].stime:
|
||||
min_stime = segments[i].stime
|
||||
if max_etime < segments[i].etime:
|
||||
max_etime = segments[i].etime
|
||||
tmp_segments.stime = min_stime
|
||||
tmp_segments.etime = max_etime
|
||||
tmp_segments.text = tmp_segments.text + "src" + segments[i].text
|
||||
spk_name =segments[i].uttid +"-" + segments[i].spkr
|
||||
if spk_name in tmp_segments.spk_text:
|
||||
tmp_segments.spk_text[spk_name] += segments[i].text
|
||||
else:
|
||||
tmp_segments.spk_text[spk_name] = segments[i].text
|
||||
tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
|
||||
else:
|
||||
overlap_time = max_etime - segments[i].stime
|
||||
if dur_time < max_length_big:
|
||||
overlap_length_option = overlap_length
|
||||
else:
|
||||
overlap_length_option = overlap_length_big
|
||||
if overlap_time > overlap_length_option:
|
||||
if min_stime > segments[i].stime:
|
||||
min_stime = segments[i].stime
|
||||
if max_etime < segments[i].etime:
|
||||
max_etime = segments[i].etime
|
||||
tmp_segments.stime = min_stime
|
||||
tmp_segments.etime = max_etime
|
||||
tmp_segments.text = tmp_segments.text + "src" + segments[i].text
|
||||
spk_name =segments[i].uttid +"-" + segments[i].spkr
|
||||
if spk_name in tmp_segments.spk_text:
|
||||
tmp_segments.spk_text[spk_name] += segments[i].text
|
||||
else:
|
||||
tmp_segments.spk_text[spk_name] = segments[i].text
|
||||
tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
|
||||
else:
|
||||
new_segments.append(tmp_segments)
|
||||
tmp_segments = segments[i]
|
||||
min_stime = segments[i].stime
|
||||
max_etime = segments[i].etime
|
||||
|
||||
return new_segments
|
||||
|
||||
def filter_overlap(segments):
|
||||
new_segments = []
|
||||
# init a helper list to store all overlap segments
|
||||
tmp_segments = [segments[0]]
|
||||
min_stime = segments[0].stime
|
||||
max_etime = segments[0].etime
|
||||
|
||||
for i in range(1, len(segments)):
|
||||
if segments[i].stime >= max_etime:
|
||||
# doesn't overlap with preivous segments
|
||||
if len(tmp_segments) == 1:
|
||||
new_segments.append(tmp_segments[0])
|
||||
# TODO: for multi-spkr asr, we can reset the stime/etime to
|
||||
# min_stime/max_etime for generating a max length mixutre speech
|
||||
tmp_segments = [segments[i]]
|
||||
min_stime = segments[i].stime
|
||||
max_etime = segments[i].etime
|
||||
else:
|
||||
# overlap with previous segments
|
||||
tmp_segments.append(segments[i])
|
||||
if min_stime > segments[i].stime:
|
||||
min_stime = segments[i].stime
|
||||
if max_etime < segments[i].etime:
|
||||
max_etime = segments[i].etime
|
||||
|
||||
return new_segments
|
||||
|
||||
|
||||
def main(args):
|
||||
wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
|
||||
textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
|
||||
|
||||
# get the path of textgrid file for each utterance
|
||||
utt2textgrid = {}
|
||||
for line in textgrid_flist:
|
||||
path = Path(line.strip())
|
||||
uttid = path.stem
|
||||
utt2textgrid[uttid] = path
|
||||
|
||||
# parse the textgrid file for each utterance
|
||||
all_segments = []
|
||||
for line in wav_scp:
|
||||
uttid = line.strip().split(" ")[0]
|
||||
uttid_part=uttid
|
||||
if args.mars == True:
|
||||
uttid_list = uttid.split("_")
|
||||
uttid_part= uttid_list[0]+"_"+uttid_list[1]
|
||||
if uttid_part not in utt2textgrid:
|
||||
print("%s doesn't have transcription" % uttid)
|
||||
continue
|
||||
|
||||
segments = []
|
||||
tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
|
||||
for i in range(tg.__len__()):
|
||||
for j in range(tg[i].__len__()):
|
||||
if tg[i][j].mark:
|
||||
segments.append(
|
||||
Segment(
|
||||
uttid,
|
||||
tg[i].name,
|
||||
tg[i][j].minTime,
|
||||
tg[i][j].maxTime,
|
||||
tg[i][j].mark.strip(),
|
||||
)
|
||||
)
|
||||
|
||||
segments = sorted(segments, key=lambda x: x.stime)
|
||||
|
||||
if args.no_overlap:
|
||||
segments = filter_overlap(segments)
|
||||
else:
|
||||
segments = preposs_overlap(segments,args.max_length,args.overlap_length)
|
||||
all_segments += segments
|
||||
|
||||
wav_scp.close()
|
||||
textgrid_flist.close()
|
||||
|
||||
segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
|
||||
utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
|
||||
text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
|
||||
utt2spk_file_fifo = codecs.open(Path(args.path) / "utt2spk_all_fifo", "w", "utf-8")
|
||||
|
||||
for i in range(len(all_segments)):
|
||||
utt_name = "%s-%s-%07d-%07d" % (
|
||||
all_segments[i].uttid,
|
||||
all_segments[i].spkr,
|
||||
all_segments[i].stime * 100,
|
||||
all_segments[i].etime * 100,
|
||||
)
|
||||
|
||||
segments_file.write(
|
||||
"%s %s %.2f %.2f\n"
|
||||
% (
|
||||
utt_name,
|
||||
all_segments[i].uttid,
|
||||
all_segments[i].stime,
|
||||
all_segments[i].etime,
|
||||
)
|
||||
)
|
||||
utt2spk_file.write(
|
||||
"%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
|
||||
)
|
||||
utt2spk_file_fifo.write(
|
||||
"%s %s\n" % (utt_name, all_segments[i].spkr_all)
|
||||
)
|
||||
text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
|
||||
|
||||
segments_file.close()
|
||||
utt2spk_file.close()
|
||||
text_file.close()
|
||||
utt2spk_file_fifo.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
158
egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
Executable file
158
egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
Executable file
@ -0,0 +1,158 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Process the textgrid files
|
||||
"""
|
||||
import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
import pdb
|
||||
|
||||
class Segment(object):
|
||||
def __init__(self, uttid, spkr, stime, etime, text):
|
||||
self.uttid = uttid
|
||||
self.spkr = spkr
|
||||
self.stime = round(stime, 2)
|
||||
self.etime = round(etime, 2)
|
||||
self.text = text
|
||||
|
||||
def change_stime(self, time):
|
||||
self.stime = time
|
||||
|
||||
def change_etime(self, time):
|
||||
self.etime = time
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="process the textgrid files")
|
||||
parser.add_argument("--path", type=str, required=True, help="Data path")
|
||||
parser.add_argument(
|
||||
"--no-overlap",
|
||||
type=strtobool,
|
||||
default=False,
|
||||
help="Whether to ignore the overlapping utterances.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mars",
|
||||
type=strtobool,
|
||||
default=False,
|
||||
help="Whether to process mars data set.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def filter_overlap(segments):
|
||||
new_segments = []
|
||||
# init a helper list to store all overlap segments
|
||||
tmp_segments = [segments[0]]
|
||||
min_stime = segments[0].stime
|
||||
max_etime = segments[0].etime
|
||||
|
||||
for i in range(1, len(segments)):
|
||||
if segments[i].stime >= max_etime:
|
||||
# doesn't overlap with preivous segments
|
||||
if len(tmp_segments) == 1:
|
||||
new_segments.append(tmp_segments[0])
|
||||
# TODO: for multi-spkr asr, we can reset the stime/etime to
|
||||
# min_stime/max_etime for generating a max length mixutre speech
|
||||
tmp_segments = [segments[i]]
|
||||
min_stime = segments[i].stime
|
||||
max_etime = segments[i].etime
|
||||
else:
|
||||
# overlap with previous segments
|
||||
tmp_segments.append(segments[i])
|
||||
if min_stime > segments[i].stime:
|
||||
min_stime = segments[i].stime
|
||||
if max_etime < segments[i].etime:
|
||||
max_etime = segments[i].etime
|
||||
|
||||
return new_segments
|
||||
|
||||
|
||||
def main(args):
|
||||
wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
|
||||
textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
|
||||
|
||||
# get the path of textgrid file for each utterance
|
||||
utt2textgrid = {}
|
||||
for line in textgrid_flist:
|
||||
path = Path(line.strip())
|
||||
uttid = path.stem
|
||||
utt2textgrid[uttid] = path
|
||||
|
||||
# parse the textgrid file for each utterance
|
||||
all_segments = []
|
||||
for line in wav_scp:
|
||||
uttid = line.strip().split(" ")[0]
|
||||
uttid_part=uttid
|
||||
if args.mars == True:
|
||||
uttid_list = uttid.split("_")
|
||||
uttid_part= uttid_list[0]+"_"+uttid_list[1]
|
||||
if uttid_part not in utt2textgrid:
|
||||
print("%s doesn't have transcription" % uttid)
|
||||
continue
|
||||
#pdb.set_trace()
|
||||
segments = []
|
||||
try:
|
||||
tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
|
||||
except:
|
||||
pdb.set_trace()
|
||||
for i in range(tg.__len__()):
|
||||
for j in range(tg[i].__len__()):
|
||||
if tg[i][j].mark:
|
||||
segments.append(
|
||||
Segment(
|
||||
uttid,
|
||||
tg[i].name,
|
||||
tg[i][j].minTime,
|
||||
tg[i][j].maxTime,
|
||||
tg[i][j].mark.strip(),
|
||||
)
|
||||
)
|
||||
|
||||
segments = sorted(segments, key=lambda x: x.stime)
|
||||
|
||||
if args.no_overlap:
|
||||
segments = filter_overlap(segments)
|
||||
|
||||
all_segments += segments
|
||||
|
||||
wav_scp.close()
|
||||
textgrid_flist.close()
|
||||
|
||||
segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
|
||||
utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
|
||||
text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
|
||||
|
||||
for i in range(len(all_segments)):
|
||||
utt_name = "%s-%s-%07d-%07d" % (
|
||||
all_segments[i].uttid,
|
||||
all_segments[i].spkr,
|
||||
all_segments[i].stime * 100,
|
||||
all_segments[i].etime * 100,
|
||||
)
|
||||
|
||||
segments_file.write(
|
||||
"%s %s %.2f %.2f\n"
|
||||
% (
|
||||
utt_name,
|
||||
all_segments[i].uttid,
|
||||
all_segments[i].stime,
|
||||
all_segments[i].etime,
|
||||
)
|
||||
)
|
||||
utt2spk_file.write(
|
||||
"%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
|
||||
)
|
||||
text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
|
||||
|
||||
segments_file.close()
|
||||
utt2spk_file.close()
|
||||
text_file.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
91
egs/alimeeting/sa-asr/local/compute_cpcer.py
Normal file
91
egs/alimeeting/sa-asr/local/compute_cpcer.py
Normal file
@ -0,0 +1,91 @@
|
||||
import editdistance
|
||||
import sys
|
||||
import os
|
||||
from itertools import permutations
|
||||
|
||||
|
||||
def load_transcripts(file_path):
|
||||
trans_list = []
|
||||
for one_line in open(file_path, "rt"):
|
||||
meeting_id, trans = one_line.strip().split(" ")
|
||||
trans_list.append((meeting_id.strip(), trans.strip()))
|
||||
|
||||
return trans_list
|
||||
|
||||
def calc_spk_trans(trans):
|
||||
spk_trans_ = [x.strip() for x in trans.split("$")]
|
||||
spk_trans = []
|
||||
for i in range(len(spk_trans_)):
|
||||
spk_trans.append((str(i), spk_trans_[i]))
|
||||
return spk_trans
|
||||
|
||||
def calc_cer(ref_trans, hyp_trans):
|
||||
ref_spk_trans = calc_spk_trans(ref_trans)
|
||||
hyp_spk_trans = calc_spk_trans(hyp_trans)
|
||||
ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
|
||||
num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
|
||||
ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
|
||||
hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
|
||||
|
||||
errors, counts, permutes = [], [], []
|
||||
min_error = 0
|
||||
cost_dict = {}
|
||||
for perm in permutations(range(num_spk)):
|
||||
flag = True
|
||||
p_err, p_count = 0, 0
|
||||
for idx, p in enumerate(perm):
|
||||
if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
|
||||
flag = False
|
||||
break
|
||||
cost_key = "{}-{}".format(idx, p)
|
||||
if cost_key in cost_dict:
|
||||
_e = cost_dict[cost_key]
|
||||
else:
|
||||
_e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
|
||||
cost_dict[cost_key] = _e
|
||||
if _e > min_error > 0:
|
||||
flag = False
|
||||
break
|
||||
p_err += _e
|
||||
p_count += len(ref_spk_trans[idx][1])
|
||||
|
||||
if flag:
|
||||
if p_err < min_error or min_error == 0:
|
||||
min_error = p_err
|
||||
|
||||
errors.append(p_err)
|
||||
counts.append(p_count)
|
||||
permutes.append(perm)
|
||||
|
||||
sd_cer = [(err, cnt, err/cnt, permute)
|
||||
for err, cnt, permute in zip(errors, counts, permutes)]
|
||||
# import ipdb;ipdb.set_trace()
|
||||
best_rst = min(sd_cer, key=lambda x: x[2])
|
||||
|
||||
return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
|
||||
|
||||
|
||||
def main():
|
||||
ref=sys.argv[1]
|
||||
hyp=sys.argv[2]
|
||||
result_path=sys.argv[3]
|
||||
ref_list = load_transcripts(ref)
|
||||
hyp_list = load_transcripts(hyp)
|
||||
result_file = open(result_path,'w')
|
||||
error, count = 0, 0
|
||||
for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
|
||||
assert ref_id == hyp_id
|
||||
mid = ref_id
|
||||
dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
|
||||
error, count = error + dist, count + length
|
||||
result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
|
||||
|
||||
# print("{} {:.2f} {} {}".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
|
||||
|
||||
result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
|
||||
result_file.close()
|
||||
# print("Sum/Avg: {:.2f}".format(error / count * 100.0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
157
egs/alimeeting/sa-asr/local/compute_wer.py
Executable file
157
egs/alimeeting/sa-asr/local/compute_wer.py
Executable file
@ -0,0 +1,157 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
def compute_wer(ref_file,
|
||||
hyp_file,
|
||||
cer_detail_file):
|
||||
rst = {
|
||||
'Wrd': 0,
|
||||
'Corr': 0,
|
||||
'Ins': 0,
|
||||
'Del': 0,
|
||||
'Sub': 0,
|
||||
'Snt': 0,
|
||||
'Err': 0.0,
|
||||
'S.Err': 0.0,
|
||||
'wrong_words': 0,
|
||||
'wrong_sentences': 0
|
||||
}
|
||||
|
||||
hyp_dict = {}
|
||||
ref_dict = {}
|
||||
with open(hyp_file, 'r') as hyp_reader:
|
||||
for line in hyp_reader:
|
||||
key = line.strip().split()[0]
|
||||
value = line.strip().split()[1:]
|
||||
hyp_dict[key] = value
|
||||
with open(ref_file, 'r') as ref_reader:
|
||||
for line in ref_reader:
|
||||
key = line.strip().split()[0]
|
||||
value = line.strip().split()[1:]
|
||||
ref_dict[key] = value
|
||||
|
||||
cer_detail_writer = open(cer_detail_file, 'w')
|
||||
for hyp_key in hyp_dict:
|
||||
if hyp_key in ref_dict:
|
||||
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
|
||||
rst['Wrd'] += out_item['nwords']
|
||||
rst['Corr'] += out_item['cor']
|
||||
rst['wrong_words'] += out_item['wrong']
|
||||
rst['Ins'] += out_item['ins']
|
||||
rst['Del'] += out_item['del']
|
||||
rst['Sub'] += out_item['sub']
|
||||
rst['Snt'] += 1
|
||||
if out_item['wrong'] > 0:
|
||||
rst['wrong_sentences'] += 1
|
||||
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
|
||||
cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
|
||||
cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
|
||||
|
||||
if rst['Wrd'] > 0:
|
||||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
|
||||
if rst['Snt'] > 0:
|
||||
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
|
||||
|
||||
cer_detail_writer.write('\n')
|
||||
cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
|
||||
", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
|
||||
cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
|
||||
cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
|
||||
|
||||
|
||||
def compute_wer_by_line(hyp,
|
||||
ref):
|
||||
hyp = list(map(lambda x: x.lower(), hyp))
|
||||
ref = list(map(lambda x: x.lower(), ref))
|
||||
|
||||
len_hyp = len(hyp)
|
||||
len_ref = len(ref)
|
||||
|
||||
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
|
||||
|
||||
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
|
||||
|
||||
for i in range(len_hyp + 1):
|
||||
cost_matrix[i][0] = i
|
||||
for j in range(len_ref + 1):
|
||||
cost_matrix[0][j] = j
|
||||
|
||||
for i in range(1, len_hyp + 1):
|
||||
for j in range(1, len_ref + 1):
|
||||
if hyp[i - 1] == ref[j - 1]:
|
||||
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
|
||||
else:
|
||||
substitution = cost_matrix[i - 1][j - 1] + 1
|
||||
insertion = cost_matrix[i - 1][j] + 1
|
||||
deletion = cost_matrix[i][j - 1] + 1
|
||||
|
||||
compare_val = [substitution, insertion, deletion]
|
||||
|
||||
min_val = min(compare_val)
|
||||
operation_idx = compare_val.index(min_val) + 1
|
||||
cost_matrix[i][j] = min_val
|
||||
ops_matrix[i][j] = operation_idx
|
||||
|
||||
match_idx = []
|
||||
i = len_hyp
|
||||
j = len_ref
|
||||
rst = {
|
||||
'nwords': len_ref,
|
||||
'cor': 0,
|
||||
'wrong': 0,
|
||||
'ins': 0,
|
||||
'del': 0,
|
||||
'sub': 0
|
||||
}
|
||||
while i >= 0 or j >= 0:
|
||||
i_idx = max(0, i)
|
||||
j_idx = max(0, j)
|
||||
|
||||
if ops_matrix[i_idx][j_idx] == 0: # correct
|
||||
if i - 1 >= 0 and j - 1 >= 0:
|
||||
match_idx.append((j - 1, i - 1))
|
||||
rst['cor'] += 1
|
||||
|
||||
i -= 1
|
||||
j -= 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 2: # insert
|
||||
i -= 1
|
||||
rst['ins'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 3: # delete
|
||||
j -= 1
|
||||
rst['del'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 1: # substitute
|
||||
i -= 1
|
||||
j -= 1
|
||||
rst['sub'] += 1
|
||||
|
||||
if i < 0 and j >= 0:
|
||||
rst['del'] += 1
|
||||
elif j < 0 and i >= 0:
|
||||
rst['ins'] += 1
|
||||
|
||||
match_idx.reverse()
|
||||
wrong_cnt = cost_matrix[len_hyp][len_ref]
|
||||
rst['wrong'] = wrong_cnt
|
||||
|
||||
return rst
|
||||
|
||||
def print_cer_detail(rst):
|
||||
return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
|
||||
+ ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
|
||||
+ str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
|
||||
+ ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 4:
|
||||
print("usage : python compute-wer.py test.ref test.hyp test.wer")
|
||||
sys.exit(0)
|
||||
|
||||
ref_file = sys.argv[1]
|
||||
hyp_file = sys.argv[2]
|
||||
cer_detail_file = sys.argv[3]
|
||||
compute_wer(ref_file, hyp_file, cer_detail_file)
|
||||
6
egs/alimeeting/sa-asr/local/download_xvector_model.py
Normal file
6
egs/alimeeting/sa-asr/local/download_xvector_model.py
Normal file
@ -0,0 +1,6 @@
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
import sys
|
||||
|
||||
|
||||
cache_dir = sys.argv[1]
|
||||
model_dir = snapshot_download('damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch', cache_dir=cache_dir)
|
||||
22
egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
Normal file
22
egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
Normal file
@ -0,0 +1,22 @@
|
||||
import sys
|
||||
if __name__=="__main__":
|
||||
uttid_path=sys.argv[1]
|
||||
src_path=sys.argv[2]
|
||||
tgt_path=sys.argv[3]
|
||||
uttid_file=open(uttid_path,'r')
|
||||
uttid_line=uttid_file.readlines()
|
||||
uttid_file.close()
|
||||
ori_utt2spk_all_fifo_file=open(src_path+'/utt2spk_all_fifo','r')
|
||||
ori_utt2spk_all_fifo_line=ori_utt2spk_all_fifo_file.readlines()
|
||||
ori_utt2spk_all_fifo_file.close()
|
||||
new_utt2spk_all_fifo_file=open(tgt_path+'/utt2spk_all_fifo','w')
|
||||
|
||||
uttid_list=[]
|
||||
for line in uttid_line:
|
||||
uttid_list.append(line.strip())
|
||||
|
||||
for line in ori_utt2spk_all_fifo_line:
|
||||
if line.strip().split(' ')[0] in uttid_list:
|
||||
new_utt2spk_all_fifo_file.write(line)
|
||||
|
||||
new_utt2spk_all_fifo_file.close()
|
||||
167
egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
Normal file
167
egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
Normal file
@ -0,0 +1,167 @@
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import soundfile
|
||||
from itertools import permutations
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from sklearn import cluster
|
||||
|
||||
|
||||
def custom_spectral_clustering(affinity, min_n_clusters=2, max_n_clusters=4, refine=True,
|
||||
threshold=0.995, laplacian_type="graph_cut"):
|
||||
if refine:
|
||||
# Symmetrization
|
||||
affinity = np.maximum(affinity, np.transpose(affinity))
|
||||
# Diffusion
|
||||
affinity = np.matmul(affinity, np.transpose(affinity))
|
||||
# Row-wise max normalization
|
||||
row_max = affinity.max(axis=1, keepdims=True)
|
||||
affinity = affinity / row_max
|
||||
|
||||
# a) Construct S and set diagonal elements to 0
|
||||
affinity = affinity - np.diag(np.diag(affinity))
|
||||
# b) Compute Laplacian matrix L and perform normalization:
|
||||
degree = np.diag(np.sum(affinity, axis=1))
|
||||
laplacian = degree - affinity
|
||||
if laplacian_type == "random_walk":
|
||||
degree_norm = np.diag(1 / (np.diag(degree) + 1e-10))
|
||||
laplacian_norm = degree_norm.dot(laplacian)
|
||||
else:
|
||||
degree_half = np.diag(degree) ** 0.5 + 1e-15
|
||||
laplacian_norm = laplacian / degree_half[:, np.newaxis] / degree_half
|
||||
|
||||
# c) Compute eigenvalues and eigenvectors of L_norm
|
||||
eigenvalues, eigenvectors = np.linalg.eig(laplacian_norm)
|
||||
eigenvalues = eigenvalues.real
|
||||
eigenvectors = eigenvectors.real
|
||||
index_array = np.argsort(eigenvalues)
|
||||
eigenvalues = eigenvalues[index_array]
|
||||
eigenvectors = eigenvectors[:, index_array]
|
||||
|
||||
# d) Compute the number of clusters k
|
||||
k = min_n_clusters
|
||||
for k in range(min_n_clusters, max_n_clusters + 1):
|
||||
if eigenvalues[k] > threshold:
|
||||
break
|
||||
k = max(k, min_n_clusters)
|
||||
spectral_embeddings = eigenvectors[:, :k]
|
||||
# print(mid, k, eigenvalues[:10])
|
||||
|
||||
spectral_embeddings = spectral_embeddings / np.linalg.norm(spectral_embeddings, axis=1, ord=2, keepdims=True)
|
||||
solver = cluster.KMeans(n_clusters=k, max_iter=1000, random_state=42)
|
||||
solver.fit(spectral_embeddings)
|
||||
return solver.labels_
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path = sys.argv[1] # dump2/raw/Eval_Ali_far
|
||||
raw_path = sys.argv[2] # data/local/Eval_Ali_far
|
||||
threshold = float(sys.argv[3]) # 0.996
|
||||
sv_threshold = float(sys.argv[4]) # 0.815
|
||||
wav_scp_file = open(path+'/wav.scp', 'r')
|
||||
wav_scp = wav_scp_file.readlines()
|
||||
wav_scp_file.close()
|
||||
raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
|
||||
raw_meeting_scp = raw_meeting_scp_file.readlines()
|
||||
raw_meeting_scp_file.close()
|
||||
segments_scp_file = open(raw_path + '/segments', 'r')
|
||||
segments_scp = segments_scp_file.readlines()
|
||||
segments_scp_file.close()
|
||||
|
||||
segments_map = {}
|
||||
for line in segments_scp:
|
||||
line_list = line.strip().split(' ')
|
||||
meeting = line_list[1]
|
||||
seg = (float(line_list[-2]), float(line_list[-1]))
|
||||
if meeting not in segments_map.keys():
|
||||
segments_map[meeting] = [seg]
|
||||
else:
|
||||
segments_map[meeting].append(seg)
|
||||
|
||||
inference_sv_pipline = pipeline(
|
||||
task=Tasks.speaker_verification,
|
||||
model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
|
||||
)
|
||||
|
||||
chunk_len = int(1.5*16000) # 1.5 seconds
|
||||
hop_len = int(0.75*16000) # 0.75 seconds
|
||||
|
||||
os.system("mkdir -p " + path + "/cluster_profile_infer")
|
||||
cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
|
||||
meeting_map = {}
|
||||
for line in raw_meeting_scp:
|
||||
meeting = line.strip().split('\t')[0]
|
||||
wav_path = line.strip().split('\t')[1]
|
||||
wav = soundfile.read(wav_path)[0]
|
||||
# take the first channel
|
||||
if wav.ndim == 2:
|
||||
wav=wav[:, 0]
|
||||
# gen_seg_embedding
|
||||
segments_list = segments_map[meeting]
|
||||
|
||||
# import ipdb;ipdb.set_trace()
|
||||
all_seg_embedding_list = []
|
||||
for seg in segments_list:
|
||||
wav_seg = wav[int(seg[0] * 16000): int(seg[1] * 16000)]
|
||||
wav_seg_len = wav_seg.shape[0]
|
||||
i = 0
|
||||
while i < wav_seg_len:
|
||||
if i + chunk_len < wav_seg_len:
|
||||
cur_wav_chunk = wav_seg[i: i+chunk_len]
|
||||
else:
|
||||
cur_wav_chunk=wav_seg[i: ]
|
||||
# chunks under 0.2s are ignored
|
||||
if cur_wav_chunk.shape[0] >= 0.2 * 16000:
|
||||
cur_chunk_embedding = inference_sv_pipline(audio_in=cur_wav_chunk)["spk_embedding"]
|
||||
all_seg_embedding_list.append(cur_chunk_embedding)
|
||||
i += hop_len
|
||||
all_seg_embedding = np.vstack(all_seg_embedding_list)
|
||||
# all_seg_embedding (n, dim)
|
||||
|
||||
# compute affinity
|
||||
affinity=cosine_similarity(all_seg_embedding)
|
||||
|
||||
affinity = np.maximum(affinity - sv_threshold, 0.0001) / (affinity.max() - sv_threshold)
|
||||
|
||||
# clustering
|
||||
labels = custom_spectral_clustering(
|
||||
affinity=affinity,
|
||||
min_n_clusters=2,
|
||||
max_n_clusters=4,
|
||||
refine=True,
|
||||
threshold=threshold,
|
||||
laplacian_type="graph_cut")
|
||||
|
||||
|
||||
cluster_dict={}
|
||||
for j in range(labels.shape[0]):
|
||||
if labels[j] not in cluster_dict.keys():
|
||||
cluster_dict[labels[j]] = np.atleast_2d(all_seg_embedding[j])
|
||||
else:
|
||||
cluster_dict[labels[j]] = np.concatenate((cluster_dict[labels[j]], np.atleast_2d(all_seg_embedding[j])))
|
||||
|
||||
emb_list = []
|
||||
# get cluster center
|
||||
for k in cluster_dict.keys():
|
||||
cluster_dict[k] = np.mean(cluster_dict[k], axis=0)
|
||||
emb_list.append(cluster_dict[k])
|
||||
|
||||
spk_num = len(emb_list)
|
||||
profile_for_infer = np.vstack(emb_list)
|
||||
# save profile for each meeting
|
||||
np.save(path + '/cluster_profile_infer/' + meeting + '.npy', profile_for_infer)
|
||||
meeting_map[meeting] = (path + '/cluster_profile_infer/' + meeting + '.npy', spk_num)
|
||||
cluster_spk_num_file.write(meeting + ' ' + str(spk_num) + '\n')
|
||||
cluster_spk_num_file.flush()
|
||||
|
||||
cluster_spk_num_file.close()
|
||||
|
||||
profile_scp = open(path + "/cluster_profile_infer.scp", 'w')
|
||||
for line in wav_scp:
|
||||
uttid = line.strip().split(' ')[0]
|
||||
meeting = uttid.split('-')[0]
|
||||
profile_scp.write(uttid + ' ' + meeting_map[meeting][0] + '\n')
|
||||
profile_scp.flush()
|
||||
profile_scp.close()
|
||||
70
egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
Normal file
70
egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
Normal file
@ -0,0 +1,70 @@
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import soundfile
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
path = sys.argv[1] # dump2/raw/Eval_Ali_far
|
||||
raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
|
||||
raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
|
||||
raw_meeting_scp = raw_meeting_scp_file.readlines()
|
||||
raw_meeting_scp_file.close()
|
||||
segments_scp_file = open(raw_path + '/segments', 'r')
|
||||
segments_scp = segments_scp_file.readlines()
|
||||
segments_scp_file.close()
|
||||
|
||||
oracle_emb_dir = path + '/oracle_embedding/'
|
||||
os.system("mkdir -p " + oracle_emb_dir)
|
||||
oracle_emb_scp_file = open(path+'/oracle_embedding.scp', 'w')
|
||||
|
||||
raw_wav_map = {}
|
||||
for line in raw_meeting_scp:
|
||||
meeting = line.strip().split('\t')[0]
|
||||
wav_path = line.strip().split('\t')[1]
|
||||
raw_wav_map[meeting] = wav_path
|
||||
|
||||
spk_map = {}
|
||||
for line in segments_scp:
|
||||
line_list = line.strip().split(' ')
|
||||
meeting = line_list[1]
|
||||
spk_id = line_list[0].split('_')[3]
|
||||
spk = meeting + '_' + spk_id
|
||||
time_start = float(line_list[-2])
|
||||
time_end = float(line_list[-1])
|
||||
if time_end - time_start > 0.5:
|
||||
if spk not in spk_map.keys():
|
||||
spk_map[spk] = [(int(time_start * 16000), int(time_end * 16000))]
|
||||
else:
|
||||
spk_map[spk].append((int(time_start * 16000), int(time_end * 16000)))
|
||||
|
||||
inference_sv_pipline = pipeline(
|
||||
task=Tasks.speaker_verification,
|
||||
model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
|
||||
)
|
||||
|
||||
for spk in spk_map.keys():
|
||||
meeting = spk.split('_SPK')[0]
|
||||
wav_path = raw_wav_map[meeting]
|
||||
wav = soundfile.read(wav_path)[0]
|
||||
# take the first channel
|
||||
if wav.ndim == 2:
|
||||
wav = wav[:, 0]
|
||||
all_seg_embedding_list=[]
|
||||
# import ipdb;ipdb.set_trace()
|
||||
for seg_time in spk_map[spk]:
|
||||
if seg_time[0] < wav.shape[0] - 0.5 * 16000:
|
||||
if seg_time[1] > wav.shape[0]:
|
||||
cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: ])["spk_embedding"]
|
||||
else:
|
||||
cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: seg_time[1]])["spk_embedding"]
|
||||
all_seg_embedding_list.append(cur_seg_embedding)
|
||||
all_seg_embedding = np.vstack(all_seg_embedding_list)
|
||||
spk_embedding = np.mean(all_seg_embedding, axis=0)
|
||||
np.save(oracle_emb_dir + spk + '.npy', spk_embedding)
|
||||
oracle_emb_scp_file.write(spk + ' ' + oracle_emb_dir + spk + '.npy' + '\n')
|
||||
oracle_emb_scp_file.flush()
|
||||
|
||||
oracle_emb_scp_file.close()
|
||||
59
egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
Normal file
59
egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
Normal file
@ -0,0 +1,59 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
path = sys.argv[1] # dump2/raw/Eval_Ali_far
|
||||
wav_scp_file = open(path+"/wav.scp", 'r')
|
||||
wav_scp = wav_scp_file.readlines()
|
||||
wav_scp_file.close()
|
||||
spk2id_file = open(path + "/spk2id", 'r')
|
||||
spk2id = spk2id_file.readlines()
|
||||
spk2id_file.close()
|
||||
embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
|
||||
embedding_scp = embedding_scp_file.readlines()
|
||||
embedding_scp_file.close()
|
||||
|
||||
embedding_map = {}
|
||||
for line in embedding_scp:
|
||||
spk = line.strip().split(' ')[0]
|
||||
if spk not in embedding_map.keys():
|
||||
emb=np.load(line.strip().split(' ')[1])
|
||||
embedding_map[spk] = emb
|
||||
|
||||
meeting_map_tmp = {}
|
||||
global_spk_list = []
|
||||
for line in spk2id:
|
||||
line_list = line.strip().split(' ')
|
||||
meeting = line_list[0].split('-')[0]
|
||||
spk_id = line_list[0].split('-')[-1].split('_')[-1]
|
||||
spk = meeting + '_' + spk_id
|
||||
global_spk_list.append(spk)
|
||||
if meeting in meeting_map_tmp.keys():
|
||||
meeting_map_tmp[meeting].append(spk)
|
||||
else:
|
||||
meeting_map_tmp[meeting] = [spk]
|
||||
|
||||
meeting_map = {}
|
||||
os.system('mkdir -p ' + path + '/oracle_profile_nopadding')
|
||||
for meeting in meeting_map_tmp.keys():
|
||||
emb_list = []
|
||||
for i in range(len(meeting_map_tmp[meeting])):
|
||||
spk = meeting_map_tmp[meeting][i]
|
||||
emb_list.append(embedding_map[spk])
|
||||
profile = np.vstack(emb_list)
|
||||
np.save(path + '/oracle_profile_nopadding/' + meeting + '.npy', profile)
|
||||
meeting_map[meeting] = path + '/oracle_profile_nopadding/' + meeting + '.npy'
|
||||
|
||||
profile_scp = open(path + '/oracle_profile_nopadding.scp', 'w')
|
||||
profile_map_scp = open(path + '/oracle_profile_nopadding_spk_list', 'w')
|
||||
|
||||
for line in wav_scp:
|
||||
uttid = line.strip().split(' ')[0]
|
||||
meeting = uttid.split('-')[0]
|
||||
profile_scp.write(uttid + ' ' + meeting_map[meeting] + '\n')
|
||||
profile_map_scp.write(uttid + ' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
|
||||
profile_scp.close()
|
||||
profile_map_scp.close()
|
||||
68
egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
Normal file
68
egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
Normal file
@ -0,0 +1,68 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
path = sys.argv[1] # dump2/raw/Train_Ali_far
|
||||
wav_scp_file = open(path+"/wav.scp", 'r')
|
||||
wav_scp = wav_scp_file.readlines()
|
||||
wav_scp_file.close()
|
||||
spk2id_file = open(path+"/spk2id", 'r')
|
||||
spk2id = spk2id_file.readlines()
|
||||
spk2id_file.close()
|
||||
embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
|
||||
embedding_scp = embedding_scp_file.readlines()
|
||||
embedding_scp_file.close()
|
||||
|
||||
embedding_map = {}
|
||||
for line in embedding_scp:
|
||||
spk = line.strip().split(' ')[0]
|
||||
if spk not in embedding_map.keys():
|
||||
emb = np.load(line.strip().split(' ')[1])
|
||||
embedding_map[spk] = emb
|
||||
|
||||
meeting_map_tmp = {}
|
||||
global_spk_list = []
|
||||
for line in spk2id:
|
||||
line_list = line.strip().split(' ')
|
||||
meeting = line_list[0].split('-')[0]
|
||||
spk_id = line_list[0].split('-')[-1].split('_')[-1]
|
||||
spk = meeting+'_' + spk_id
|
||||
global_spk_list.append(spk)
|
||||
if meeting in meeting_map_tmp.keys():
|
||||
meeting_map_tmp[meeting].append(spk)
|
||||
else:
|
||||
meeting_map_tmp[meeting] = [spk]
|
||||
|
||||
for meeting in meeting_map_tmp.keys():
|
||||
num = len(meeting_map_tmp[meeting])
|
||||
if num < 4:
|
||||
global_spk_list_tmp = global_spk_list[: ]
|
||||
for spk in meeting_map_tmp[meeting]:
|
||||
global_spk_list_tmp.remove(spk)
|
||||
padding_spk = random.sample(global_spk_list_tmp, 4 - num)
|
||||
meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk
|
||||
|
||||
meeting_map = {}
|
||||
os.system('mkdir -p ' + path + '/oracle_profile_padding')
|
||||
for meeting in meeting_map_tmp.keys():
|
||||
emb_list = []
|
||||
for i in range(len(meeting_map_tmp[meeting])):
|
||||
spk = meeting_map_tmp[meeting][i]
|
||||
emb_list.append(embedding_map[spk])
|
||||
profile = np.vstack(emb_list)
|
||||
np.save(path + '/oracle_profile_padding/' + meeting + '.npy',profile)
|
||||
meeting_map[meeting] = path + '/oracle_profile_padding/' + meeting + '.npy'
|
||||
|
||||
profile_scp = open(path + '/oracle_profile_padding.scp', 'w')
|
||||
profile_map_scp = open(path + '/oracle_profile_padding_spk_list', 'w')
|
||||
|
||||
for line in wav_scp:
|
||||
uttid = line.strip().split(' ')[0]
|
||||
meeting = uttid.split('-')[0]
|
||||
profile_scp.write(uttid+' ' + meeting_map[meeting] + '\n')
|
||||
profile_map_scp.write(uttid+' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
|
||||
profile_scp.close()
|
||||
profile_map_scp.close()
|
||||
32
egs/alimeeting/sa-asr/local/proce_text.py
Executable file
32
egs/alimeeting/sa-asr/local/proce_text.py
Executable file
@ -0,0 +1,32 @@
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
||||
in_f = sys.argv[1]
|
||||
out_f = sys.argv[2]
|
||||
|
||||
|
||||
with open(in_f, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
with open(out_f, "w", encoding="utf-8") as f:
|
||||
for line in lines:
|
||||
outs = line.strip().split(" ", 1)
|
||||
if len(outs) == 2:
|
||||
idx, text = outs
|
||||
text = re.sub("</s>", "", text)
|
||||
text = re.sub("<s>", "", text)
|
||||
text = re.sub("@@", "", text)
|
||||
text = re.sub("@", "", text)
|
||||
text = re.sub("<unk>", "", text)
|
||||
text = re.sub(" ", "", text)
|
||||
text = re.sub("\$", "", text)
|
||||
text = text.lower()
|
||||
else:
|
||||
idx = outs[0]
|
||||
text = " "
|
||||
|
||||
text = [x for x in text]
|
||||
text = " ".join(text)
|
||||
out = "{} {}\n".format(idx, text)
|
||||
f.write(out)
|
||||
86
egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
Executable file
86
egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
Executable file
@ -0,0 +1,86 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Process the textgrid files
|
||||
"""
|
||||
import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
import pdb
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="process the textgrid files")
|
||||
parser.add_argument("--path", type=str, required=True, help="Data path")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
class Segment(object):
|
||||
def __init__(self, uttid, text):
|
||||
self.uttid = uttid
|
||||
self.text = text
|
||||
|
||||
def main(args):
|
||||
text = codecs.open(Path(args.path) / "text", "r", "utf-8")
|
||||
spk2utt = codecs.open(Path(args.path) / "spk2utt", "r", "utf-8")
|
||||
utt2spk = codecs.open(Path(args.path) / "utt2spk_all_fifo", "r", "utf-8")
|
||||
spk2id = codecs.open(Path(args.path) / "spk2id", "w", "utf-8")
|
||||
|
||||
spkid_map = {}
|
||||
meetingid_map = {}
|
||||
for line in spk2utt:
|
||||
spkid = line.strip().split(" ")[0]
|
||||
meeting_id_list = spkid.split("_")[:3]
|
||||
meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
|
||||
if meeting_id not in meetingid_map:
|
||||
meetingid_map[meeting_id] = 1
|
||||
else:
|
||||
meetingid_map[meeting_id] += 1
|
||||
spkid_map[spkid] = meetingid_map[meeting_id]
|
||||
spk2id.write("%s %s\n" % (spkid, meetingid_map[meeting_id]))
|
||||
|
||||
utt2spklist = {}
|
||||
for line in utt2spk:
|
||||
uttid = line.strip().split(" ")[0]
|
||||
spkid = line.strip().split(" ")[1]
|
||||
spklist = spkid.split("$")
|
||||
tmp = []
|
||||
for index in range(len(spklist)):
|
||||
tmp.append(spkid_map[spklist[index]])
|
||||
utt2spklist[uttid] = tmp
|
||||
# parse the textgrid file for each utterance
|
||||
all_segments = []
|
||||
for line in text:
|
||||
uttid = line.strip().split(" ")[0]
|
||||
context = line.strip().split(" ")[1]
|
||||
spklist = utt2spklist[uttid]
|
||||
length_text = len(context)
|
||||
cnt = 0
|
||||
tmp_text = ""
|
||||
for index in range(length_text):
|
||||
if context[index] != "$":
|
||||
tmp_text += str(spklist[cnt])
|
||||
else:
|
||||
tmp_text += "$"
|
||||
cnt += 1
|
||||
tmp_seg = Segment(uttid,tmp_text)
|
||||
all_segments.append(tmp_seg)
|
||||
|
||||
text.close()
|
||||
utt2spk.close()
|
||||
spk2utt.close()
|
||||
spk2id.close()
|
||||
|
||||
text_id = codecs.open(Path(args.path) / "text_id", "w", "utf-8")
|
||||
|
||||
for i in range(len(all_segments)):
|
||||
uttid_tmp = all_segments[i].uttid
|
||||
text_tmp = all_segments[i].text
|
||||
|
||||
text_id.write("%s %s\n" % (uttid_tmp, text_tmp))
|
||||
|
||||
text_id.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
24
egs/alimeeting/sa-asr/local/process_text_id.py
Normal file
24
egs/alimeeting/sa-asr/local/process_text_id.py
Normal file
@ -0,0 +1,24 @@
|
||||
import sys
|
||||
if __name__=="__main__":
|
||||
path=sys.argv[1]
|
||||
|
||||
text_id_old_file=open(path+"/text_id",'r')
|
||||
text_id_old=text_id_old_file.readlines()
|
||||
text_id_old_file.close()
|
||||
|
||||
text_id=open(path+"/text_id_train",'w')
|
||||
for line in text_id_old:
|
||||
uttid=line.strip().split(' ')[0]
|
||||
old_id=line.strip().split(' ')[1]
|
||||
pre_id='0'
|
||||
new_id_list=[]
|
||||
for i in old_id:
|
||||
if i == '$':
|
||||
new_id_list.append(pre_id)
|
||||
else:
|
||||
new_id_list.append(str(int(i)-1))
|
||||
pre_id=str(int(i)-1)
|
||||
new_id_list.append(pre_id)
|
||||
new_id=' '.join(new_id_list)
|
||||
text_id.write(uttid+' '+new_id+'\n')
|
||||
text_id.close()
|
||||
55
egs/alimeeting/sa-asr/local/process_text_spk_merge.py
Normal file
55
egs/alimeeting/sa-asr/local/process_text_spk_merge.py
Normal file
@ -0,0 +1,55 @@
|
||||
import sys
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path=sys.argv[1]
|
||||
text_scp_file = open(path + '/text', 'r')
|
||||
text_scp = text_scp_file.readlines()
|
||||
text_scp_file.close()
|
||||
text_id_scp_file = open(path + '/text_id', 'r')
|
||||
text_id_scp = text_id_scp_file.readlines()
|
||||
text_id_scp_file.close()
|
||||
text_spk_merge_file = open(path + '/text_spk_merge', 'w')
|
||||
assert len(text_scp) == len(text_id_scp)
|
||||
|
||||
meeting_map = {} # {meeting_id: [(start_time, text, text_id), (start_time, text, text_id), ...]}
|
||||
for i in range(len(text_scp)):
|
||||
text_line = text_scp[i].strip().split(' ')
|
||||
text_id_line = text_id_scp[i].strip().split(' ')
|
||||
assert text_line[0] == text_id_line[0]
|
||||
if len(text_line) > 1:
|
||||
uttid = text_line[0]
|
||||
text = text_line[1]
|
||||
text_id = text_id_line[1]
|
||||
meeting_id = uttid.split('-')[0]
|
||||
start_time = int(uttid.split('-')[-2])
|
||||
if meeting_id not in meeting_map:
|
||||
meeting_map[meeting_id] = [(start_time,text,text_id)]
|
||||
else:
|
||||
meeting_map[meeting_id].append((start_time,text,text_id))
|
||||
|
||||
for meeting_id in sorted(meeting_map.keys()):
|
||||
cur_meeting_list = sorted(meeting_map[meeting_id], key=lambda x: x[0])
|
||||
text_spk_merge_map = {} #{1: text1, 2: text2, ...}
|
||||
for cur_utt in cur_meeting_list:
|
||||
cur_text = cur_utt[1]
|
||||
cur_text_id = cur_utt[2]
|
||||
assert len(cur_text)==len(cur_text_id)
|
||||
if len(cur_text) != 0:
|
||||
cur_text_split = cur_text.split('$')
|
||||
cur_text_id_split = cur_text_id.split('$')
|
||||
assert len(cur_text_split) == len(cur_text_id_split)
|
||||
for i in range(len(cur_text_split)):
|
||||
if len(cur_text_split[i]) != 0:
|
||||
spk_id = int(cur_text_id_split[i][0])
|
||||
if spk_id not in text_spk_merge_map.keys():
|
||||
text_spk_merge_map[spk_id] = cur_text_split[i]
|
||||
else:
|
||||
text_spk_merge_map[spk_id] += cur_text_split[i]
|
||||
text_spk_merge_list = []
|
||||
for spk_id in sorted(text_spk_merge_map.keys()):
|
||||
text_spk_merge_list.append(text_spk_merge_map[spk_id])
|
||||
text_spk_merge_file.write(meeting_id + ' ' + '$'.join(text_spk_merge_list) + '\n')
|
||||
text_spk_merge_file.flush()
|
||||
|
||||
text_spk_merge_file.close()
|
||||
127
egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
Executable file
127
egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
Executable file
@ -0,0 +1,127 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Process the textgrid files
|
||||
"""
|
||||
import argparse
|
||||
import codecs
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
import textgrid
|
||||
import pdb
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
|
||||
|
||||
class Segment(object):
|
||||
def __init__(self, uttid, spkr, stime, etime, text):
|
||||
self.uttid = uttid
|
||||
self.spkr = spkr
|
||||
self.stime = round(stime, 2)
|
||||
self.etime = round(etime, 2)
|
||||
self.text = text
|
||||
|
||||
def change_stime(self, time):
|
||||
self.stime = time
|
||||
|
||||
def change_etime(self, time):
|
||||
self.etime = time
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="process the textgrid files")
|
||||
parser.add_argument("--path", type=str, required=True, help="Data path")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
|
||||
segment_file = codecs.open(Path(args.path)/"segments", "w", "utf-8")
|
||||
utt2spk = codecs.open(Path(args.path)/"utt2spk", "w", "utf-8")
|
||||
|
||||
# get the path of textgrid file for each utterance
|
||||
for line in textgrid_flist:
|
||||
line_array = line.strip().split(" ")
|
||||
path = Path(line_array[1])
|
||||
uttid = line_array[0]
|
||||
|
||||
try:
|
||||
tg = textgrid.TextGrid.fromFile(path)
|
||||
except:
|
||||
pdb.set_trace()
|
||||
num_spk = tg.__len__()
|
||||
spk2textgrid = {}
|
||||
spk2weight = {}
|
||||
weight2spk = {}
|
||||
cnt = 2
|
||||
xmax = 0
|
||||
for i in range(tg.__len__()):
|
||||
spk_name = tg[i].name
|
||||
if spk_name not in spk2weight:
|
||||
spk2weight[spk_name] = cnt
|
||||
weight2spk[cnt] = spk_name
|
||||
cnt = cnt * 2
|
||||
segments = []
|
||||
for j in range(tg[i].__len__()):
|
||||
if tg[i][j].mark:
|
||||
if xmax < tg[i][j].maxTime:
|
||||
xmax = tg[i][j].maxTime
|
||||
segments.append(
|
||||
Segment(
|
||||
uttid,
|
||||
tg[i].name,
|
||||
tg[i][j].minTime,
|
||||
tg[i][j].maxTime,
|
||||
tg[i][j].mark.strip(),
|
||||
)
|
||||
)
|
||||
segments = sorted(segments, key=lambda x: x.stime)
|
||||
spk2textgrid[spk_name] = segments
|
||||
olp_label = np.zeros((num_spk, int(xmax/0.01)), dtype=np.int32)
|
||||
for spkid in spk2weight.keys():
|
||||
weight = spk2weight[spkid]
|
||||
segments = spk2textgrid[spkid]
|
||||
idx = int(math.log2(weight) )- 1
|
||||
for i in range(len(segments)):
|
||||
stime = segments[i].stime
|
||||
etime = segments[i].etime
|
||||
olp_label[idx, int(stime/0.01): int(etime/0.01)] = weight
|
||||
sum_label = olp_label.sum(axis=0)
|
||||
stime = 0
|
||||
pre_value = 0
|
||||
for pos in range(sum_label.shape[0]):
|
||||
if sum_label[pos] in weight2spk:
|
||||
if pre_value in weight2spk:
|
||||
if sum_label[pos] != pre_value:
|
||||
spkids = weight2spk[pre_value]
|
||||
spkid_array = spkids.split("_")
|
||||
spkid = spkid_array[-1]
|
||||
#spkid = uttid+spkid
|
||||
if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
|
||||
segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
|
||||
utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
|
||||
stime = pos
|
||||
pre_value = sum_label[pos]
|
||||
else:
|
||||
stime = pos
|
||||
pre_value = sum_label[pos]
|
||||
else:
|
||||
if pre_value in weight2spk:
|
||||
spkids = weight2spk[pre_value]
|
||||
spkid_array = spkids.split("_")
|
||||
spkid = spkid_array[-1]
|
||||
#spkid = uttid+spkid
|
||||
if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
|
||||
segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
|
||||
utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
|
||||
stime = pos
|
||||
pre_value = sum_label[pos]
|
||||
textgrid_flist.close()
|
||||
segment_file.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
14
egs/alimeeting/sa-asr/local/text_format.pl
Executable file
14
egs/alimeeting/sa-asr/local/text_format.pl
Executable file
@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env perl
|
||||
use warnings; #sed replacement for -w perl parameter
|
||||
# Copyright Chao Weng
|
||||
|
||||
# normalizations for hkust trascript
|
||||
# see the docs/trans-guidelines.pdf for details
|
||||
|
||||
while (<STDIN>) {
|
||||
@A = split(" ", $_);
|
||||
if (@A == 1) {
|
||||
next;
|
||||
}
|
||||
print $_
|
||||
}
|
||||
38
egs/alimeeting/sa-asr/local/text_normalize.pl
Executable file
38
egs/alimeeting/sa-asr/local/text_normalize.pl
Executable file
@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env perl
|
||||
use warnings; #sed replacement for -w perl parameter
|
||||
# Copyright Chao Weng
|
||||
|
||||
# normalizations for hkust trascript
|
||||
# see the docs/trans-guidelines.pdf for details
|
||||
|
||||
while (<STDIN>) {
|
||||
@A = split(" ", $_);
|
||||
print "$A[0] ";
|
||||
for ($n = 1; $n < @A; $n++) {
|
||||
$tmp = $A[$n];
|
||||
if ($tmp =~ /<sil>/) {$tmp =~ s:<sil>::g;}
|
||||
if ($tmp =~ /<%>/) {$tmp =~ s:<%>::g;}
|
||||
if ($tmp =~ /<->/) {$tmp =~ s:<->::g;}
|
||||
if ($tmp =~ /<\$>/) {$tmp =~ s:<\$>::g;}
|
||||
if ($tmp =~ /<#>/) {$tmp =~ s:<#>::g;}
|
||||
if ($tmp =~ /<_>/) {$tmp =~ s:<_>::g;}
|
||||
if ($tmp =~ /<space>/) {$tmp =~ s:<space>::g;}
|
||||
if ($tmp =~ /`/) {$tmp =~ s:`::g;}
|
||||
if ($tmp =~ /&/) {$tmp =~ s:&::g;}
|
||||
if ($tmp =~ /,/) {$tmp =~ s:,::g;}
|
||||
if ($tmp =~ /[a-zA-Z]/) {$tmp=uc($tmp);}
|
||||
if ($tmp =~ /A/) {$tmp =~ s:A:A:g;}
|
||||
if ($tmp =~ /a/) {$tmp =~ s:a:A:g;}
|
||||
if ($tmp =~ /b/) {$tmp =~ s:b:B:g;}
|
||||
if ($tmp =~ /c/) {$tmp =~ s:c:C:g;}
|
||||
if ($tmp =~ /k/) {$tmp =~ s:k:K:g;}
|
||||
if ($tmp =~ /t/) {$tmp =~ s:t:T:g;}
|
||||
if ($tmp =~ /,/) {$tmp =~ s:,::g;}
|
||||
if ($tmp =~ /丶/) {$tmp =~ s:丶::g;}
|
||||
if ($tmp =~ /。/) {$tmp =~ s:。::g;}
|
||||
if ($tmp =~ /、/) {$tmp =~ s:、::g;}
|
||||
if ($tmp =~ /?/) {$tmp =~ s:?::g;}
|
||||
print "$tmp ";
|
||||
}
|
||||
print "\n";
|
||||
}
|
||||
6
egs/alimeeting/sa-asr/path.sh
Executable file
6
egs/alimeeting/sa-asr/path.sh
Executable file
@ -0,0 +1,6 @@
|
||||
export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
export PATH=$PWD/utils/:$PATH
|
||||
243
egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
Executable file
243
egs/alimeeting/sa-asr/pyscripts/audio/format_wav_scp.py
Executable file
@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import kaldiio
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import resampy
|
||||
import soundfile
|
||||
from tqdm import tqdm
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.fileio.read_text import read_2column_text
|
||||
from funasr.fileio.sound_scp import SoundScpWriter
|
||||
|
||||
|
||||
def humanfriendly_or_none(value: str):
|
||||
if value in ("none", "None", "NONE"):
|
||||
return None
|
||||
return humanfriendly.parse_size(value)
|
||||
|
||||
|
||||
def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
|
||||
"""
|
||||
|
||||
>>> str2int_tuple('3,4,5')
|
||||
(3, 4, 5)
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
|
||||
return None
|
||||
return tuple(map(int, integers.strip().split(",")))
|
||||
|
||||
|
||||
def main():
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
logging.info(get_commandline_args())
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Create waves list from "wav.scp"',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("scp")
|
||||
parser.add_argument("outdir")
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
default="wav",
|
||||
help="Specify the prefix word of output file name " 'such as "wav.scp"',
|
||||
)
|
||||
parser.add_argument("--segments", default=None)
|
||||
parser.add_argument(
|
||||
"--fs",
|
||||
type=humanfriendly_or_none,
|
||||
default=None,
|
||||
help="If the sampling rate specified, " "Change the sampling rate.",
|
||||
)
|
||||
parser.add_argument("--audio-format", default="wav")
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--ref-channels", default=None, type=str2int_tuple)
|
||||
group.add_argument("--utt2ref-channels", default=None, type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
out_num_samples = Path(args.outdir) / f"utt2num_samples"
|
||||
|
||||
if args.ref_channels is not None:
|
||||
|
||||
def utt2ref_channels(x) -> Tuple[int, ...]:
|
||||
return args.ref_channels
|
||||
|
||||
elif args.utt2ref_channels is not None:
|
||||
utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
|
||||
|
||||
def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
|
||||
chs_str = d[x]
|
||||
return tuple(map(int, chs_str.split()))
|
||||
|
||||
else:
|
||||
utt2ref_channels = None
|
||||
|
||||
Path(args.outdir).mkdir(parents=True, exist_ok=True)
|
||||
out_wavscp = Path(args.outdir) / f"{args.name}.scp"
|
||||
if args.segments is not None:
|
||||
# Note: kaldiio supports only wav-pcm-int16le file.
|
||||
loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
|
||||
if args.audio_format.endswith("ark"):
|
||||
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
|
||||
fscp = out_wavscp.open("w")
|
||||
else:
|
||||
writer = SoundScpWriter(
|
||||
args.outdir,
|
||||
out_wavscp,
|
||||
format=args.audio_format,
|
||||
)
|
||||
|
||||
with out_num_samples.open("w") as fnum_samples:
|
||||
for uttid, (rate, wave) in tqdm(loader):
|
||||
# wave: (Time,) or (Time, Nmic)
|
||||
if wave.ndim == 2 and utt2ref_channels is not None:
|
||||
wave = wave[:, utt2ref_channels(uttid)]
|
||||
|
||||
if args.fs is not None and args.fs != rate:
|
||||
# FIXME(kamo): To use sox?
|
||||
wave = resampy.resample(
|
||||
wave.astype(np.float64), rate, args.fs, axis=0
|
||||
)
|
||||
wave = wave.astype(np.int16)
|
||||
rate = args.fs
|
||||
if args.audio_format.endswith("ark"):
|
||||
if "flac" in args.audio_format:
|
||||
suf = "flac"
|
||||
elif "wav" in args.audio_format:
|
||||
suf = "wav"
|
||||
else:
|
||||
raise RuntimeError("wav.ark or flac")
|
||||
|
||||
# NOTE(kamo): Using extended ark format style here.
|
||||
# This format is incompatible with Kaldi
|
||||
kaldiio.save_ark(
|
||||
fark,
|
||||
{uttid: (wave, rate)},
|
||||
scp=fscp,
|
||||
append=True,
|
||||
write_function=f"soundfile_{suf}",
|
||||
)
|
||||
|
||||
else:
|
||||
writer[uttid] = rate, wave
|
||||
fnum_samples.write(f"{uttid} {len(wave)}\n")
|
||||
else:
|
||||
if args.audio_format.endswith("ark"):
|
||||
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
|
||||
else:
|
||||
wavdir = Path(args.outdir) / f"data_{args.name}"
|
||||
wavdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with Path(args.scp).open("r") as fscp, out_wavscp.open(
|
||||
"w"
|
||||
) as fout, out_num_samples.open("w") as fnum_samples:
|
||||
for line in tqdm(fscp):
|
||||
uttid, wavpath = line.strip().split(None, 1)
|
||||
|
||||
if wavpath.endswith("|"):
|
||||
# Streaming input e.g. cat a.wav |
|
||||
with kaldiio.open_like_kaldi(wavpath, "rb") as f:
|
||||
with BytesIO(f.read()) as g:
|
||||
wave, rate = soundfile.read(g, dtype=np.int16)
|
||||
if wave.ndim == 2 and utt2ref_channels is not None:
|
||||
wave = wave[:, utt2ref_channels(uttid)]
|
||||
|
||||
if args.fs is not None and args.fs != rate:
|
||||
# FIXME(kamo): To use sox?
|
||||
wave = resampy.resample(
|
||||
wave.astype(np.float64), rate, args.fs, axis=0
|
||||
)
|
||||
wave = wave.astype(np.int16)
|
||||
rate = args.fs
|
||||
|
||||
if args.audio_format.endswith("ark"):
|
||||
if "flac" in args.audio_format:
|
||||
suf = "flac"
|
||||
elif "wav" in args.audio_format:
|
||||
suf = "wav"
|
||||
else:
|
||||
raise RuntimeError("wav.ark or flac")
|
||||
|
||||
# NOTE(kamo): Using extended ark format style here.
|
||||
# This format is incompatible with Kaldi
|
||||
kaldiio.save_ark(
|
||||
fark,
|
||||
{uttid: (wave, rate)},
|
||||
scp=fout,
|
||||
append=True,
|
||||
write_function=f"soundfile_{suf}",
|
||||
)
|
||||
else:
|
||||
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
|
||||
soundfile.write(owavpath, wave, rate)
|
||||
fout.write(f"{uttid} {owavpath}\n")
|
||||
else:
|
||||
wave, rate = soundfile.read(wavpath, dtype=np.int16)
|
||||
if wave.ndim == 2 and utt2ref_channels is not None:
|
||||
wave = wave[:, utt2ref_channels(uttid)]
|
||||
save_asis = False
|
||||
|
||||
elif args.audio_format.endswith("ark"):
|
||||
save_asis = False
|
||||
|
||||
elif Path(wavpath).suffix == "." + args.audio_format and (
|
||||
args.fs is None or args.fs == rate
|
||||
):
|
||||
save_asis = True
|
||||
|
||||
else:
|
||||
save_asis = False
|
||||
|
||||
if save_asis:
|
||||
# Neither --segments nor --fs are specified and
|
||||
# the line doesn't end with "|",
|
||||
# i.e. not using unix-pipe,
|
||||
# only in this case,
|
||||
# just using the original file as is.
|
||||
fout.write(f"{uttid} {wavpath}\n")
|
||||
else:
|
||||
if args.fs is not None and args.fs != rate:
|
||||
# FIXME(kamo): To use sox?
|
||||
wave = resampy.resample(
|
||||
wave.astype(np.float64), rate, args.fs, axis=0
|
||||
)
|
||||
wave = wave.astype(np.int16)
|
||||
rate = args.fs
|
||||
|
||||
if args.audio_format.endswith("ark"):
|
||||
if "flac" in args.audio_format:
|
||||
suf = "flac"
|
||||
elif "wav" in args.audio_format:
|
||||
suf = "wav"
|
||||
else:
|
||||
raise RuntimeError("wav.ark or flac")
|
||||
|
||||
# NOTE(kamo): Using extended ark format style here.
|
||||
# This format is not supported in Kaldi.
|
||||
kaldiio.save_ark(
|
||||
fark,
|
||||
{uttid: (wave, rate)},
|
||||
scp=fout,
|
||||
append=True,
|
||||
write_function=f"soundfile_{suf}",
|
||||
)
|
||||
else:
|
||||
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
|
||||
soundfile.write(owavpath, wave, rate)
|
||||
fout.write(f"{uttid} {owavpath}\n")
|
||||
fnum_samples.write(f"{uttid} {len(wave)}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
45
egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
Executable file
45
egs/alimeeting/sa-asr/pyscripts/utils/print_args.py
Executable file
@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
|
||||
|
||||
def get_commandline_args(no_executable=True):
|
||||
extra_chars = [
|
||||
" ",
|
||||
";",
|
||||
"&",
|
||||
"|",
|
||||
"<",
|
||||
">",
|
||||
"?",
|
||||
"*",
|
||||
"~",
|
||||
"`",
|
||||
'"',
|
||||
"'",
|
||||
"\\",
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
]
|
||||
|
||||
# Escape the extra characters for shell
|
||||
argv = [
|
||||
arg.replace("'", "'\\''")
|
||||
if all(char not in arg for char in extra_chars)
|
||||
else "'" + arg.replace("'", "'\\''") + "'"
|
||||
for arg in sys.argv
|
||||
]
|
||||
|
||||
if no_executable:
|
||||
return " ".join(argv[1:])
|
||||
else:
|
||||
return sys.executable + " " + " ".join(argv)
|
||||
|
||||
|
||||
def main():
|
||||
print(get_commandline_args())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
51
egs/alimeeting/sa-asr/run_m2met_2023.sh
Executable file
51
egs/alimeeting/sa-asr/run_m2met_2023.sh
Executable file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env bash
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
ngpu=4
|
||||
device="0,1,2,3"
|
||||
|
||||
#stage 1 creat both near and far
|
||||
stage=1
|
||||
stop_stage=18
|
||||
|
||||
|
||||
train_set=Train_Ali_far
|
||||
valid_set=Eval_Ali_far
|
||||
test_sets="Test_Ali_far"
|
||||
asr_config=conf/train_asr_conformer.yaml
|
||||
sa_asr_config=conf/train_sa_asr_conformer.yaml
|
||||
inference_config=conf/decode_asr_rnn.yaml
|
||||
|
||||
lm_config=conf/train_lm_transformer.yaml
|
||||
use_lm=false
|
||||
use_wordlm=false
|
||||
./asr_local.sh \
|
||||
--device ${device} \
|
||||
--ngpu ${ngpu} \
|
||||
--stage ${stage} \
|
||||
--stop_stage ${stop_stage} \
|
||||
--gpu_inference true \
|
||||
--njob_infer 4 \
|
||||
--asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
|
||||
--sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
|
||||
--asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
|
||||
--lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
|
||||
--lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
|
||||
--lang zh \
|
||||
--audio_format wav \
|
||||
--feats_type raw \
|
||||
--token_type char \
|
||||
--use_lm ${use_lm} \
|
||||
--use_word_lm ${use_wordlm} \
|
||||
--lm_config "${lm_config}" \
|
||||
--asr_config "${asr_config}" \
|
||||
--sa_asr_config "${sa_asr_config}" \
|
||||
--inference_config "${inference_config}" \
|
||||
--train_set "${train_set}" \
|
||||
--valid_set "${valid_set}" \
|
||||
--test_sets "${test_sets}" \
|
||||
--lm_train_text "data/${train_set}/text" "$@"
|
||||
50
egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
Executable file
50
egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env bash
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
ngpu=4
|
||||
device="0,1,2,3"
|
||||
|
||||
stage=1
|
||||
stop_stage=4
|
||||
|
||||
|
||||
train_set=Train_Ali_far
|
||||
valid_set=Eval_Ali_far
|
||||
test_sets="Test_2023_Ali_far"
|
||||
asr_config=conf/train_asr_conformer.yaml
|
||||
sa_asr_config=conf/train_sa_asr_conformer.yaml
|
||||
inference_config=conf/decode_asr_rnn.yaml
|
||||
|
||||
lm_config=conf/train_lm_transformer.yaml
|
||||
use_lm=false
|
||||
use_wordlm=false
|
||||
./asr_local_infer.sh \
|
||||
--device ${device} \
|
||||
--ngpu ${ngpu} \
|
||||
--stage ${stage} \
|
||||
--stop_stage ${stop_stage} \
|
||||
--gpu_inference true \
|
||||
--njob_infer 4 \
|
||||
--asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
|
||||
--sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
|
||||
--asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
|
||||
--lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
|
||||
--lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
|
||||
--lang zh \
|
||||
--audio_format wav \
|
||||
--feats_type raw \
|
||||
--token_type char \
|
||||
--use_lm ${use_lm} \
|
||||
--use_word_lm ${use_wordlm} \
|
||||
--lm_config "${lm_config}" \
|
||||
--asr_config "${asr_config}" \
|
||||
--sa_asr_config "${sa_asr_config}" \
|
||||
--inference_config "${inference_config}" \
|
||||
--train_set "${train_set}" \
|
||||
--valid_set "${valid_set}" \
|
||||
--test_sets "${test_sets}" \
|
||||
--lm_train_text "data/${train_set}/text" "$@"
|
||||
142
egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
Executable file
142
egs/alimeeting/sa-asr/scripts/audio/format_wav_scp.sh
Executable file
@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
SECONDS=0
|
||||
log() {
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
help_message=$(cat << EOF
|
||||
Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
|
||||
e.g.
|
||||
$0 data/test/wav.scp data/test_format/
|
||||
|
||||
Format 'wav.scp': In short words,
|
||||
changing "kaldi-datadir" to "modified-kaldi-datadir"
|
||||
|
||||
The 'wav.scp' format in kaldi is very flexible,
|
||||
e.g. It can use unix-pipe as describing that wav file,
|
||||
but it sometime looks confusing and make scripts more complex.
|
||||
This tools creates actual wav files from 'wav.scp'
|
||||
and also segments wav files using 'segments'.
|
||||
|
||||
Options
|
||||
--fs <fs>
|
||||
--segments <segments>
|
||||
--nj <nj>
|
||||
--cmd <cmd>
|
||||
EOF
|
||||
)
|
||||
|
||||
out_filename=wav.scp
|
||||
cmd=utils/run.pl
|
||||
nj=30
|
||||
fs=none
|
||||
segments=
|
||||
|
||||
ref_channels=
|
||||
utt2ref_channels=
|
||||
|
||||
audio_format=wav
|
||||
write_utt2num_samples=true
|
||||
|
||||
log "$0 $*"
|
||||
. utils/parse_options.sh
|
||||
|
||||
if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
|
||||
log "${help_message}"
|
||||
log "Error: invalid command line arguments"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
. ./path.sh # Setup the environment
|
||||
|
||||
scp=$1
|
||||
if [ ! -f "${scp}" ]; then
|
||||
log "${help_message}"
|
||||
echo "$0: Error: No such file: ${scp}"
|
||||
exit 1
|
||||
fi
|
||||
dir=$2
|
||||
|
||||
|
||||
if [ $# -eq 2 ]; then
|
||||
logdir=${dir}/logs
|
||||
outdir=${dir}/data
|
||||
|
||||
elif [ $# -eq 3 ]; then
|
||||
logdir=$3
|
||||
outdir=${dir}/data
|
||||
|
||||
elif [ $# -eq 4 ]; then
|
||||
logdir=$3
|
||||
outdir=$4
|
||||
fi
|
||||
|
||||
|
||||
mkdir -p ${logdir}
|
||||
|
||||
rm -f "${dir}/${out_filename}"
|
||||
|
||||
|
||||
opts=
|
||||
if [ -n "${utt2ref_channels}" ]; then
|
||||
opts="--utt2ref-channels ${utt2ref_channels} "
|
||||
elif [ -n "${ref_channels}" ]; then
|
||||
opts="--ref-channels ${ref_channels} "
|
||||
fi
|
||||
|
||||
|
||||
if [ -n "${segments}" ]; then
|
||||
log "[info]: using ${segments}"
|
||||
nutt=$(<${segments} wc -l)
|
||||
nj=$((nj<nutt?nj:nutt))
|
||||
|
||||
split_segments=""
|
||||
for n in $(seq ${nj}); do
|
||||
split_segments="${split_segments} ${logdir}/segments.${n}"
|
||||
done
|
||||
|
||||
utils/split_scp.pl "${segments}" ${split_segments}
|
||||
|
||||
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
|
||||
pyscripts/audio/format_wav_scp.py \
|
||||
${opts} \
|
||||
--fs ${fs} \
|
||||
--audio-format "${audio_format}" \
|
||||
"--segment=${logdir}/segments.JOB" \
|
||||
"${scp}" "${outdir}/format.JOB"
|
||||
|
||||
else
|
||||
log "[info]: without segments"
|
||||
nutt=$(<${scp} wc -l)
|
||||
nj=$((nj<nutt?nj:nutt))
|
||||
|
||||
split_scps=""
|
||||
for n in $(seq ${nj}); do
|
||||
split_scps="${split_scps} ${logdir}/wav.${n}.scp"
|
||||
done
|
||||
|
||||
utils/split_scp.pl "${scp}" ${split_scps}
|
||||
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
|
||||
pyscripts/audio/format_wav_scp.py \
|
||||
${opts} \
|
||||
--fs "${fs}" \
|
||||
--audio-format "${audio_format}" \
|
||||
"${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
|
||||
fi
|
||||
|
||||
# Workaround for the NFS problem
|
||||
ls ${outdir}/format.* > /dev/null
|
||||
|
||||
# concatenate the .scp files together.
|
||||
for n in $(seq ${nj}); do
|
||||
cat "${outdir}/format.${n}/wav.scp" || exit 1;
|
||||
done > "${dir}/${out_filename}" || exit 1
|
||||
|
||||
if "${write_utt2num_samples}"; then
|
||||
for n in $(seq ${nj}); do
|
||||
cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
|
||||
done > "${dir}/utt2num_samples" || exit 1
|
||||
fi
|
||||
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
116
egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
Executable file
116
egs/alimeeting/sa-asr/scripts/utils/perturb_data_dir_speed.sh
Executable file
@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# 2020 @kamo-naoyuki
|
||||
# This file was copied from Kaldi and
|
||||
# I deleted parts related to wav duration
|
||||
# because we shouldn't use kaldi's command here
|
||||
# and we don't need the files actually.
|
||||
|
||||
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
|
||||
# 2014 Tom Ko
|
||||
# 2018 Emotech LTD (author: Pawel Swietojanski)
|
||||
# Apache 2.0
|
||||
|
||||
# This script operates on a directory, such as in data/train/,
|
||||
# that contains some subset of the following files:
|
||||
# wav.scp
|
||||
# spk2utt
|
||||
# utt2spk
|
||||
# text
|
||||
#
|
||||
# It generates the files which are used for perturbing the speed of the original data.
|
||||
|
||||
export LC_ALL=C
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# != 3 ]]; then
|
||||
echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
|
||||
echo "e.g.:"
|
||||
echo " $0 0.9 data/train_si284 data/train_si284p"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
factor=$1
|
||||
srcdir=$2
|
||||
destdir=$3
|
||||
label="sp"
|
||||
spk_prefix="${label}${factor}-"
|
||||
utt_prefix="${label}${factor}-"
|
||||
|
||||
#check is sox on the path
|
||||
|
||||
! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
|
||||
|
||||
if [[ ! -f ${srcdir}/utt2spk ]]; then
|
||||
echo "$0: no such file ${srcdir}/utt2spk"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [[ ${destdir} == "${srcdir}" ]]; then
|
||||
echo "$0: this script requires <srcdir> and <destdir> to be different."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "${destdir}"
|
||||
|
||||
<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
|
||||
<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
|
||||
<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
|
||||
if [[ ! -f ${srcdir}/utt2uniq ]]; then
|
||||
<"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
|
||||
else
|
||||
<"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
|
||||
fi
|
||||
|
||||
|
||||
<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
|
||||
utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
|
||||
|
||||
if [[ -f ${srcdir}/segments ]]; then
|
||||
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
|
||||
utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
|
||||
awk -v factor="${factor}" \
|
||||
'{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
|
||||
>"${destdir}"/segments
|
||||
|
||||
utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
|
||||
awk -v factor="${factor}" \
|
||||
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
|
||||
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
|
||||
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
|
||||
> "${destdir}"/wav.scp
|
||||
if [[ -f ${srcdir}/reco2file_and_channel ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/reco_map \
|
||||
<"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
|
||||
fi
|
||||
|
||||
else # no segments->wav indexed by utterance.
|
||||
if [[ -f ${srcdir}/wav.scp ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
|
||||
awk -v factor="${factor}" \
|
||||
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
|
||||
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
|
||||
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
|
||||
> "${destdir}"/wav.scp
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -f ${srcdir}/text ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
|
||||
fi
|
||||
if [[ -f ${srcdir}/spk2gender ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
|
||||
fi
|
||||
if [[ -f ${srcdir}/utt2lang ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
|
||||
fi
|
||||
|
||||
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
|
||||
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
|
||||
|
||||
utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
|
||||
97
egs/alimeeting/sa-asr/utils/apply_map.pl
Executable file
97
egs/alimeeting/sa-asr/utils/apply_map.pl
Executable file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env perl
|
||||
use warnings; #sed replacement for -w perl parameter
|
||||
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
|
||||
# Apache 2.0.
|
||||
|
||||
# This program is a bit like ./sym2int.pl in that it applies a map
|
||||
# to things in a file, but it's a bit more general in that it doesn't
|
||||
# assume the things being mapped to are single tokens, they could
|
||||
# be sequences of tokens. See the usage message.
|
||||
|
||||
|
||||
$permissive = 0;
|
||||
|
||||
for ($x = 0; $x <= 2; $x++) {
|
||||
|
||||
if (@ARGV > 0 && $ARGV[0] eq "-f") {
|
||||
shift @ARGV;
|
||||
$field_spec = shift @ARGV;
|
||||
if ($field_spec =~ m/^\d+$/) {
|
||||
$field_begin = $field_spec - 1; $field_end = $field_spec - 1;
|
||||
}
|
||||
if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
|
||||
if ($1 ne "") {
|
||||
$field_begin = $1 - 1; # Change to zero-based indexing.
|
||||
}
|
||||
if ($2 ne "") {
|
||||
$field_end = $2 - 1; # Change to zero-based indexing.
|
||||
}
|
||||
}
|
||||
if (!defined $field_begin && !defined $field_end) {
|
||||
die "Bad argument to -f option: $field_spec";
|
||||
}
|
||||
}
|
||||
|
||||
if (@ARGV > 0 && $ARGV[0] eq '--permissive') {
|
||||
shift @ARGV;
|
||||
# Mapping is optional (missing key is printed to output)
|
||||
$permissive = 1;
|
||||
}
|
||||
}
|
||||
|
||||
if(@ARGV != 1) {
|
||||
print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n";
|
||||
print STDERR <<'EOF';
|
||||
Usage: apply_map.pl [options] map <input >output
|
||||
options: [-f <field-range> ] [--permissive]
|
||||
This applies a map to some specified fields of some input text:
|
||||
For each line in the map file: the first field is the thing we
|
||||
map from, and the remaining fields are the sequence we map it to.
|
||||
The -f (field-range) option says which fields of the input file the map
|
||||
map should apply to.
|
||||
If the --permissive option is supplied, fields which are not present
|
||||
in the map will be left as they were.
|
||||
Applies the map 'map' to all input text, where each line of the map
|
||||
is interpreted as a map from the first field to the list of the other fields
|
||||
Note: <field-range> can look like 4-5, or 4-, or 5-, or 1, it means the field
|
||||
range in the input to apply the map to.
|
||||
e.g.: echo A B | apply_map.pl a.txt
|
||||
where a.txt is:
|
||||
A a1 a2
|
||||
B b
|
||||
will produce:
|
||||
a1 a2 b
|
||||
EOF
|
||||
exit(1);
|
||||
}
|
||||
|
||||
($map_file) = @ARGV;
|
||||
open(M, "<$map_file") || die "Error opening map file $map_file: $!";
|
||||
|
||||
while (<M>) {
|
||||
@A = split(" ", $_);
|
||||
@A >= 1 || die "apply_map.pl: empty line.";
|
||||
$i = shift @A;
|
||||
$o = join(" ", @A);
|
||||
$map{$i} = $o;
|
||||
}
|
||||
|
||||
while(<STDIN>) {
|
||||
@A = split(" ", $_);
|
||||
for ($x = 0; $x < @A; $x++) {
|
||||
if ( (!defined $field_begin || $x >= $field_begin)
|
||||
&& (!defined $field_end || $x <= $field_end)) {
|
||||
$a = $A[$x];
|
||||
if (!defined $map{$a}) {
|
||||
if (!$permissive) {
|
||||
die "apply_map.pl: undefined key $a in $map_file\n";
|
||||
} else {
|
||||
print STDERR "apply_map.pl: warning! missing key $a in $map_file\n";
|
||||
}
|
||||
} else {
|
||||
$A[$x] = $map{$a};
|
||||
}
|
||||
}
|
||||
}
|
||||
print join(" ", @A) . "\n";
|
||||
}
|
||||
146
egs/alimeeting/sa-asr/utils/combine_data.sh
Executable file
146
egs/alimeeting/sa-asr/utils/combine_data.sh
Executable file
@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0.
|
||||
# 2014 David Snyder
|
||||
|
||||
# This script combines the data from multiple source directories into
|
||||
# a single destination directory.
|
||||
|
||||
# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information
|
||||
# about what these directories contain.
|
||||
|
||||
# Begin configuration section.
|
||||
extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..."
|
||||
skip_fix=false # skip the fix_data_dir.sh in the end
|
||||
# End configuration section.
|
||||
|
||||
echo "$0 $@" # Print the command line for logging
|
||||
|
||||
if [ -f path.sh ]; then . ./path.sh; fi
|
||||
. parse_options.sh || exit 1;
|
||||
|
||||
if [ $# -lt 2 ]; then
|
||||
echo "Usage: combine_data.sh [--extra-files 'file1 file2'] <dest-data-dir> <src-data-dir1> <src-data-dir2> ..."
|
||||
echo "Note, files that don't appear in all source dirs will not be combined,"
|
||||
echo "with the exception of utt2uniq and segments, which are created where necessary."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
dest=$1;
|
||||
shift;
|
||||
|
||||
first_src=$1;
|
||||
|
||||
rm -r $dest 2>/dev/null || true
|
||||
mkdir -p $dest;
|
||||
|
||||
export LC_ALL=C
|
||||
|
||||
for dir in $*; do
|
||||
if [ ! -f $dir/utt2spk ]; then
|
||||
echo "$0: no such file $dir/utt2spk"
|
||||
exit 1;
|
||||
fi
|
||||
done
|
||||
|
||||
# Check that frame_shift are compatible, where present together with features.
|
||||
dir_with_frame_shift=
|
||||
for dir in $*; do
|
||||
if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then
|
||||
if [[ $dir_with_frame_shift ]] &&
|
||||
! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then
|
||||
echo "$0:error: different frame_shift in directories $dir and " \
|
||||
"$dir_with_frame_shift. Cannot combine features."
|
||||
exit 1;
|
||||
fi
|
||||
dir_with_frame_shift=$dir
|
||||
fi
|
||||
done
|
||||
|
||||
# W.r.t. utt2uniq file the script has different behavior compared to other files
|
||||
# it is not compulsary for it to exist in src directories, but if it exists in
|
||||
# even one it should exist in all. We will create the files where necessary
|
||||
has_utt2uniq=false
|
||||
for in_dir in $*; do
|
||||
if [ -f $in_dir/utt2uniq ]; then
|
||||
has_utt2uniq=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if $has_utt2uniq; then
|
||||
# we are going to create an utt2uniq file in the destdir
|
||||
for in_dir in $*; do
|
||||
if [ ! -f $in_dir/utt2uniq ]; then
|
||||
# we assume that utt2uniq is a one to one mapping
|
||||
cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}'
|
||||
else
|
||||
cat $in_dir/utt2uniq
|
||||
fi
|
||||
done | sort -k1 > $dest/utt2uniq
|
||||
echo "$0: combined utt2uniq"
|
||||
else
|
||||
echo "$0 [info]: not combining utt2uniq as it does not exist"
|
||||
fi
|
||||
# some of the old scripts might provide utt2uniq as an extrafile, so just remove it
|
||||
extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g")
|
||||
|
||||
# segments are treated similarly to utt2uniq. If it exists in some, but not all
|
||||
# src directories, then we generate segments where necessary.
|
||||
has_segments=false
|
||||
for in_dir in $*; do
|
||||
if [ -f $in_dir/segments ]; then
|
||||
has_segments=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if $has_segments; then
|
||||
for in_dir in $*; do
|
||||
if [ ! -f $in_dir/segments ]; then
|
||||
echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
|
||||
utils/data/get_segments_for_data.sh $in_dir
|
||||
else
|
||||
cat $in_dir/segments
|
||||
fi
|
||||
done | sort -k1 > $dest/segments
|
||||
echo "$0: combined segments"
|
||||
else
|
||||
echo "$0 [info]: not combining segments as it does not exist"
|
||||
fi
|
||||
|
||||
for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do
|
||||
exists_somewhere=false
|
||||
absent_somewhere=false
|
||||
for d in $*; do
|
||||
if [ -f $d/$file ]; then
|
||||
exists_somewhere=true
|
||||
else
|
||||
absent_somewhere=true
|
||||
fi
|
||||
done
|
||||
|
||||
if ! $absent_somewhere; then
|
||||
set -o pipefail
|
||||
( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1;
|
||||
set +o pipefail
|
||||
echo "$0: combined $file"
|
||||
else
|
||||
if ! $exists_somewhere; then
|
||||
echo "$0 [info]: not combining $file as it does not exist"
|
||||
else
|
||||
echo "$0 [info]: **not combining $file as it does not exist everywhere**"
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
|
||||
|
||||
if [[ $dir_with_frame_shift ]]; then
|
||||
cp $dir_with_frame_shift/frame_shift $dest
|
||||
fi
|
||||
|
||||
if ! $skip_fix ; then
|
||||
utils/fix_data_dir.sh $dest || exit 1;
|
||||
fi
|
||||
|
||||
exit 0
|
||||
145
egs/alimeeting/sa-asr/utils/copy_data_dir.sh
Executable file
145
egs/alimeeting/sa-asr/utils/copy_data_dir.sh
Executable file
@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
|
||||
# Apache 2.0
|
||||
|
||||
# This script operates on a directory, such as in data/train/,
|
||||
# that contains some subset of the following files:
|
||||
# feats.scp
|
||||
# wav.scp
|
||||
# vad.scp
|
||||
# spk2utt
|
||||
# utt2spk
|
||||
# text
|
||||
#
|
||||
# It copies to another directory, possibly adding a specified prefix or a suffix
|
||||
# to the utterance and/or speaker names. Note, the recording-ids stay the same.
|
||||
#
|
||||
|
||||
|
||||
# begin configuration section
|
||||
spk_prefix=
|
||||
utt_prefix=
|
||||
spk_suffix=
|
||||
utt_suffix=
|
||||
validate_opts= # should rarely be needed.
|
||||
# end configuration section
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: "
|
||||
echo " $0 [options] <srcdir> <destdir>"
|
||||
echo "e.g.:"
|
||||
echo " $0 --spk-prefix=1- --utt-prefix=1- data/train data/train_1"
|
||||
echo "Options"
|
||||
echo " --spk-prefix=<prefix> # Prefix for speaker ids, default empty"
|
||||
echo " --utt-prefix=<prefix> # Prefix for utterance ids, default empty"
|
||||
echo " --spk-suffix=<suffix> # Suffix for speaker ids, default empty"
|
||||
echo " --utt-suffix=<suffix> # Suffix for utterance ids, default empty"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
|
||||
export LC_ALL=C
|
||||
|
||||
srcdir=$1
|
||||
destdir=$2
|
||||
|
||||
if [ ! -f $srcdir/utt2spk ]; then
|
||||
echo "copy_data_dir.sh: no such file $srcdir/utt2spk"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ "$destdir" == "$srcdir" ]; then
|
||||
echo "$0: this script requires <srcdir> and <destdir> to be different."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set -e;
|
||||
|
||||
mkdir -p $destdir
|
||||
|
||||
cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/utt_map
|
||||
cat $srcdir/spk2utt | awk -v p=$spk_prefix -v s=$spk_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/spk_map
|
||||
|
||||
if [ ! -f $srcdir/utt2uniq ]; then
|
||||
if [[ ! -z $utt_prefix || ! -z $utt_suffix ]]; then
|
||||
cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $1);}' > $destdir/utt2uniq
|
||||
fi
|
||||
else
|
||||
cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
|
||||
fi
|
||||
|
||||
cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \
|
||||
utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
|
||||
|
||||
if [ -f $srcdir/feats.scp ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
|
||||
fi
|
||||
|
||||
if [ -f $srcdir/vad.scp ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
|
||||
fi
|
||||
|
||||
if [ -f $srcdir/segments ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
|
||||
cp $srcdir/wav.scp $destdir
|
||||
else # no segments->wav indexed by utt.
|
||||
if [ -f $srcdir/wav.scp ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f $srcdir/reco2file_and_channel ]; then
|
||||
cp $srcdir/reco2file_and_channel $destdir/
|
||||
fi
|
||||
|
||||
if [ -f $srcdir/text ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
|
||||
fi
|
||||
if [ -f $srcdir/utt2dur ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
|
||||
fi
|
||||
if [ -f $srcdir/utt2num_frames ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
|
||||
fi
|
||||
if [ -f $srcdir/reco2dur ]; then
|
||||
if [ -f $srcdir/segments ]; then
|
||||
cp $srcdir/reco2dur $destdir/reco2dur
|
||||
else
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
|
||||
fi
|
||||
fi
|
||||
if [ -f $srcdir/spk2gender ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
|
||||
fi
|
||||
if [ -f $srcdir/cmvn.scp ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
|
||||
fi
|
||||
for f in frame_shift stm glm ctm; do
|
||||
if [ -f $srcdir/$f ]; then
|
||||
cp $srcdir/$f $destdir
|
||||
fi
|
||||
done
|
||||
|
||||
rm $destdir/spk_map $destdir/utt_map
|
||||
|
||||
echo "$0: copied data from $srcdir to $destdir"
|
||||
|
||||
for f in feats.scp cmvn.scp vad.scp utt2lang utt2uniq utt2dur utt2num_frames text wav.scp reco2file_and_channel frame_shift stm glm ctm; do
|
||||
if [ -f $destdir/$f ] && [ ! -f $srcdir/$f ]; then
|
||||
echo "$0: file $f exists in dest $destdir but not in src $srcdir. Moving it to"
|
||||
echo " ... $destdir/.backup/$f"
|
||||
mkdir -p $destdir/.backup
|
||||
mv $destdir/$f $destdir/.backup/
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
|
||||
[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
|
||||
|
||||
utils/validate_data_dir.sh $validate_opts $destdir
|
||||
143
egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh
Executable file
143
egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh
Executable file
@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2016 Johns Hopkins University (author: Daniel Povey)
|
||||
# 2018 Andrea Carmantini
|
||||
# Apache 2.0
|
||||
|
||||
# This script operates on a data directory, such as in data/train/, and adds the
|
||||
# reco2dur file if it does not already exist. The file 'reco2dur' maps from
|
||||
# recording to the duration of the recording in seconds. This script works it
|
||||
# out from the 'wav.scp' file, or, if utterance-ids are the same as recording-ids, from the
|
||||
# utt2dur file (it first tries interrogating the headers, and if this fails, it reads the wave
|
||||
# files in entirely.)
|
||||
# We could use durations from segments file, but that's not the duration of the recordings
|
||||
# but the sum of utterance lenghts (silence in between could be excluded from segments)
|
||||
# For sum of utterance lenghts:
|
||||
# awk 'FNR==NR{uttdur[$1]=$2;next}
|
||||
# { for(i=2;i<=NF;i++){dur+=uttdur[$i];}
|
||||
# print $1 FS dur; dur=0 }' $data/utt2dur $data/reco2utt
|
||||
|
||||
|
||||
frame_shift=0.01
|
||||
cmd=run.pl
|
||||
nj=4
|
||||
|
||||
. utils/parse_options.sh
|
||||
. ./path.sh
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: $0 [options] <datadir>"
|
||||
echo "e.g.:"
|
||||
echo " $0 data/train"
|
||||
echo " Options:"
|
||||
echo " --frame-shift # frame shift in seconds. Only relevant when we are"
|
||||
echo " # getting duration from feats.scp (default: 0.01). "
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export LC_ALL=C
|
||||
|
||||
data=$1
|
||||
|
||||
|
||||
if [ -s $data/reco2dur ] && \
|
||||
[ $(wc -l < $data/wav.scp) -eq $(wc -l < $data/reco2dur) ]; then
|
||||
echo "$0: $data/reco2dur already exists with the expected length. We won't recompute it."
|
||||
exit 0;
|
||||
fi
|
||||
|
||||
if [ -s $data/utt2dur ] && \
|
||||
[ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ] && \
|
||||
[ ! -s $data/segments ]; then
|
||||
|
||||
echo "$0: $data/wav.scp indexed by utt-id; copying utt2dur to reco2dur"
|
||||
cp $data/utt2dur $data/reco2dur && exit 0;
|
||||
|
||||
elif [ -f $data/wav.scp ]; then
|
||||
echo "$0: obtaining durations from recordings"
|
||||
|
||||
# if the wav.scp contains only lines of the form
|
||||
# utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph |
|
||||
if cat $data/wav.scp | perl -e '
|
||||
while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space.
|
||||
@A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
|
||||
$A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
|
||||
$reco = $A[0]; $sphere_file = $A[4];
|
||||
|
||||
if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
|
||||
$sample_rate = -1; $sample_count = -1;
|
||||
for ($n = 0; $n <= 30; $n++) {
|
||||
$line = <F>;
|
||||
if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
|
||||
if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
|
||||
if ($line =~ m/end_head/) { break; }
|
||||
}
|
||||
close(F);
|
||||
if ($sample_rate == -1 || $sample_count == -1) {
|
||||
die "could not parse sphere header from $sphere_file";
|
||||
}
|
||||
$duration = $sample_count * 1.0 / $sample_rate;
|
||||
print "$reco $duration\n";
|
||||
} ' > $data/reco2dur; then
|
||||
echo "$0: successfully obtained recording lengths from sphere-file headers"
|
||||
else
|
||||
echo "$0: could not get recording lengths from sphere-file headers, using wav-to-duration"
|
||||
if ! command -v wav-to-duration >/dev/null; then
|
||||
echo "$0: wav-to-duration is not on your path"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
read_entire_file=false
|
||||
if grep -q 'sox.*speed' $data/wav.scp; then
|
||||
read_entire_file=true
|
||||
echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
|
||||
echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
|
||||
echo "... perturb_data_dir_speed_3way.sh."
|
||||
fi
|
||||
|
||||
num_recos=$(wc -l <$data/wav.scp)
|
||||
if [ $nj -gt $num_recos ]; then
|
||||
nj=$num_recos
|
||||
fi
|
||||
|
||||
temp_data_dir=$data/wav${nj}split
|
||||
wavscps=$(for n in `seq $nj`; do echo $temp_data_dir/$n/wav.scp; done)
|
||||
subdirs=$(for n in `seq $nj`; do echo $temp_data_dir/$n; done)
|
||||
|
||||
if ! mkdir -p $subdirs >&/dev/null; then
|
||||
for n in `seq $nj`; do
|
||||
mkdir -p $temp_data_dir/$n
|
||||
done
|
||||
fi
|
||||
|
||||
utils/split_scp.pl $data/wav.scp $wavscps
|
||||
|
||||
|
||||
$cmd JOB=1:$nj $data/log/get_reco_durations.JOB.log \
|
||||
wav-to-duration --read-entire-file=$read_entire_file \
|
||||
scp:$temp_data_dir/JOB/wav.scp ark,t:$temp_data_dir/JOB/reco2dur || \
|
||||
{ echo "$0: there was a problem getting the durations"; exit 1; } # This could
|
||||
|
||||
for n in `seq $nj`; do
|
||||
cat $temp_data_dir/$n/reco2dur
|
||||
done > $data/reco2dur
|
||||
fi
|
||||
rm -r $temp_data_dir
|
||||
else
|
||||
echo "$0: Expected $data/wav.scp to exist"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
len1=$(wc -l < $data/wav.scp)
|
||||
len2=$(wc -l < $data/reco2dur)
|
||||
if [ "$len1" != "$len2" ]; then
|
||||
echo "$0: warning: length of reco2dur does not equal that of wav.scp, $len2 != $len1"
|
||||
if [ $len1 -gt $[$len2*2] ]; then
|
||||
echo "$0: less than half of recordings got a duration: failing."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "$0: computed $data/reco2dur"
|
||||
|
||||
exit 0
|
||||
29
egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh
Executable file
29
egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh
Executable file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This script operates on a data directory, such as in data/train/,
|
||||
# and writes new segments to stdout. The file 'segments' maps from
|
||||
# utterance to time offsets into a recording, with the format:
|
||||
# <utterance-id> <recording-id> <segment-begin> <segment-end>
|
||||
# This script assumes utterance and recording ids are the same (i.e., that
|
||||
# wav.scp is indexed by utterance), and uses durations from 'utt2dur',
|
||||
# created if necessary by get_utt2dur.sh.
|
||||
|
||||
. ./path.sh
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: $0 [options] <datadir>"
|
||||
echo "e.g.:"
|
||||
echo " $0 data/train > data/train/segments"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
data=$1
|
||||
|
||||
if [ ! -s $data/utt2dur ]; then
|
||||
utils/data/get_utt2dur.sh $data 1>&2 || exit 1;
|
||||
fi
|
||||
|
||||
# <utt-id> <utt-id> 0 <utt-dur>
|
||||
awk '{ print $1, $1, 0, $2 }' $data/utt2dur
|
||||
|
||||
exit 0
|
||||
135
egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh
Executable file
135
egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh
Executable file
@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2016 Johns Hopkins University (author: Daniel Povey)
|
||||
# Apache 2.0
|
||||
|
||||
# This script operates on a data directory, such as in data/train/, and adds the
|
||||
# utt2dur file if it does not already exist. The file 'utt2dur' maps from
|
||||
# utterance to the duration of the utterance in seconds. This script works it
|
||||
# out from the 'segments' file, or, if not present, from the wav.scp file (it
|
||||
# first tries interrogating the headers, and if this fails, it reads the wave
|
||||
# files in entirely.)
|
||||
|
||||
frame_shift=0.01
|
||||
cmd=run.pl
|
||||
nj=4
|
||||
read_entire_file=false
|
||||
|
||||
. utils/parse_options.sh
|
||||
. ./path.sh
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: $0 [options] <datadir>"
|
||||
echo "e.g.:"
|
||||
echo " $0 data/train"
|
||||
echo " Options:"
|
||||
echo " --frame-shift # frame shift in seconds. Only relevant when we are"
|
||||
echo " # getting duration from feats.scp, and only if the "
|
||||
echo " # file frame_shift does not exist (default: 0.01). "
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export LC_ALL=C
|
||||
|
||||
data=$1
|
||||
|
||||
if [ -s $data/utt2dur ] && \
|
||||
[ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ]; then
|
||||
echo "$0: $data/utt2dur already exists with the expected length. We won't recompute it."
|
||||
exit 0;
|
||||
fi
|
||||
|
||||
if [ -s $data/segments ]; then
|
||||
echo "$0: working out $data/utt2dur from $data/segments"
|
||||
awk '{len=$4-$3; print $1, len;}' < $data/segments > $data/utt2dur
|
||||
elif [[ -s $data/frame_shift && -f $data/utt2num_frames ]]; then
|
||||
echo "$0: computing $data/utt2dur from $data/{frame_shift,utt2num_frames}."
|
||||
frame_shift=$(cat $data/frame_shift) || exit 1
|
||||
# The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
|
||||
awk -v fs=$frame_shift '{ $2=($2+1.5)*fs; print }' <$data/utt2num_frames >$data/utt2dur
|
||||
elif [ -f $data/wav.scp ]; then
|
||||
echo "$0: segments file does not exist so getting durations from wave files"
|
||||
|
||||
# if the wav.scp contains only lines of the form
|
||||
# utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph |
|
||||
if perl <$data/wav.scp -e '
|
||||
while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space.
|
||||
@A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
|
||||
$A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
|
||||
$utt = $A[0]; $sphere_file = $A[4];
|
||||
|
||||
if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
|
||||
$sample_rate = -1; $sample_count = -1;
|
||||
for ($n = 0; $n <= 30; $n++) {
|
||||
$line = <F>;
|
||||
if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
|
||||
if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
|
||||
if ($line =~ m/end_head/) { break; }
|
||||
}
|
||||
close(F);
|
||||
if ($sample_rate == -1 || $sample_count == -1) {
|
||||
die "could not parse sphere header from $sphere_file";
|
||||
}
|
||||
$duration = $sample_count * 1.0 / $sample_rate;
|
||||
print "$utt $duration\n";
|
||||
} ' > $data/utt2dur; then
|
||||
echo "$0: successfully obtained utterance lengths from sphere-file headers"
|
||||
else
|
||||
echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration"
|
||||
if ! command -v wav-to-duration >/dev/null; then
|
||||
echo "$0: wav-to-duration is not on your path"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if grep -q 'sox.*speed' $data/wav.scp; then
|
||||
read_entire_file=true
|
||||
echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
|
||||
echo "... It is much faster if you call get_utt2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
|
||||
echo "... perturb_data_dir_speed_3way.sh."
|
||||
fi
|
||||
|
||||
|
||||
num_utts=$(wc -l <$data/utt2spk)
|
||||
if [ $nj -gt $num_utts ]; then
|
||||
nj=$num_utts
|
||||
fi
|
||||
|
||||
utils/data/split_data.sh --per-utt $data $nj
|
||||
sdata=$data/split${nj}utt
|
||||
|
||||
$cmd JOB=1:$nj $data/log/get_durations.JOB.log \
|
||||
wav-to-duration --read-entire-file=$read_entire_file \
|
||||
scp:$sdata/JOB/wav.scp ark,t:$sdata/JOB/utt2dur || \
|
||||
{ echo "$0: there was a problem getting the durations"; exit 1; }
|
||||
|
||||
for n in `seq $nj`; do
|
||||
cat $sdata/$n/utt2dur
|
||||
done > $data/utt2dur
|
||||
fi
|
||||
elif [ -f $data/feats.scp ]; then
|
||||
echo "$0: wave file does not exist so getting durations from feats files"
|
||||
if [[ -s $data/frame_shift ]]; then
|
||||
frame_shift=$(cat $data/frame_shift) || exit 1
|
||||
echo "$0: using frame_shift=$frame_shift from file $data/frame_shift"
|
||||
fi
|
||||
# The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
|
||||
feat-to-len scp:$data/feats.scp ark,t:- |
|
||||
awk -v frame_shift=$frame_shift '{print $1, ($2+1.5)*frame_shift}' >$data/utt2dur
|
||||
else
|
||||
echo "$0: Expected $data/wav.scp, $data/segments or $data/feats.scp to exist"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
len1=$(wc -l < $data/utt2spk)
|
||||
len2=$(wc -l < $data/utt2dur)
|
||||
if [ "$len1" != "$len2" ]; then
|
||||
echo "$0: warning: length of utt2dur does not equal that of utt2spk, $len2 != $len1"
|
||||
if [ $len1 -gt $[$len2*2] ]; then
|
||||
echo "$0: less than half of utterances got a duration: failing."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "$0: computed $data/utt2dur"
|
||||
|
||||
exit 0
|
||||
160
egs/alimeeting/sa-asr/utils/data/split_data.sh
Executable file
160
egs/alimeeting/sa-asr/utils/data/split_data.sh
Executable file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2010-2013 Microsoft Corporation
|
||||
# Johns Hopkins University (Author: Daniel Povey)
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
split_per_spk=true
|
||||
if [ "$1" == "--per-utt" ]; then
|
||||
split_per_spk=false
|
||||
shift
|
||||
fi
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: $0 [--per-utt] <data-dir> <num-to-split>"
|
||||
echo "E.g.: $0 data/train 50"
|
||||
echo "It creates its output in e.g. data/train/split50/{1,2,3,...50}, or if the "
|
||||
echo "--per-utt option was given, in e.g. data/train/split50utt/{1,2,3,...50}."
|
||||
echo ""
|
||||
echo "This script will not split the data-dir if it detects that the output is newer than the input."
|
||||
echo "By default it splits per speaker (so each speaker is in only one split dir),"
|
||||
echo "but with the --per-utt option it will ignore the speaker information while splitting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
data=$1
|
||||
numsplit=$2
|
||||
|
||||
if ! [ "$numsplit" -gt 0 ]; then
|
||||
echo "Invalid num-split argument $numsplit";
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if $split_per_spk; then
|
||||
warning_opt=
|
||||
else
|
||||
# suppress warnings from filter_scps.pl about 'some input lines were output
|
||||
# to multiple files'.
|
||||
warning_opt="--no-warn"
|
||||
fi
|
||||
|
||||
n=0;
|
||||
feats=""
|
||||
wavs=""
|
||||
utt2spks=""
|
||||
texts=""
|
||||
|
||||
nu=`cat $data/utt2spk | wc -l`
|
||||
nf=`cat $data/feats.scp 2>/dev/null | wc -l`
|
||||
nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
|
||||
if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
|
||||
echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
|
||||
echo "** use utils/fix_data_dir.sh $data to fix this."
|
||||
fi
|
||||
if [ -f $data/text ] && [ $nu -ne $nt ]; then
|
||||
echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
|
||||
echo "** use utils/fix_data_dir.sh to fix this."
|
||||
fi
|
||||
|
||||
|
||||
if $split_per_spk; then
|
||||
utt2spk_opt="--utt2spk=$data/utt2spk"
|
||||
utt=""
|
||||
else
|
||||
utt2spk_opt=
|
||||
utt="utt"
|
||||
fi
|
||||
|
||||
s1=$data/split${numsplit}${utt}/1
|
||||
if [ ! -d $s1 ]; then
|
||||
need_to_split=true
|
||||
else
|
||||
need_to_split=false
|
||||
for f in utt2spk spk2utt spk2warp feats.scp text wav.scp cmvn.scp spk2gender \
|
||||
vad.scp segments reco2file_and_channel utt2lang; do
|
||||
if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then
|
||||
need_to_split=true
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if ! $need_to_split; then
|
||||
exit 0;
|
||||
fi
|
||||
|
||||
utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done)
|
||||
|
||||
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done)
|
||||
|
||||
# if this mkdir fails due to argument-list being too long, iterate.
|
||||
if ! mkdir -p $directories >&/dev/null; then
|
||||
for n in `seq $numsplit`; do
|
||||
mkdir -p $data/split${numsplit}${utt}/$n
|
||||
done
|
||||
fi
|
||||
|
||||
# If lockfile is not installed, just don't lock it. It's not a big deal.
|
||||
which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock
|
||||
trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM
|
||||
|
||||
utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
|
||||
|
||||
for n in `seq $numsplit`; do
|
||||
dsn=$data/split${numsplit}${utt}/$n
|
||||
utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
|
||||
done
|
||||
|
||||
maybe_wav_scp=
|
||||
if [ ! -f $data/segments ]; then
|
||||
maybe_wav_scp=wav.scp # If there is no segments file, then wav file is
|
||||
# indexed per utt.
|
||||
fi
|
||||
|
||||
# split some things that are indexed by utterance.
|
||||
for f in feats.scp text vad.scp utt2lang $maybe_wav_scp utt2dur utt2num_frames; do
|
||||
if [ -f $data/$f ]; then
|
||||
utils/filter_scps.pl JOB=1:$numsplit \
|
||||
$data/split${numsplit}${utt}/JOB/utt2spk $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
|
||||
fi
|
||||
done
|
||||
|
||||
# split some things that are indexed by speaker
|
||||
for f in spk2gender spk2warp cmvn.scp; do
|
||||
if [ -f $data/$f ]; then
|
||||
utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
|
||||
$data/split${numsplit}${utt}/JOB/spk2utt $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -f $data/segments ]; then
|
||||
utils/filter_scps.pl JOB=1:$numsplit \
|
||||
$data/split${numsplit}${utt}/JOB/utt2spk $data/segments $data/split${numsplit}${utt}/JOB/segments || exit 1
|
||||
for n in `seq $numsplit`; do
|
||||
dsn=$data/split${numsplit}${utt}/$n
|
||||
awk '{print $2;}' $dsn/segments | sort | uniq > $dsn/tmp.reco # recording-ids.
|
||||
done
|
||||
if [ -f $data/reco2file_and_channel ]; then
|
||||
utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
|
||||
$data/split${numsplit}${utt}/JOB/tmp.reco $data/reco2file_and_channel \
|
||||
$data/split${numsplit}${utt}/JOB/reco2file_and_channel || exit 1
|
||||
fi
|
||||
if [ -f $data/wav.scp ]; then
|
||||
utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
|
||||
$data/split${numsplit}${utt}/JOB/tmp.reco $data/wav.scp \
|
||||
$data/split${numsplit}${utt}/JOB/wav.scp || exit 1
|
||||
fi
|
||||
for f in $data/split${numsplit}${utt}/*/tmp.reco; do rm $f; done
|
||||
fi
|
||||
|
||||
exit 0
|
||||
87
egs/alimeeting/sa-asr/utils/filter_scp.pl
Executable file
87
egs/alimeeting/sa-asr/utils/filter_scp.pl
Executable file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env perl
|
||||
# Copyright 2010-2012 Microsoft Corporation
|
||||
# Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This script takes a list of utterance-ids or any file whose first field
|
||||
# of each line is an utterance-id, and filters an scp
|
||||
# file (or any file whose "n-th" field is an utterance id), printing
|
||||
# out only those lines whose "n-th" field is in id_list. The index of
|
||||
# the "n-th" field is 1, by default, but can be changed by using
|
||||
# the -f <n> switch
|
||||
|
||||
$exclude = 0;
|
||||
$field = 1;
|
||||
$shifted = 0;
|
||||
|
||||
do {
|
||||
$shifted=0;
|
||||
if ($ARGV[0] eq "--exclude") {
|
||||
$exclude = 1;
|
||||
shift @ARGV;
|
||||
$shifted=1;
|
||||
}
|
||||
if ($ARGV[0] eq "-f") {
|
||||
$field = $ARGV[1];
|
||||
shift @ARGV; shift @ARGV;
|
||||
$shifted=1
|
||||
}
|
||||
} while ($shifted);
|
||||
|
||||
if(@ARGV < 1 || @ARGV > 2) {
|
||||
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
|
||||
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
|
||||
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
|
||||
"only the lines that were *not* in id_list.\n" .
|
||||
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
|
||||
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
|
||||
"-f option, add 1 to the argument.\n" .
|
||||
"See also: utils/filter_scp.pl .\n";
|
||||
}
|
||||
|
||||
|
||||
$idlist = shift @ARGV;
|
||||
open(F, "<$idlist") || die "Could not open id-list file $idlist";
|
||||
while(<F>) {
|
||||
@A = split;
|
||||
@A>=1 || die "Invalid id-list file line $_";
|
||||
$seen{$A[0]} = 1;
|
||||
}
|
||||
|
||||
if ($field == 1) { # Treat this as special case, since it is common.
|
||||
while(<>) {
|
||||
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
|
||||
# $1 is what we filter on.
|
||||
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
|
||||
print $_;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while(<>) {
|
||||
@A = split;
|
||||
@A > 0 || die "Invalid scp file line $_";
|
||||
@A >= $field || die "Invalid scp file line $_";
|
||||
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
|
||||
print $_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# tests:
|
||||
# the following should print "foo 1"
|
||||
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
|
||||
# the following should print "bar 2".
|
||||
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)
|
||||
215
egs/alimeeting/sa-asr/utils/fix_data_dir.sh
Executable file
215
egs/alimeeting/sa-asr/utils/fix_data_dir.sh
Executable file
@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This script makes sure that only the segments present in
|
||||
# all of "feats.scp", "wav.scp" [if present], segments [if present]
|
||||
# text, and utt2spk are present in any of them.
|
||||
# It puts the original contents of data-dir into
|
||||
# data-dir/.backup
|
||||
|
||||
cmd="$@"
|
||||
|
||||
utt_extra_files=
|
||||
spk_extra_files=
|
||||
|
||||
. utils/parse_options.sh
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: utils/data/fix_data_dir.sh <data-dir>"
|
||||
echo "e.g.: utils/data/fix_data_dir.sh data/train"
|
||||
echo "This script helps ensure that the various files in a data directory"
|
||||
echo "are correctly sorted and filtered, for example removing utterances"
|
||||
echo "that have no features (if feats.scp is present)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
data=$1
|
||||
|
||||
if [ -f $data/images.scp ]; then
|
||||
image/fix_data_dir.sh $cmd
|
||||
exit $?
|
||||
fi
|
||||
|
||||
mkdir -p $data/.backup
|
||||
|
||||
[ ! -d $data ] && echo "$0: no such directory $data" && exit 1;
|
||||
|
||||
[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1;
|
||||
|
||||
set -e -o pipefail -u
|
||||
|
||||
tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
|
||||
trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
|
||||
|
||||
export LC_ALL=C
|
||||
|
||||
function check_sorted {
|
||||
file=$1
|
||||
sort -k1,1 -u <$file >$file.tmp
|
||||
if ! cmp -s $file $file.tmp; then
|
||||
echo "$0: file $1 is not in sorted order or not unique, sorting it"
|
||||
mv $file.tmp $file
|
||||
else
|
||||
rm $file.tmp
|
||||
fi
|
||||
}
|
||||
|
||||
for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \
|
||||
reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do
|
||||
if [ -f $data/$x ]; then
|
||||
cp $data/$x $data/.backup/$x
|
||||
check_sorted $data/$x
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
function filter_file {
|
||||
filter=$1
|
||||
file_to_filter=$2
|
||||
cp $file_to_filter ${file_to_filter}.tmp
|
||||
utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter
|
||||
if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then
|
||||
length1=$(cat ${file_to_filter}.tmp | wc -l)
|
||||
length2=$(cat ${file_to_filter} | wc -l)
|
||||
if [ $length1 -ne $length2 ]; then
|
||||
echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter."
|
||||
fi
|
||||
fi
|
||||
rm $file_to_filter.tmp
|
||||
}
|
||||
|
||||
function filter_recordings {
|
||||
# We call this once before the stage when we filter on utterance-id, and once
|
||||
# after.
|
||||
|
||||
if [ -f $data/segments ]; then
|
||||
# We have a segments file -> we need to filter this and the file wav.scp, and
|
||||
# reco2file_and_utt, if it exists, to make sure they have the same list of
|
||||
# recording-ids.
|
||||
|
||||
if [ ! -f $data/wav.scp ]; then
|
||||
echo "$0: $data/segments exists but not $data/wav.scp"
|
||||
exit 1;
|
||||
fi
|
||||
awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings
|
||||
n1=$(cat $tmpdir/recordings | wc -l)
|
||||
[ ! -s $tmpdir/recordings ] && \
|
||||
echo "Empty list of recordings (bad file $data/segments)?" && exit 1;
|
||||
utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp
|
||||
mv $tmpdir/recordings.tmp $tmpdir/recordings
|
||||
|
||||
|
||||
cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
|
||||
filter_file $tmpdir/recordings $data/segments
|
||||
cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
|
||||
rm $data/segments.tmp
|
||||
|
||||
filter_file $tmpdir/recordings $data/wav.scp
|
||||
[ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel
|
||||
[ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur
|
||||
true
|
||||
fi
|
||||
}
|
||||
|
||||
function filter_speakers {
|
||||
# throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
|
||||
utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
|
||||
|
||||
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
|
||||
for s in cmvn.scp spk2gender; do
|
||||
f=$data/$s
|
||||
if [ -f $f ]; then
|
||||
filter_file $f $tmpdir/speakers
|
||||
fi
|
||||
done
|
||||
|
||||
filter_file $tmpdir/speakers $data/spk2utt
|
||||
utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
|
||||
|
||||
for s in cmvn.scp spk2gender $spk_extra_files; do
|
||||
f=$data/$s
|
||||
if [ -f $f ]; then
|
||||
filter_file $tmpdir/speakers $f
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
function filter_utts {
|
||||
cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
|
||||
|
||||
! cat $data/utt2spk | sort | cmp - $data/utt2spk && \
|
||||
echo "utt2spk is not in sorted order (fix this yourself)" && exit 1;
|
||||
|
||||
! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \
|
||||
echo "utt2spk is not in sorted order when sorted first on speaker-id " && \
|
||||
echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
|
||||
|
||||
! cat $data/spk2utt | sort | cmp - $data/spk2utt && \
|
||||
echo "spk2utt is not in sorted order (fix this yourself)" && exit 1;
|
||||
|
||||
if [ -f $data/utt2uniq ]; then
|
||||
! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \
|
||||
echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1;
|
||||
fi
|
||||
|
||||
maybe_wav=
|
||||
maybe_reco2dur=
|
||||
[ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist.
|
||||
[ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts
|
||||
|
||||
maybe_utt2dur=
|
||||
if [ -f $data/utt2dur ]; then
|
||||
cat $data/utt2dur | \
|
||||
awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2dur.ok || exit 1
|
||||
maybe_utt2dur=utt2dur.ok
|
||||
fi
|
||||
|
||||
maybe_utt2num_frames=
|
||||
if [ -f $data/utt2num_frames ]; then
|
||||
cat $data/utt2num_frames | \
|
||||
awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2num_frames.ok || exit 1
|
||||
maybe_utt2num_frames=utt2num_frames.ok
|
||||
fi
|
||||
|
||||
for x in feats.scp text segments utt2lang $maybe_wav $maybe_utt2dur $maybe_utt2num_frames; do
|
||||
if [ -f $data/$x ]; then
|
||||
utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp
|
||||
mv $tmpdir/utts.tmp $tmpdir/utts
|
||||
fi
|
||||
done
|
||||
rm $data/utt2dur.ok 2>/dev/null || true
|
||||
rm $data/utt2num_frames.ok 2>/dev/null || true
|
||||
|
||||
[ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \
|
||||
rm $tmpdir/utts && exit 1;
|
||||
|
||||
|
||||
if [ -f $data/utt2spk ]; then
|
||||
new_nutts=$(cat $tmpdir/utts | wc -l)
|
||||
old_nutts=$(cat $data/utt2spk | wc -l)
|
||||
if [ $new_nutts -ne $old_nutts ]; then
|
||||
echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts"
|
||||
else
|
||||
echo "fix_data_dir.sh: kept all $old_nutts utterances."
|
||||
fi
|
||||
fi
|
||||
|
||||
for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do
|
||||
if [ -f $data/$x ]; then
|
||||
cp $data/$x $data/.backup/$x
|
||||
if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then
|
||||
utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
}
|
||||
|
||||
filter_recordings
|
||||
filter_speakers
|
||||
filter_utts
|
||||
filter_speakers
|
||||
filter_recordings
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
|
||||
|
||||
echo "fix_data_dir.sh: old files are kept in $data/.backup"
|
||||
97
egs/alimeeting/sa-asr/utils/parse_options.sh
Executable file
97
egs/alimeeting/sa-asr/utils/parse_options.sh
Executable file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
||||
# Arnab Ghoshal, Karel Vesely
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Parse command-line options.
|
||||
# To be sourced by another script (as in ". parse_options.sh").
|
||||
# Option format is: --option-name arg
|
||||
# and shell variable "option_name" gets set to value "arg."
|
||||
# The exception is --help, which takes no arguments, but prints the
|
||||
# $help_message variable (if defined).
|
||||
|
||||
|
||||
###
|
||||
### The --config file options have lower priority to command line
|
||||
### options, so we need to import them first...
|
||||
###
|
||||
|
||||
# Now import all the configs specified by command-line, in left-to-right order
|
||||
for ((argpos=1; argpos<$#; argpos++)); do
|
||||
if [ "${!argpos}" == "--config" ]; then
|
||||
argpos_plus1=$((argpos+1))
|
||||
config=${!argpos_plus1}
|
||||
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
||||
. $config # source the config file.
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
###
|
||||
### Now we process the command line options
|
||||
###
|
||||
while true; do
|
||||
[ -z "${1:-}" ] && break; # break if there are no arguments
|
||||
case "$1" in
|
||||
# If the enclosing script is called with --help option, print the help
|
||||
# message and exit. Scripts should put help messages in $help_message
|
||||
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
||||
else printf "$help_message\n" 1>&2 ; fi;
|
||||
exit 0 ;;
|
||||
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
||||
exit 1 ;;
|
||||
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
||||
# then work out the variable name as $name, which will equal "foo_bar".
|
||||
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
||||
# Next we test whether the variable in question is undefned-- if so it's
|
||||
# an invalid option and we die. Note: $0 evaluates to the name of the
|
||||
# enclosing script.
|
||||
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
||||
# is undefined. We then have to wrap this test inside "eval" because
|
||||
# foo_bar is itself inside a variable ($name).
|
||||
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
||||
|
||||
oldval="`eval echo \\$$name`";
|
||||
# Work out whether we seem to be expecting a Boolean argument.
|
||||
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
||||
was_bool=true;
|
||||
else
|
||||
was_bool=false;
|
||||
fi
|
||||
|
||||
# Set the variable to the right value-- the escaped quotes make it work if
|
||||
# the option had spaces, like --cmd "queue.pl -sync y"
|
||||
eval $name=\"$2\";
|
||||
|
||||
# Check that Boolean-valued arguments are really Boolean.
|
||||
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
||||
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
||||
exit 1;
|
||||
fi
|
||||
shift 2;
|
||||
;;
|
||||
*) break;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
# Check for an empty argument to the --cmd option, which can easily occur as a
|
||||
# result of scripting errors.
|
||||
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
||||
|
||||
|
||||
true; # so this script returns exit code 0.
|
||||
27
egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl
Executable file
27
egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl
Executable file
@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
while(<>){
|
||||
@A = split(" ", $_);
|
||||
@A > 1 || die "Invalid line in spk2utt file: $_";
|
||||
$s = shift @A;
|
||||
foreach $u ( @A ) {
|
||||
print "$u $s\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
246
egs/alimeeting/sa-asr/utils/split_scp.pl
Executable file
246
egs/alimeeting/sa-asr/utils/split_scp.pl
Executable file
@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env perl
|
||||
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# See ../../COPYING for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This program splits up any kind of .scp or archive-type file.
|
||||
# If there is no utt2spk option it will work on any text file and
|
||||
# will split it up with an approximately equal number of lines in
|
||||
# each but.
|
||||
# With the --utt2spk option it will work on anything that has the
|
||||
# utterance-id as the first entry on each line; the utt2spk file is
|
||||
# of the form "utterance speaker" (on each line).
|
||||
# It splits it into equal size chunks as far as it can. If you use the utt2spk
|
||||
# option it will make sure these chunks coincide with speaker boundaries. In
|
||||
# this case, if there are more chunks than speakers (and in some other
|
||||
# circumstances), some of the resulting chunks will be empty and it will print
|
||||
# an error message and exit with nonzero status.
|
||||
# You will normally call this like:
|
||||
# split_scp.pl scp scp.1 scp.2 scp.3 ...
|
||||
# or
|
||||
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
|
||||
# Note that you can use this script to split the utt2spk file itself,
|
||||
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
|
||||
|
||||
# You can also call the scripts like:
|
||||
# split_scp.pl -j 3 0 scp scp.0
|
||||
# [note: with this option, it assumes zero-based indexing of the split parts,
|
||||
# i.e. the second number must be 0 <= n < num-jobs.]
|
||||
|
||||
use warnings;
|
||||
|
||||
$num_jobs = 0;
|
||||
$job_id = 0;
|
||||
$utt2spk_file = "";
|
||||
$one_based = 0;
|
||||
|
||||
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
|
||||
if ($ARGV[0] eq "-j") {
|
||||
shift @ARGV;
|
||||
$num_jobs = shift @ARGV;
|
||||
$job_id = shift @ARGV;
|
||||
}
|
||||
if ($ARGV[0] =~ /--utt2spk=(.+)/) {
|
||||
$utt2spk_file=$1;
|
||||
shift;
|
||||
}
|
||||
if ($ARGV[0] eq '--one-based') {
|
||||
$one_based = 1;
|
||||
shift @ARGV;
|
||||
}
|
||||
}
|
||||
|
||||
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
|
||||
$job_id - $one_based >= $num_jobs)) {
|
||||
die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
|
||||
($one_based ? " --one-based" : "") . "'\n"
|
||||
}
|
||||
|
||||
$one_based
|
||||
and $job_id--;
|
||||
|
||||
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
|
||||
die
|
||||
"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
|
||||
or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
|
||||
... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
|
||||
}
|
||||
|
||||
$error = 0;
|
||||
$inscp = shift @ARGV;
|
||||
if ($num_jobs == 0) { # without -j option
|
||||
@OUTPUTS = @ARGV;
|
||||
} else {
|
||||
for ($j = 0; $j < $num_jobs; $j++) {
|
||||
if ($j == $job_id) {
|
||||
if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
|
||||
else { push @OUTPUTS, "-"; }
|
||||
} else {
|
||||
push @OUTPUTS, "/dev/null";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ($utt2spk_file ne "") { # We have the --utt2spk option...
|
||||
open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
|
||||
while(<$u_fh>) {
|
||||
@A = split;
|
||||
@A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
|
||||
($u,$s) = @A;
|
||||
$utt2spk{$u} = $s;
|
||||
}
|
||||
close $u_fh;
|
||||
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
|
||||
@spkrs = ();
|
||||
while(<$i_fh>) {
|
||||
@A = split;
|
||||
if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
|
||||
$u = $A[0];
|
||||
$s = $utt2spk{$u};
|
||||
defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
|
||||
if(!defined $spk_count{$s}) {
|
||||
push @spkrs, $s;
|
||||
$spk_count{$s} = 0;
|
||||
$spk_data{$s} = []; # ref to new empty array.
|
||||
}
|
||||
$spk_count{$s}++;
|
||||
push @{$spk_data{$s}}, $_;
|
||||
}
|
||||
# Now split as equally as possible ..
|
||||
# First allocate spks to files by allocating an approximately
|
||||
# equal number of speakers.
|
||||
$numspks = @spkrs; # number of speakers.
|
||||
$numscps = @OUTPUTS; # number of output files.
|
||||
if ($numspks < $numscps) {
|
||||
die "$0: Refusing to split data because number of speakers $numspks " .
|
||||
"is less than the number of output .scp files $numscps\n";
|
||||
}
|
||||
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
|
||||
$scparray[$scpidx] = []; # [] is array reference.
|
||||
}
|
||||
for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
|
||||
$scpidx = int(($spkidx*$numscps) / $numspks);
|
||||
$spk = $spkrs[$spkidx];
|
||||
push @{$scparray[$scpidx]}, $spk;
|
||||
$scpcount[$scpidx] += $spk_count{$spk};
|
||||
}
|
||||
|
||||
# Now will try to reassign beginning + ending speakers
|
||||
# to different scp's and see if it gets more balanced.
|
||||
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
|
||||
# We can show that if considering changing just 2 scp's, we minimize
|
||||
# this by minimizing the squared difference in sizes. This is
|
||||
# equivalent to minimizing the absolute difference in sizes. This
|
||||
# shows this method is bound to converge.
|
||||
|
||||
$changed = 1;
|
||||
while($changed) {
|
||||
$changed = 0;
|
||||
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
|
||||
# First try to reassign ending spk of this scp.
|
||||
if($scpidx < $numscps-1) {
|
||||
$sz = @{$scparray[$scpidx]};
|
||||
if($sz > 0) {
|
||||
$spk = $scparray[$scpidx]->[$sz-1];
|
||||
$count = $spk_count{$spk};
|
||||
$nutt1 = $scpcount[$scpidx];
|
||||
$nutt2 = $scpcount[$scpidx+1];
|
||||
if( abs( ($nutt2+$count) - ($nutt1-$count))
|
||||
< abs($nutt2 - $nutt1)) { # Would decrease
|
||||
# size-diff by reassigning spk...
|
||||
$scpcount[$scpidx+1] += $count;
|
||||
$scpcount[$scpidx] -= $count;
|
||||
pop @{$scparray[$scpidx]};
|
||||
unshift @{$scparray[$scpidx+1]}, $spk;
|
||||
$changed = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
|
||||
$spk = $scparray[$scpidx]->[0];
|
||||
$count = $spk_count{$spk};
|
||||
$nutt1 = $scpcount[$scpidx-1];
|
||||
$nutt2 = $scpcount[$scpidx];
|
||||
if( abs( ($nutt2-$count) - ($nutt1+$count))
|
||||
< abs($nutt2 - $nutt1)) { # Would decrease
|
||||
# size-diff by reassigning spk...
|
||||
$scpcount[$scpidx-1] += $count;
|
||||
$scpcount[$scpidx] -= $count;
|
||||
shift @{$scparray[$scpidx]};
|
||||
push @{$scparray[$scpidx-1]}, $spk;
|
||||
$changed = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
# Now print out the files...
|
||||
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
|
||||
$scpfile = $OUTPUTS[$scpidx];
|
||||
($scpfile ne '-' ? open($f_fh, '>', $scpfile)
|
||||
: open($f_fh, '>&', \*STDOUT)) ||
|
||||
die "$0: Could not open scp file $scpfile for writing: $!\n";
|
||||
$count = 0;
|
||||
if(@{$scparray[$scpidx]} == 0) {
|
||||
print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
|
||||
"$scpfile (too many splits and too few speakers?)\n";
|
||||
$error = 1;
|
||||
} else {
|
||||
foreach $spk ( @{$scparray[$scpidx]} ) {
|
||||
print $f_fh @{$spk_data{$spk}};
|
||||
$count += $spk_count{$spk};
|
||||
}
|
||||
$count == $scpcount[$scpidx] || die "Count mismatch [code error]";
|
||||
}
|
||||
close($f_fh);
|
||||
}
|
||||
} else {
|
||||
# This block is the "normal" case where there is no --utt2spk
|
||||
# option and we just break into equal size chunks.
|
||||
|
||||
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
|
||||
|
||||
$numscps = @OUTPUTS; # size of array.
|
||||
@F = ();
|
||||
while(<$i_fh>) {
|
||||
push @F, $_;
|
||||
}
|
||||
$numlines = @F;
|
||||
if($numlines == 0) {
|
||||
print STDERR "$0: error: empty input scp file $inscp\n";
|
||||
$error = 1;
|
||||
}
|
||||
$linesperscp = int( $numlines / $numscps); # the "whole part"..
|
||||
$linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
|
||||
$remainder = $numlines - ($linesperscp * $numscps);
|
||||
($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
|
||||
# [just doing int() rounds down].
|
||||
$n = 0;
|
||||
for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
|
||||
$scpfile = $OUTPUTS[$scpidx];
|
||||
($scpfile ne '-' ? open($o_fh, '>', $scpfile)
|
||||
: open($o_fh, '>&', \*STDOUT)) ||
|
||||
die "$0: Could not open scp file $scpfile for writing: $!\n";
|
||||
for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
|
||||
print $o_fh $F[$n++];
|
||||
}
|
||||
close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
|
||||
}
|
||||
$n == $numlines || die "$n != $numlines [code error]";
|
||||
}
|
||||
|
||||
exit ($error);
|
||||
38
egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl
Executable file
38
egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl
Executable file
@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# converts an utt2spk file to a spk2utt file.
|
||||
# Takes input from the stdin or from a file argument;
|
||||
# output goes to the standard out.
|
||||
|
||||
if ( @ARGV > 1 ) {
|
||||
die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
|
||||
}
|
||||
|
||||
while(<>){
|
||||
@A = split(" ", $_);
|
||||
@A == 2 || die "Invalid line in utt2spk file: $_";
|
||||
($u,$s) = @A;
|
||||
if(!$seen_spk{$s}) {
|
||||
$seen_spk{$s} = 1;
|
||||
push @spklist, $s;
|
||||
}
|
||||
push (@{$spk_hash{$s}}, "$u");
|
||||
}
|
||||
foreach $s (@spklist) {
|
||||
$l = join(' ',@{$spk_hash{$s}});
|
||||
print "$s $l\n";
|
||||
}
|
||||
404
egs/alimeeting/sa-asr/utils/validate_data_dir.sh
Executable file
404
egs/alimeeting/sa-asr/utils/validate_data_dir.sh
Executable file
@ -0,0 +1,404 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
cmd="$@"
|
||||
|
||||
no_feats=false
|
||||
no_wav=false
|
||||
no_text=false
|
||||
no_spk_sort=false
|
||||
non_print=false
|
||||
|
||||
|
||||
function show_help
|
||||
{
|
||||
echo "Usage: $0 [--no-feats] [--no-text] [--non-print] [--no-wav] [--no-spk-sort] <data-dir>"
|
||||
echo "The --no-xxx options mean that the script does not require "
|
||||
echo "xxx.scp to be present, but it will check it if it is present."
|
||||
echo "--no-spk-sort means that the script does not require the utt2spk to be "
|
||||
echo "sorted by the speaker-id in addition to being sorted by utterance-id."
|
||||
echo "--non-print ignore the presence of non-printable characters."
|
||||
echo "By default, utt2spk is expected to be sorted by both, which can be "
|
||||
echo "achieved by making the speaker-id prefixes of the utterance-ids"
|
||||
echo "e.g.: $0 data/train"
|
||||
}
|
||||
|
||||
while [ $# -ne 0 ] ; do
|
||||
case "$1" in
|
||||
"--no-feats")
|
||||
no_feats=true;
|
||||
;;
|
||||
"--no-text")
|
||||
no_text=true;
|
||||
;;
|
||||
"--non-print")
|
||||
non_print=true;
|
||||
;;
|
||||
"--no-wav")
|
||||
no_wav=true;
|
||||
;;
|
||||
"--no-spk-sort")
|
||||
no_spk_sort=true;
|
||||
;;
|
||||
*)
|
||||
if ! [ -z "$data" ] ; then
|
||||
show_help;
|
||||
exit 1
|
||||
fi
|
||||
data=$1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
|
||||
|
||||
if [ ! -d $data ]; then
|
||||
echo "$0: no such directory $data"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f $data/images.scp ]; then
|
||||
cmd=${cmd/--no-wav/} # remove --no-wav if supplied
|
||||
image/validate_data_dir.sh $cmd
|
||||
exit $?
|
||||
fi
|
||||
|
||||
for f in spk2utt utt2spk; do
|
||||
if [ ! -f $data/$f ]; then
|
||||
echo "$0: no such file $f"
|
||||
exit 1;
|
||||
fi
|
||||
if [ ! -s $data/$f ]; then
|
||||
echo "$0: empty file $f"
|
||||
exit 1;
|
||||
fi
|
||||
done
|
||||
|
||||
! cat $data/utt2spk | awk '{if (NF != 2) exit(1); }' && \
|
||||
echo "$0: $data/utt2spk has wrong format." && exit;
|
||||
|
||||
ns=$(wc -l < $data/spk2utt)
|
||||
if [ "$ns" == 1 ]; then
|
||||
echo "$0: WARNING: you have only one speaker. This probably a bad idea."
|
||||
echo " Search for the word 'bold' in http://kaldi-asr.org/doc/data_prep.html"
|
||||
echo " for more information."
|
||||
fi
|
||||
|
||||
|
||||
tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
|
||||
trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
|
||||
|
||||
export LC_ALL=C
|
||||
|
||||
function check_sorted_and_uniq {
|
||||
! perl -ne '((substr $_,-1) eq "\n") or die "file $ARGV has invalid newline";' $1 && exit 1;
|
||||
! awk '{print $1}' < $1 | sort -uC && echo "$0: file $1 is not sorted or has duplicates" && exit 1;
|
||||
}
|
||||
|
||||
function partial_diff {
|
||||
diff -U1 $1 $2 | (head -n 6; echo "..."; tail -n 6)
|
||||
n1=`cat $1 | wc -l`
|
||||
n2=`cat $2 | wc -l`
|
||||
echo "[Lengths are $1=$n1 versus $2=$n2]"
|
||||
}
|
||||
|
||||
check_sorted_and_uniq $data/utt2spk
|
||||
|
||||
if ! $no_spk_sort; then
|
||||
! sort -k2 -C $data/utt2spk && \
|
||||
echo "$0: utt2spk is not in sorted order when sorted first on speaker-id " && \
|
||||
echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
|
||||
fi
|
||||
|
||||
check_sorted_and_uniq $data/spk2utt
|
||||
|
||||
! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
|
||||
<(utils/spk2utt_to_utt2spk.pl $data/spk2utt) && \
|
||||
echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
|
||||
|
||||
cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
|
||||
|
||||
if [ ! -f $data/text ] && ! $no_text; then
|
||||
echo "$0: no such file $data/text (if this is by design, specify --no-text)"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
num_utts=`cat $tmpdir/utts | wc -l`
|
||||
if ! $no_text; then
|
||||
if ! $non_print; then
|
||||
if locale -a | grep "C.UTF-8" >/dev/null; then
|
||||
L=C.UTF-8
|
||||
else
|
||||
L=en_US.UTF-8
|
||||
fi
|
||||
n_non_print=$(LC_ALL="$L" grep -c '[^[:print:][:space:]]' $data/text) && \
|
||||
echo "$0: text contains $n_non_print lines with non-printable characters" &&\
|
||||
exit 1;
|
||||
fi
|
||||
utils/validate_text.pl $data/text || exit 1;
|
||||
check_sorted_and_uniq $data/text
|
||||
text_len=`cat $data/text | wc -l`
|
||||
illegal_sym_list="<s> </s> #0"
|
||||
for x in $illegal_sym_list; do
|
||||
if grep -w "$x" $data/text > /dev/null; then
|
||||
echo "$0: Error: in $data, text contains illegal symbol $x"
|
||||
exit 1;
|
||||
fi
|
||||
done
|
||||
awk '{print $1}' < $data/text > $tmpdir/utts.txt
|
||||
if ! cmp -s $tmpdir/utts{,.txt}; then
|
||||
echo "$0: Error: in $data, utterance lists extracted from utt2spk and text"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/utts{,.txt}
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f $data/segments ] && [ ! -f $data/wav.scp ]; then
|
||||
echo "$0: in directory $data, segments file exists but no wav.scp"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $data/wav.scp ] && ! $no_wav; then
|
||||
echo "$0: no such file $data/wav.scp (if this is by design, specify --no-wav)"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f $data/wav.scp ]; then
|
||||
check_sorted_and_uniq $data/wav.scp
|
||||
|
||||
if grep -E -q '^\S+\s+~' $data/wav.scp; then
|
||||
# note: it's not a good idea to have any kind of tilde in wav.scp, even if
|
||||
# part of a command, as it would cause compatibility problems if run by
|
||||
# other users, but this used to be not checked for so we let it slide unless
|
||||
# it's something of the form "foo ~/foo.wav" (i.e. a plain file name) which
|
||||
# would definitely cause problems as the fopen system call does not do
|
||||
# tilde expansion.
|
||||
echo "$0: Please do not use tilde (~) in your wav.scp."
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f $data/segments ]; then
|
||||
|
||||
check_sorted_and_uniq $data/segments
|
||||
# We have a segments file -> interpret wav file as "recording-ids" not utterance-ids.
|
||||
! cat $data/segments | \
|
||||
awk '{if (NF != 4 || $4 <= $3) { print "Bad line in segments file", $0; exit(1); }}' && \
|
||||
echo "$0: badly formatted segments file" && exit 1;
|
||||
|
||||
segments_len=`cat $data/segments | wc -l`
|
||||
if [ -f $data/text ]; then
|
||||
! cmp -s $tmpdir/utts <(awk '{print $1}' <$data/segments) && \
|
||||
echo "$0: Utterance list differs between $data/utt2spk and $data/segments " && \
|
||||
echo "$0: Lengths are $segments_len vs $num_utts" && \
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat $data/segments | awk '{print $2}' | sort | uniq > $tmpdir/recordings
|
||||
awk '{print $1}' $data/wav.scp > $tmpdir/recordings.wav
|
||||
if ! cmp -s $tmpdir/recordings{,.wav}; then
|
||||
echo "$0: Error: in $data, recording-ids extracted from segments and wav.scp"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/recordings{,.wav}
|
||||
exit 1;
|
||||
fi
|
||||
if [ -f $data/reco2file_and_channel ]; then
|
||||
# this file is needed only for ctm scoring; it's indexed by recording-id.
|
||||
check_sorted_and_uniq $data/reco2file_and_channel
|
||||
! cat $data/reco2file_and_channel | \
|
||||
awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
|
||||
if ( NF == 3 && $3 == "1" ) {
|
||||
warning_issued = 1;
|
||||
} else {
|
||||
print "Bad line ", $0; exit 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
END {
|
||||
if (warning_issued == 1) {
|
||||
print "The channel should be marked as A or B, not 1! You should change it ASAP! "
|
||||
}
|
||||
}' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
|
||||
cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/recordings.r2fc
|
||||
if ! cmp -s $tmpdir/recordings{,.r2fc}; then
|
||||
echo "$0: Error: in $data, recording-ids extracted from segments and reco2file_and_channel"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/recordings{,.r2fc}
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
else
|
||||
# No segments file -> assume wav.scp indexed by utterance.
|
||||
cat $data/wav.scp | awk '{print $1}' > $tmpdir/utts.wav
|
||||
if ! cmp -s $tmpdir/utts{,.wav}; then
|
||||
echo "$0: Error: in $data, utterance lists extracted from utt2spk and wav.scp"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/utts{,.wav}
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f $data/reco2file_and_channel ]; then
|
||||
# this file is needed only for ctm scoring; it's indexed by recording-id.
|
||||
check_sorted_and_uniq $data/reco2file_and_channel
|
||||
! cat $data/reco2file_and_channel | \
|
||||
awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
|
||||
if ( NF == 3 && $3 == "1" ) {
|
||||
warning_issued = 1;
|
||||
} else {
|
||||
print "Bad line ", $0; exit 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
END {
|
||||
if (warning_issued == 1) {
|
||||
print "The channel should be marked as A or B, not 1! You should change it ASAP! "
|
||||
}
|
||||
}' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
|
||||
cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/utts.r2fc
|
||||
if ! cmp -s $tmpdir/utts{,.r2fc}; then
|
||||
echo "$0: Error: in $data, utterance-ids extracted from segments and reco2file_and_channel"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/utts{,.r2fc}
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -f $data/feats.scp ] && ! $no_feats; then
|
||||
echo "$0: no such file $data/feats.scp (if this is by design, specify --no-feats)"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f $data/feats.scp ]; then
|
||||
check_sorted_and_uniq $data/feats.scp
|
||||
cat $data/feats.scp | awk '{print $1}' > $tmpdir/utts.feats
|
||||
if ! cmp -s $tmpdir/utts{,.feats}; then
|
||||
echo "$0: Error: in $data, utterance-ids extracted from utt2spk and features"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/utts{,.feats}
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
if [ -f $data/cmvn.scp ]; then
|
||||
check_sorted_and_uniq $data/cmvn.scp
|
||||
cat $data/cmvn.scp | awk '{print $1}' > $tmpdir/speakers.cmvn
|
||||
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
|
||||
if ! cmp -s $tmpdir/speakers{,.cmvn}; then
|
||||
echo "$0: Error: in $data, speaker lists extracted from spk2utt and cmvn"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/speakers{,.cmvn}
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f $data/spk2gender ]; then
|
||||
check_sorted_and_uniq $data/spk2gender
|
||||
! cat $data/spk2gender | awk '{if (!((NF == 2 && ($2 == "m" || $2 == "f")))) exit 1; }' && \
|
||||
echo "$0: Mal-formed spk2gender file" && exit 1;
|
||||
cat $data/spk2gender | awk '{print $1}' > $tmpdir/speakers.spk2gender
|
||||
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
|
||||
if ! cmp -s $tmpdir/speakers{,.spk2gender}; then
|
||||
echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2gender"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/speakers{,.spk2gender}
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f $data/spk2warp ]; then
|
||||
check_sorted_and_uniq $data/spk2warp
|
||||
! cat $data/spk2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
|
||||
echo "$0: Mal-formed spk2warp file" && exit 1;
|
||||
cat $data/spk2warp | awk '{print $1}' > $tmpdir/speakers.spk2warp
|
||||
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
|
||||
if ! cmp -s $tmpdir/speakers{,.spk2warp}; then
|
||||
echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2warp"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/speakers{,.spk2warp}
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f $data/utt2warp ]; then
|
||||
check_sorted_and_uniq $data/utt2warp
|
||||
! cat $data/utt2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
|
||||
echo "$0: Mal-formed utt2warp file" && exit 1;
|
||||
cat $data/utt2warp | awk '{print $1}' > $tmpdir/utts.utt2warp
|
||||
cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
|
||||
if ! cmp -s $tmpdir/utts{,.utt2warp}; then
|
||||
echo "$0: Error: in $data, utterance lists extracted from utt2spk and utt2warp"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/utts{,.utt2warp}
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
|
||||
# check some optionally-required things
|
||||
for f in vad.scp utt2lang utt2uniq; do
|
||||
if [ -f $data/$f ]; then
|
||||
check_sorted_and_uniq $data/$f
|
||||
if ! cmp -s <( awk '{print $1}' $data/utt2spk ) \
|
||||
<( awk '{print $1}' $data/$f ); then
|
||||
echo "$0: error: in $data, $f and utt2spk do not have identical utterance-id list"
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
if [ -f $data/utt2dur ]; then
|
||||
check_sorted_and_uniq $data/utt2dur
|
||||
cat $data/utt2dur | awk '{print $1}' > $tmpdir/utts.utt2dur
|
||||
if ! cmp -s $tmpdir/utts{,.utt2dur}; then
|
||||
echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2dur file"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/utts{,.utt2dur}
|
||||
exit 1;
|
||||
fi
|
||||
cat $data/utt2dur | \
|
||||
awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line utt2dur:" NR ":" $0; exit(1) }}' || exit 1
|
||||
fi
|
||||
|
||||
if [ -f $data/utt2num_frames ]; then
|
||||
check_sorted_and_uniq $data/utt2num_frames
|
||||
cat $data/utt2num_frames | awk '{print $1}' > $tmpdir/utts.utt2num_frames
|
||||
if ! cmp -s $tmpdir/utts{,.utt2num_frames}; then
|
||||
echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2num_frames file"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/utts{,.utt2num_frames}
|
||||
exit 1
|
||||
fi
|
||||
awk <$data/utt2num_frames '{
|
||||
if (NF != 2 || !($2 > 0) || $2 != int($2)) {
|
||||
print "Bad line utt2num_frames:" NR ":" $0
|
||||
exit 1 } }' || exit 1
|
||||
fi
|
||||
|
||||
if [ -f $data/reco2dur ]; then
|
||||
check_sorted_and_uniq $data/reco2dur
|
||||
cat $data/reco2dur | awk '{print $1}' > $tmpdir/recordings.reco2dur
|
||||
if [ -f $tmpdir/recordings ]; then
|
||||
if ! cmp -s $tmpdir/recordings{,.reco2dur}; then
|
||||
echo "$0: Error: in $data, recording-ids extracted from segments and reco2dur file"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/recordings{,.reco2dur}
|
||||
exit 1;
|
||||
fi
|
||||
else
|
||||
if ! cmp -s $tmpdir/{utts,recordings.reco2dur}; then
|
||||
echo "$0: Error: in $data, recording-ids extracted from wav.scp and reco2dur file"
|
||||
echo "$0: differ, partial diff is:"
|
||||
partial_diff $tmpdir/{utts,recordings.reco2dur}
|
||||
exit 1;
|
||||
fi
|
||||
fi
|
||||
cat $data/reco2dur | \
|
||||
awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line : " $0; exit(1) }}' || exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "$0: Successfully validated data-directory $data"
|
||||
136
egs/alimeeting/sa-asr/utils/validate_text.pl
Executable file
136
egs/alimeeting/sa-asr/utils/validate_text.pl
Executable file
@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env perl
|
||||
#
|
||||
#===============================================================================
|
||||
# Copyright 2017 Johns Hopkins University (author: Yenda Trmal <jtrmal@gmail.com>)
|
||||
# Johns Hopkins University (author: Daniel Povey)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#===============================================================================
|
||||
|
||||
# validation script for data/<dataset>/text
|
||||
# to be called (preferably) from utils/validate_data_dir.sh
|
||||
use strict;
|
||||
use warnings;
|
||||
use utf8;
|
||||
use Fcntl qw< SEEK_SET >;
|
||||
|
||||
# this function reads the opened file (supplied as a first
|
||||
# parameter) into an array of lines. For each
|
||||
# line, it tests whether it's a valid utf-8 compatible
|
||||
# line. If all lines are valid utf-8, it returns the lines
|
||||
# decoded as utf-8, otherwise it assumes the file's encoding
|
||||
# is one of those 1-byte encodings, such as ISO-8859-x
|
||||
# or Windows CP-X.
|
||||
# Please recall we do not really care about
|
||||
# the actually encoding, we just need to
|
||||
# make sure the length of the (decoded) string
|
||||
# is correct (to make the output formatting looking right).
|
||||
sub get_utf8_or_bytestream {
|
||||
use Encode qw(decode encode);
|
||||
my $is_utf_compatible = 1;
|
||||
my @unicode_lines;
|
||||
my @raw_lines;
|
||||
my $raw_text;
|
||||
my $lineno = 0;
|
||||
my $file = shift;
|
||||
|
||||
while (<$file>) {
|
||||
$raw_text = $_;
|
||||
last unless $raw_text;
|
||||
if ($is_utf_compatible) {
|
||||
my $decoded_text = eval { decode("UTF-8", $raw_text, Encode::FB_CROAK) } ;
|
||||
$is_utf_compatible = $is_utf_compatible && defined($decoded_text);
|
||||
push @unicode_lines, $decoded_text;
|
||||
} else {
|
||||
#print STDERR "WARNING: the line $raw_text cannot be interpreted as UTF-8: $decoded_text\n";
|
||||
;
|
||||
}
|
||||
push @raw_lines, $raw_text;
|
||||
$lineno += 1;
|
||||
}
|
||||
|
||||
if (!$is_utf_compatible) {
|
||||
return (0, @raw_lines);
|
||||
} else {
|
||||
return (1, @unicode_lines);
|
||||
}
|
||||
}
|
||||
|
||||
# check if the given unicode string contain unicode whitespaces
|
||||
# other than the usual four: TAB, LF, CR and SPACE
|
||||
sub validate_utf8_whitespaces {
|
||||
my $unicode_lines = shift;
|
||||
use feature 'unicode_strings';
|
||||
for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) {
|
||||
my $current_line = $unicode_lines->[$i];
|
||||
if ((substr $current_line, -1) ne "\n"){
|
||||
print STDERR "$0: The current line (nr. $i) has invalid newline\n";
|
||||
return 1;
|
||||
}
|
||||
my @A = split(" ", $current_line);
|
||||
my $utt_id = $A[0];
|
||||
# we replace TAB, LF, CR, and SPACE
|
||||
# this is to simplify the test
|
||||
if ($current_line =~ /\x{000d}/) {
|
||||
print STDERR "$0: The line for utterance $utt_id contains CR (0x0D) character\n";
|
||||
return 1;
|
||||
}
|
||||
$current_line =~ s/[\x{0009}\x{000a}\x{0020}]/./g;
|
||||
if ($current_line =~/\s/) {
|
||||
print STDERR "$0: The line for utterance $utt_id contains disallowed Unicode whitespaces\n";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
# checks if the text in the file (supplied as the argument) is utf-8 compatible
|
||||
# if yes, checks if it contains only allowed whitespaces. If no, then does not
|
||||
# do anything. The function seeks to the original position in the file after
|
||||
# reading the text.
|
||||
sub check_allowed_whitespace {
|
||||
my $file = shift;
|
||||
my $filename = shift;
|
||||
my $pos = tell($file);
|
||||
(my $is_utf, my @lines) = get_utf8_or_bytestream($file);
|
||||
seek($file, $pos, SEEK_SET);
|
||||
if ($is_utf) {
|
||||
my $has_invalid_whitespaces = validate_utf8_whitespaces(\@lines);
|
||||
if ($has_invalid_whitespaces) {
|
||||
print STDERR "$0: ERROR: text file '$filename' contains disallowed UTF-8 whitespace character(s)\n";
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
if(@ARGV != 1) {
|
||||
die "Usage: validate_text.pl <text-file>\n" .
|
||||
"e.g.: validate_text.pl data/train/text\n";
|
||||
}
|
||||
|
||||
my $text = shift @ARGV;
|
||||
|
||||
if (-z "$text") {
|
||||
print STDERR "$0: ERROR: file '$text' is empty or does not exist\n";
|
||||
exit 1;
|
||||
}
|
||||
|
||||
if(!open(FILE, "<$text")) {
|
||||
print STDERR "$0: ERROR: failed to open $text\n";
|
||||
exit 1;
|
||||
}
|
||||
|
||||
check_allowed_whitespace(\*FILE, $text) or exit 1;
|
||||
close(FILE);
|
||||
@ -40,7 +40,6 @@ from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
|
||||
|
||||
header_colors = '\033[95m'
|
||||
@ -91,8 +90,6 @@ class Speech2Text:
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
||||
|
||||
logging.info("asr_model: {}".format(asr_model))
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
@ -111,7 +108,7 @@ class Speech2Text:
|
||||
# 2. Build Language model
|
||||
if lm_train_config is not None:
|
||||
lm, lm_train_args = LMTask.build_model_from_file(
|
||||
lm_train_config, lm_file, device
|
||||
lm_train_config, lm_file, None, device
|
||||
)
|
||||
scorers["lm"] = lm.lm
|
||||
|
||||
@ -142,6 +139,13 @@ class Speech2Text:
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
for scorer in scorers.values():
|
||||
if isinstance(scorer, torch.nn.Module):
|
||||
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
logging.info(f"Beam_search: {beam_search}")
|
||||
logging.info(f"Decoding device={device}, dtype={dtype}")
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
@ -198,16 +202,7 @@ class Speech2Text:
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
if self.frontend is not None:
|
||||
feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
||||
feats = to_device(feats, device=self.device)
|
||||
feats_len = feats_len.int()
|
||||
self.asr_model.frontend = None
|
||||
else:
|
||||
feats = speech
|
||||
feats_len = speech_lengths
|
||||
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
|
||||
batch = {"speech": feats, "speech_lengths": feats_len}
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
@ -355,6 +350,9 @@ def inference_modelscope(
|
||||
if ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
@ -408,6 +406,7 @@ def inference_modelscope(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
fs=fs,
|
||||
mc=True,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
@ -452,7 +451,7 @@ def inference_modelscope(
|
||||
|
||||
# Write the result to each file
|
||||
ibest_writer["token"][key] = " ".join(token)
|
||||
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["score"][key] = str(hyp.score)
|
||||
|
||||
if text is not None:
|
||||
@ -463,6 +462,9 @@ def inference_modelscope(
|
||||
asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text
|
||||
|
||||
logging.info("uttid: {}".format(key))
|
||||
logging.info("text predictions: {}\n".format(text))
|
||||
return asr_result_list
|
||||
|
||||
return _forward
|
||||
@ -637,4 +639,4 @@ def main(cmd=None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
@ -288,6 +288,9 @@ def inference_launch_funasr(**kwargs):
|
||||
if mode == "asr":
|
||||
from funasr.bin.asr_inference import inference
|
||||
return inference(**kwargs)
|
||||
elif mode == "sa_asr":
|
||||
from funasr.bin.sa_asr_inference import inference
|
||||
return inference(**kwargs)
|
||||
elif mode == "uniasr":
|
||||
from funasr.bin.asr_inference_uniasr import inference
|
||||
return inference(**kwargs)
|
||||
@ -342,4 +345,4 @@ def main(cmd=None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
@ -2,6 +2,14 @@
|
||||
|
||||
import os
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level='INFO',
|
||||
format=f"[{os.uname()[1].split('.')[0]}]"
|
||||
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
from funasr.tasks.asr import ASRTask
|
||||
|
||||
|
||||
@ -27,7 +35,8 @@ if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
# setup local gpu_id
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
if args.ngpu > 0:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
|
||||
# DDP settings
|
||||
if args.ngpu > 1:
|
||||
@ -38,9 +47,9 @@ if __name__ == '__main__':
|
||||
|
||||
# re-compute batch size: when dataset type is small
|
||||
if args.dataset_type == "small":
|
||||
if args.batch_size is not None:
|
||||
if args.batch_size is not None and args.ngpu > 0:
|
||||
args.batch_size = args.batch_size * args.ngpu
|
||||
if args.batch_bins is not None:
|
||||
if args.batch_bins is not None and args.ngpu > 0:
|
||||
args.batch_bins = args.batch_bins * args.ngpu
|
||||
|
||||
main(args=args)
|
||||
|
||||
674
funasr/bin/sa_asr_inference.py
Normal file
674
funasr/bin/sa_asr_inference.py
Normal file
@ -0,0 +1,674 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
|
||||
from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
|
||||
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis
|
||||
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.modules.scorers.length_bonus import LengthBonus
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.tasks.sa_asr import ASRTask
|
||||
from funasr.tasks.lm import LMTask
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
|
||||
|
||||
header_colors = '\033[95m'
|
||||
end_colors = '\033[0m'
|
||||
|
||||
|
||||
class Speech2Text:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr_train_config: Union[Path, str] = None,
|
||||
asr_model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
lm_train_config: Union[Path, str] = None,
|
||||
lm_file: Union[Path, str] = None,
|
||||
token_type: str = None,
|
||||
bpemodel: str = None,
|
||||
device: str = "cpu",
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
beam_size: int = 20,
|
||||
ctc_weight: float = 0.5,
|
||||
lm_weight: float = 1.0,
|
||||
ngram_weight: float = 0.9,
|
||||
penalty: float = 0.0,
|
||||
nbest: int = 1,
|
||||
streaming: bool = False,
|
||||
frontend_conf: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
frontend = None
|
||||
|
||||
logging.info("asr_model: {}".format(asr_model))
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
decoder = asr_model.decoder
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
# 2. Build Language model
|
||||
if lm_train_config is not None:
|
||||
lm, lm_train_args = LMTask.build_model_from_file(
|
||||
lm_train_config, lm_file, None, device
|
||||
)
|
||||
scorers["lm"] = lm.lm
|
||||
|
||||
# 3. Build ngram model
|
||||
# ngram is not supported now
|
||||
ngram = None
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
# 4. Build BeamSearch object
|
||||
# transducer is not supported now
|
||||
beam_search_transducer = None
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - ctc_weight,
|
||||
ctc=ctc_weight,
|
||||
lm=lm_weight,
|
||||
ngram=ngram_weight,
|
||||
length_bonus=penalty,
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=beam_size,
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=asr_model.sos,
|
||||
eos=asr_model.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
||||
)
|
||||
|
||||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
for scorer in scorers.values():
|
||||
if isinstance(scorer, torch.nn.Module):
|
||||
scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
logging.info(f"Beam_search: {beam_search}")
|
||||
logging.info(f"Decoding device={device}, dtype={dtype}")
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
if bpemodel is None:
|
||||
bpemodel = asr_train_args.bpemodel
|
||||
|
||||
if token_type is None:
|
||||
tokenizer = None
|
||||
elif token_type == "bpe":
|
||||
if bpemodel is not None:
|
||||
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
||||
else:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = build_tokenizer(token_type=token_type)
|
||||
converter = TokenIDConverter(token_list=token_list)
|
||||
logging.info(f"Text tokenizer: {tokenizer}")
|
||||
|
||||
self.asr_model = asr_model
|
||||
self.asr_train_args = asr_train_args
|
||||
self.converter = converter
|
||||
self.tokenizer = tokenizer
|
||||
self.beam_search = beam_search
|
||||
self.beam_search_transducer = beam_search_transducer
|
||||
self.maxlenratio = maxlenratio
|
||||
self.minlenratio = minlenratio
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.nbest = nbest
|
||||
self.frontend = frontend
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
|
||||
) -> List[
|
||||
Tuple[
|
||||
Optional[str],
|
||||
Optional[str],
|
||||
List[str],
|
||||
List[int],
|
||||
Union[Hypothesis],
|
||||
]
|
||||
]:
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
speech: Input speech data
|
||||
Returns:
|
||||
text, text_id, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
if isinstance(profile, np.ndarray):
|
||||
profile = torch.tensor(profile)
|
||||
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
# b. Forward Encoder
|
||||
asr_enc, _, spk_enc = self.asr_model.encode(**batch)
|
||||
if isinstance(asr_enc, tuple):
|
||||
asr_enc = asr_enc[0]
|
||||
if isinstance(spk_enc, tuple):
|
||||
spk_enc = spk_enc[0]
|
||||
assert len(asr_enc) == 1, len(asr_enc)
|
||||
assert len(spk_enc) == 1, len(spk_enc)
|
||||
|
||||
# c. Passed the encoder result and the beam search
|
||||
nbest_hyps = self.beam_search(
|
||||
asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
|
||||
results = []
|
||||
for hyp in nbest_hyps:
|
||||
assert isinstance(hyp, (Hypothesis)), type(hyp)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1: last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1: last_pos].tolist()
|
||||
|
||||
spk_weigths=torch.stack(hyp.spk_weigths, dim=0)
|
||||
|
||||
token_ori = self.converter.ids2tokens(token_int)
|
||||
text_ori = self.tokenizer.tokens2text(token_ori)
|
||||
|
||||
text_ori_spklist = text_ori.split('$')
|
||||
cur_index = 0
|
||||
spk_choose = []
|
||||
for i in range(len(text_ori_spklist)):
|
||||
text_ori_split = text_ori_spklist[i]
|
||||
n = len(text_ori_split)
|
||||
spk_weights_local = spk_weigths[cur_index: cur_index + n]
|
||||
cur_index = cur_index + n + 1
|
||||
spk_weights_local = spk_weights_local.mean(dim=0)
|
||||
spk_choose_local = spk_weights_local.argmax(-1)
|
||||
spk_choose.append(spk_choose_local.item() + 1)
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
text = self.tokenizer.tokens2text(token)
|
||||
else:
|
||||
text = None
|
||||
|
||||
text_spklist = text.split('$')
|
||||
assert len(spk_choose) == len(text_spklist)
|
||||
|
||||
spk_list=[]
|
||||
for i in range(len(text_spklist)):
|
||||
text_split = text_spklist[i]
|
||||
n = len(text_split)
|
||||
spk_list.append(str(spk_choose[i]) * n)
|
||||
|
||||
text_id = '$'.join(spk_list)
|
||||
|
||||
assert len(text) == len(text_id)
|
||||
|
||||
results.append((text, text_id, token, token_int, hyp))
|
||||
|
||||
assert check_return_type(results)
|
||||
return results
|
||||
|
||||
def inference(
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
log_level: Union[int, str],
|
||||
data_path_and_name_and_type,
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
cmvn_file: Optional[str] = None,
|
||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
key_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
streaming: bool = False,
|
||||
output_dir: Optional[str] = None,
|
||||
dtype: str = "float32",
|
||||
seed: int = 0,
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
inference_pipeline = inference_modelscope(
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
ngpu=ngpu,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
penalty=penalty,
|
||||
log_level=log_level,
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
cmvn_file=cmvn_file,
|
||||
raw_inputs=raw_inputs,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
token_type=token_type,
|
||||
key_file=key_file,
|
||||
word_lm_train_config=word_lm_train_config,
|
||||
bpemodel=bpemodel,
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
streaming=streaming,
|
||||
output_dir=output_dir,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
ngram_weight=ngram_weight,
|
||||
nbest=nbest,
|
||||
num_workers=num_workers,
|
||||
**kwargs,
|
||||
)
|
||||
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
|
||||
|
||||
def inference_modelscope(
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
log_level: Union[int, str],
|
||||
# data_path_and_name_and_type,
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
cmvn_file: Optional[str] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
key_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
streaming: bool = False,
|
||||
output_dir: Optional[str] = None,
|
||||
dtype: str = "float32",
|
||||
seed: int = 0,
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
if word_lm_train_config is not None:
|
||||
raise NotImplementedError("Word LM is not implemented")
|
||||
if ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if ngpu >= 1 and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
# 2. Build speech2text
|
||||
speech2text_kwargs = dict(
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
cmvn_file=cmvn_file,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
device=device,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
streaming=streaming,
|
||||
)
|
||||
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
|
||||
speech2text = Speech2Text(**speech2text_kwargs)
|
||||
|
||||
def _forward(data_path_and_name_and_type,
|
||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
raw_inputs = raw_inputs.numpy()
|
||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||
loader = ASRTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
fs=fs,
|
||||
mc=True,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
|
||||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
finish_count = 0
|
||||
file_count = 1
|
||||
# 7 .Start for-loop
|
||||
# FIXME(kamo): The output format should be discussed about
|
||||
asr_result_list = []
|
||||
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
||||
if output_path is not None:
|
||||
writer = DatadirWriter(output_path)
|
||||
else:
|
||||
writer = None
|
||||
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
try:
|
||||
results = speech2text(**batch)
|
||||
except TooShortUttError as e:
|
||||
logging.warning(f"Utterance {keys} {e}")
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
results = [[" ", ["sil"], [2], hyp]] * nbest
|
||||
|
||||
# Only supporting batch_size==1
|
||||
key = keys[0]
|
||||
for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
|
||||
# Create a directory: outdir/{n}best_recog
|
||||
if writer is not None:
|
||||
ibest_writer = writer[f"{n}best_recog"]
|
||||
|
||||
# Write the result to each file
|
||||
ibest_writer["token"][key] = " ".join(token)
|
||||
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["score"][key] = str(hyp.score)
|
||||
ibest_writer["text_id"][key] = text_id
|
||||
|
||||
if text is not None:
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text
|
||||
|
||||
logging.info("uttid: {}".format(key))
|
||||
logging.info("text predictions: {}".format(text))
|
||||
logging.info("text_id predictions: {}\n".format(text_id))
|
||||
return asr_result_list
|
||||
|
||||
return _forward
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description="ASR Decoding",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Note(kamo): Use '_' instead of '-' as separator.
|
||||
# '-' is confusing if written in yaml.
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of gpus. 0 indicates CPU mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpuid_list",
|
||||
type=str,
|
||||
default="",
|
||||
help="The visible gpus",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float16", "float32", "float64"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of workers used for DataLoader",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Input data related")
|
||||
group.add_argument(
|
||||
"--data_path_and_name_and_type",
|
||||
type=str2triple_str,
|
||||
required=False,
|
||||
action="append",
|
||||
)
|
||||
group.add_argument("--raw_inputs", type=list, default=None)
|
||||
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
|
||||
group.add_argument("--key_file", type=str_or_none)
|
||||
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group("The model configuration related")
|
||||
group.add_argument(
|
||||
"--asr_train_config",
|
||||
type=str,
|
||||
help="ASR training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--asr_model_file",
|
||||
type=str,
|
||||
help="ASR model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--cmvn_file",
|
||||
type=str,
|
||||
help="Global cmvn file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_train_config",
|
||||
type=str,
|
||||
help="LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_file",
|
||||
type=str,
|
||||
help="LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_train_config",
|
||||
type=str,
|
||||
help="Word LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_file",
|
||||
type=str,
|
||||
help="Word LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ngram_file",
|
||||
type=str,
|
||||
help="N-gram parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--model_tag",
|
||||
type=str,
|
||||
help="Pretrained model tag. If specify this option, *_train_config and "
|
||||
"*_file will be overwritten",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Beam-search related")
|
||||
group.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size for inference",
|
||||
)
|
||||
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
||||
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
|
||||
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
|
||||
group.add_argument(
|
||||
"--maxlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain max output length. "
|
||||
"If maxlenratio=0.0 (default), it uses a end-detect "
|
||||
"function "
|
||||
"to automatically find maximum hypothesis lengths."
|
||||
"If maxlenratio<0.0, its absolute value is interpreted"
|
||||
"as a constant max output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--minlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain min output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ctc_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="CTC weight in joint decoding",
|
||||
)
|
||||
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
|
||||
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
|
||||
group.add_argument("--streaming", type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group("Text converter related")
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
choices=["char", "bpe", None],
|
||||
help="The token type for ASR model. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bpemodel",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The model path of sentencepiece. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
kwargs.pop("config", None)
|
||||
inference(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
55
funasr/bin/sa_asr_train.py
Executable file
55
funasr/bin/sa_asr_train.py
Executable file
@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level='INFO',
|
||||
format=f"[{os.uname()[1].split('.')[0]}]"
|
||||
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
from funasr.tasks.sa_asr import ASRTask
|
||||
|
||||
|
||||
# for ASR Training
|
||||
def parse_args():
|
||||
parser = ASRTask.get_parser()
|
||||
parser.add_argument(
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="local gpu id.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args=None, cmd=None):
|
||||
# for ASR Training
|
||||
ASRTask.main(args=args, cmd=cmd)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
# setup local gpu_id
|
||||
if args.ngpu > 0:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
|
||||
# DDP settings
|
||||
if args.ngpu > 1:
|
||||
args.distributed = True
|
||||
else:
|
||||
args.distributed = False
|
||||
assert args.num_worker_count == 1
|
||||
|
||||
# re-compute batch size: when dataset type is small
|
||||
if args.dataset_type == "small":
|
||||
if args.batch_size is not None and args.ngpu > 0:
|
||||
args.batch_size = args.batch_size * args.ngpu
|
||||
if args.batch_bins is not None and args.ngpu > 0:
|
||||
args.batch_bins = args.batch_bins * args.ngpu
|
||||
|
||||
main(args=args)
|
||||
@ -46,13 +46,15 @@ class SoundScpReader(collections.abc.Mapping):
|
||||
if self.normalize:
|
||||
# soundfile.read normalizes data to [-1,1] if dtype is not given
|
||||
array, rate = librosa.load(
|
||||
wav, sr=self.dest_sample_rate, mono=not self.always_2d
|
||||
wav, sr=self.dest_sample_rate, mono=self.always_2d
|
||||
)
|
||||
else:
|
||||
array, rate = librosa.load(
|
||||
wav, sr=self.dest_sample_rate, mono=not self.always_2d, dtype=self.dtype
|
||||
wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype
|
||||
)
|
||||
|
||||
if array.ndim==2:
|
||||
array=array.transpose((1, 0))
|
||||
return rate, array
|
||||
|
||||
def get_path(self, key):
|
||||
|
||||
47
funasr/losses/nll_loss.py
Normal file
47
funasr/losses/nll_loss.py
Normal file
@ -0,0 +1,47 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class NllLoss(nn.Module):
|
||||
"""Nll loss.
|
||||
|
||||
:param int size: the number of class
|
||||
:param int padding_idx: ignored class id
|
||||
:param bool normalize_length: normalize loss by sequence length if True
|
||||
:param torch.nn.Module criterion: loss function
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
padding_idx,
|
||||
normalize_length=False,
|
||||
criterion=nn.NLLLoss(reduction='none'),
|
||||
):
|
||||
"""Construct an LabelSmoothingLoss object."""
|
||||
super(NllLoss, self).__init__()
|
||||
self.criterion = criterion
|
||||
self.padding_idx = padding_idx
|
||||
self.size = size
|
||||
self.true_dist = None
|
||||
self.normalize_length = normalize_length
|
||||
|
||||
def forward(self, x, target):
|
||||
"""Compute loss between x and target.
|
||||
|
||||
:param torch.Tensor x: prediction (batch, seqlen, class)
|
||||
:param torch.Tensor target:
|
||||
target signal masked with self.padding_id (batch, seqlen)
|
||||
:return: scalar float value
|
||||
:rtype torch.Tensor
|
||||
"""
|
||||
assert x.size(2) == self.size
|
||||
batch_size = x.size(0)
|
||||
x = x.view(-1, self.size)
|
||||
target = target.view(-1)
|
||||
with torch.no_grad():
|
||||
ignore = target == self.padding_idx # (B,)
|
||||
total = len(target) - ignore.sum().item()
|
||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
||||
kl = self.criterion(x , target)
|
||||
denom = total if self.normalize_length else batch_size
|
||||
return kl.masked_fill(ignore, 0).sum() / denom
|
||||
169
funasr/models/decoder/decoder_layer_sa_asr.py
Normal file
169
funasr/models/decoder/decoder_layer_sa_asr.py
Normal file
@ -0,0 +1,169 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
|
||||
|
||||
class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
|
||||
self.size = size
|
||||
self.self_attn = self_attn
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
|
||||
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
# compute only the last frame query keeping dim: max_time_out -> 1
|
||||
assert cache.shape == (
|
||||
tgt.shape[0],
|
||||
tgt.shape[1] - 1,
|
||||
self.size,
|
||||
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
tgt_concat = torch.cat(
|
||||
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear1(tgt_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
z = x
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, skip), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(skip)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
|
||||
|
||||
class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
d_size,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
|
||||
self.size = size
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
self.spk_linear = nn.Linear(d_size, size, bias=False)
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
|
||||
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
x = tgt_q
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
residual = x
|
||||
|
||||
if dn!=None:
|
||||
x = x + self.spk_linear(dn)
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, memory, memory_mask
|
||||
|
||||
|
||||
|
||||
291
funasr/models/decoder/transformer_decoder_sa_asr.py
Normal file
291
funasr/models/decoder/transformer_decoder_sa_asr.py
Normal file
@ -0,0 +1,291 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadedAttention
|
||||
from funasr.modules.attention import CosineDistanceAttention
|
||||
from funasr.models.decoder.transformer_decoder import DecoderLayer
|
||||
from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer
|
||||
from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer
|
||||
from funasr.modules.dynamic_conv import DynamicConvolution
|
||||
from funasr.modules.dynamic_conv2d import DynamicConvolution2D
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.lightconv import LightweightConvolution
|
||||
from funasr.modules.lightconv2d import LightweightConvolution2D
|
||||
from funasr.modules.mask import subsequent_mask
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
|
||||
class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
spker_embedding_dim: int = 256,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
input_layer: str = "embed",
|
||||
use_asr_output_layer: bool = True,
|
||||
use_spk_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
if input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(vocab_size, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
if use_asr_output_layer:
|
||||
self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.asr_output_layer = None
|
||||
|
||||
if use_spk_output_layer:
|
||||
self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
|
||||
else:
|
||||
self.spk_output_layer = None
|
||||
|
||||
self.cos_distance_att = CosineDistanceAttention()
|
||||
|
||||
self.decoder1 = None
|
||||
self.decoder2 = None
|
||||
self.decoder3 = None
|
||||
self.decoder4 = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
asr_hs_pad: torch.Tensor,
|
||||
spk_hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
profile_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask & m
|
||||
|
||||
asr_memory = asr_hs_pad
|
||||
spk_memory = spk_hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
|
||||
# Spk decoder
|
||||
x = self.embed(tgt)
|
||||
|
||||
x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
|
||||
x, tgt_mask, asr_memory, spk_memory, memory_mask
|
||||
)
|
||||
x, tgt_mask, spk_memory, memory_mask = self.decoder2(
|
||||
x, tgt_mask, spk_memory, memory_mask
|
||||
)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.spk_output_layer is not None:
|
||||
x = self.spk_output_layer(x)
|
||||
dn, weights = self.cos_distance_att(x, profile, profile_lens)
|
||||
# Asr decoder
|
||||
x, tgt_mask, asr_memory, memory_mask = self.decoder3(
|
||||
z, tgt_mask, asr_memory, memory_mask, dn
|
||||
)
|
||||
x, tgt_mask, asr_memory, memory_mask = self.decoder4(
|
||||
x, tgt_mask, asr_memory, memory_mask
|
||||
)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.asr_output_layer is not None:
|
||||
x = self.asr_output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, weights, olens
|
||||
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
asr_memory: torch.Tensor,
|
||||
spk_memory: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
cache: List[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
x = self.embed(tgt)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
|
||||
new_cache = []
|
||||
x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
|
||||
x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
|
||||
)
|
||||
new_cache.append(x)
|
||||
for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
|
||||
x, tgt_mask, spk_memory, _ = decoder(
|
||||
x, tgt_mask, spk_memory, None, cache=c
|
||||
)
|
||||
new_cache.append(x)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
else:
|
||||
x = x
|
||||
if self.spk_output_layer is not None:
|
||||
x = self.spk_output_layer(x)
|
||||
dn, weights = self.cos_distance_att(x, profile, None)
|
||||
|
||||
x, tgt_mask, asr_memory, _ = self.decoder3(
|
||||
z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
|
||||
)
|
||||
new_cache.append(x)
|
||||
|
||||
for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
|
||||
x, tgt_mask, asr_memory, _ = decoder(
|
||||
x, tgt_mask, asr_memory, None, cache=c
|
||||
)
|
||||
new_cache.append(x)
|
||||
|
||||
if self.normalize_before:
|
||||
y = self.after_norm(x[:, -1])
|
||||
else:
|
||||
y = x[:, -1]
|
||||
if self.asr_output_layer is not None:
|
||||
y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
|
||||
|
||||
return y, weights, new_cache
|
||||
|
||||
def score(self, ys, state, asr_enc, spk_enc, profile):
|
||||
"""Score."""
|
||||
ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
|
||||
logp, weights, state = self.forward_one_step(
|
||||
ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
|
||||
)
|
||||
return logp.squeeze(0), weights.squeeze(), state
|
||||
|
||||
class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
spker_embedding_dim: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
asr_num_blocks: int = 6,
|
||||
spk_num_blocks: int = 3,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_asr_output_layer: bool = True,
|
||||
use_spk_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
spker_embedding_dim=spker_embedding_dim,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_asr_output_layer=use_asr_output_layer,
|
||||
use_spk_output_layer=use_spk_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
self.decoder2 = repeat(
|
||||
spk_num_blocks - 1,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
|
||||
attention_dim,
|
||||
spker_embedding_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
self.decoder4 = repeat(
|
||||
asr_num_blocks - 1,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
521
funasr/models/e2e_sa_asr.py
Normal file
521
funasr/models/e2e_sa_asr.py
Normal file
@ -0,0 +1,521 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr.losses.nll_loss import NllLoss
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.modules.add_sos_eos import add_sos_eos
|
||||
from funasr.modules.e2e_asr_common import ErrorCalculator
|
||||
from funasr.modules.nets_utils import th_accuracy
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class ESPnetASRModel(AbsESPnetModel):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
max_spk_num: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
asr_encoder: AbsEncoder,
|
||||
spk_encoder: torch.nn.Module,
|
||||
postencoder: Optional[AbsPostEncoder],
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
spk_weight: float = 0.5,
|
||||
ctc_weight: float = 0.5,
|
||||
interctc_weight: float = 0.0,
|
||||
ignore_id: int = -1,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
extract_feats_in_collect_stats: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
self.blank_id = 0
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
self.vocab_size = vocab_size
|
||||
self.max_spk_num=max_spk_num
|
||||
self.ignore_id = ignore_id
|
||||
self.spk_weight = spk_weight
|
||||
self.ctc_weight = ctc_weight
|
||||
self.interctc_weight = interctc_weight
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.postencoder = postencoder
|
||||
self.asr_encoder = asr_encoder
|
||||
self.spk_encoder = spk_encoder
|
||||
|
||||
if not hasattr(self.asr_encoder, "interctc_use_conditioning"):
|
||||
self.asr_encoder.interctc_use_conditioning = False
|
||||
if self.asr_encoder.interctc_use_conditioning:
|
||||
self.asr_encoder.conditioning_layer = torch.nn.Linear(
|
||||
vocab_size, self.asr_encoder.output_size()
|
||||
)
|
||||
|
||||
self.error_calculator = None
|
||||
|
||||
|
||||
# we set self.decoder = None in the CTC mode since
|
||||
# self.decoder parameters were never used and PyTorch complained
|
||||
# and threw an Exception in the multi-GPU experiment.
|
||||
# thanks Jeff Farris for pointing out the issue.
|
||||
if ctc_weight == 1.0:
|
||||
self.decoder = None
|
||||
else:
|
||||
self.decoder = decoder
|
||||
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
self.criterion_spk = NllLoss(
|
||||
size=max_spk_num,
|
||||
padding_idx=ignore_id,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
if report_cer or report_wer:
|
||||
self.error_calculator = ErrorCalculator(
|
||||
token_list, sym_space, sym_blank, report_cer, report_wer
|
||||
)
|
||||
|
||||
if ctc_weight == 0.0:
|
||||
self.ctc = None
|
||||
else:
|
||||
self.ctc = ctc
|
||||
|
||||
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
profile_lengths: torch.Tensor,
|
||||
text_id: torch.Tensor,
|
||||
text_id_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
profile: (Batch, Length, Dim)
|
||||
profile_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
asr_encoder_out, encoder_out_lens, spk_encoder_out = self.encode(speech, speech_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(asr_encoder_out, tuple):
|
||||
intermediate_outs = asr_encoder_out[1]
|
||||
asr_encoder_out = asr_encoder_out[0]
|
||||
|
||||
loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = None, None, None, None, None, None
|
||||
loss_ctc, cer_ctc = None, None
|
||||
stats = dict()
|
||||
|
||||
# 1. CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
asr_encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
|
||||
# Intermediate CTC (optional)
|
||||
loss_interctc = 0.0
|
||||
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||
for layer_idx, intermediate_out in intermediate_outs:
|
||||
# we assume intermediate_out has the same length & padding
|
||||
# as those of encoder_out
|
||||
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||
intermediate_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
loss_interctc = loss_interctc + loss_ic
|
||||
|
||||
# Collect Intermedaite CTC stats
|
||||
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||
loss_ic.detach() if loss_ic is not None else None
|
||||
)
|
||||
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||
|
||||
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||
|
||||
# calculate whole encoder loss
|
||||
loss_ctc = (
|
||||
1 - self.interctc_weight
|
||||
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||
|
||||
|
||||
# 2b. Attention decoder branch
|
||||
if self.ctc_weight != 1.0:
|
||||
loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss(
|
||||
asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths
|
||||
)
|
||||
|
||||
# 3. CTC-Att loss definition
|
||||
if self.ctc_weight == 0.0:
|
||||
loss_asr = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss_asr = loss_ctc
|
||||
else:
|
||||
loss_asr = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
if self.spk_weight == 0.0:
|
||||
loss = loss_asr
|
||||
else:
|
||||
loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr
|
||||
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_asr=loss_asr.detach(),
|
||||
loss_att=loss_att.detach() if loss_att is not None else None,
|
||||
loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
|
||||
loss_spk=loss_spk.detach() if loss_spk is not None else None,
|
||||
acc=acc_att,
|
||||
acc_spk=acc_spk,
|
||||
cer=cer_att,
|
||||
wer=wer_att,
|
||||
cer_ctc=cer_ctc,
|
||||
)
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.extract_feats_in_collect_stats:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
else:
|
||||
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||
logging.warning(
|
||||
"Generating dummy stats for feats and feats_lengths, "
|
||||
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||
f"{self.extract_feats_in_collect_stats}"
|
||||
)
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
feats_raw = feats.clone()
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
if self.asr_encoder.interctc_use_conditioning:
|
||||
encoder_out, encoder_out_lens, _ = self.asr_encoder(
|
||||
feats, feats_lengths, ctc=self.ctc
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths)
|
||||
intermediate_outs = None
|
||||
if isinstance(encoder_out, tuple):
|
||||
intermediate_outs = encoder_out[1]
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0]
|
||||
# import ipdb;ipdb.set_trace()
|
||||
if encoder_out_spk_ori.size(1)!=encoder_out.size(1):
|
||||
encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
|
||||
else:
|
||||
encoder_out_spk=encoder_out_spk_ori
|
||||
# Post-encoder, e.g. NLU
|
||||
if self.postencoder is not None:
|
||||
encoder_out, encoder_out_lens = self.postencoder(
|
||||
encoder_out, encoder_out_lens
|
||||
)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
assert encoder_out_spk.size(0) == speech.size(0), (
|
||||
encoder_out_spk.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
|
||||
if intermediate_outs is not None:
|
||||
return (encoder_out, intermediate_outs), encoder_out_lens
|
||||
|
||||
return encoder_out, encoder_out_lens, encoder_out_spk
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def nll(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
|
||||
Args:
|
||||
encoder_out: (Batch, Length, Dim)
|
||||
encoder_out_lens: (Batch,)
|
||||
ys_pad: (Batch, Length)
|
||||
ys_pad_lens: (Batch,)
|
||||
"""
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
) # [batch, seqlen, dim]
|
||||
batch_size = decoder_out.size(0)
|
||||
decoder_num_class = decoder_out.size(2)
|
||||
# nll: negative log-likelihood
|
||||
nll = torch.nn.functional.cross_entropy(
|
||||
decoder_out.view(-1, decoder_num_class),
|
||||
ys_out_pad.view(-1),
|
||||
ignore_index=self.ignore_id,
|
||||
reduction="none",
|
||||
)
|
||||
nll = nll.view(batch_size, -1)
|
||||
nll = nll.sum(dim=1)
|
||||
assert nll.size(0) == batch_size
|
||||
return nll
|
||||
|
||||
def batchify_nll(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
batch_size: int = 100,
|
||||
):
|
||||
"""Compute negative log likelihood(nll) from transformer-decoder
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
encoder_out: (Batch, Length, Dim)
|
||||
encoder_out_lens: (Batch,)
|
||||
ys_pad: (Batch, Length)
|
||||
ys_pad_lens: (Batch,)
|
||||
batch_size: int, samples each batch contain when computing nll,
|
||||
you may change this to avoid OOM or increase
|
||||
GPU memory usage
|
||||
"""
|
||||
total_num = encoder_out.size(0)
|
||||
if total_num <= batch_size:
|
||||
nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
else:
|
||||
nll = []
|
||||
start_idx = 0
|
||||
while True:
|
||||
end_idx = min(start_idx + batch_size, total_num)
|
||||
batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
|
||||
batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
|
||||
batch_ys_pad = ys_pad[start_idx:end_idx, :]
|
||||
batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
|
||||
batch_nll = self.nll(
|
||||
batch_encoder_out,
|
||||
batch_encoder_out_lens,
|
||||
batch_ys_pad,
|
||||
batch_ys_pad_lens,
|
||||
)
|
||||
nll.append(batch_nll)
|
||||
start_idx = end_idx
|
||||
if start_idx == total_num:
|
||||
break
|
||||
nll = torch.cat(nll)
|
||||
assert nll.size(0) == total_num
|
||||
return nll
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
asr_encoder_out: torch.Tensor,
|
||||
spk_encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
profile_lens: torch.Tensor,
|
||||
text_id: torch.Tensor,
|
||||
text_id_lengths: torch.Tensor
|
||||
):
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, weights_no_pad, _ = self.decoder(
|
||||
asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens
|
||||
)
|
||||
|
||||
spk_num_no_pad=weights_no_pad.size(-1)
|
||||
pad=(0,self.max_spk_num-spk_num_no_pad)
|
||||
weights=F.pad(weights_no_pad, pad, mode='constant', value=0)
|
||||
|
||||
# pre_id=weights.argmax(-1)
|
||||
# pre_text=decoder_out.argmax(-1)
|
||||
# id_mask=(pre_id==text_id).to(dtype=text_id.dtype)
|
||||
# pre_text_mask=pre_text*id_mask+1-id_mask #相同的地方不变,不同的地方设为1(<unk>)
|
||||
# padding_mask= ys_out_pad != self.ignore_id
|
||||
# numerator = torch.sum(pre_text_mask.masked_select(padding_mask) == ys_out_pad.masked_select(padding_mask))
|
||||
# denominator = torch.sum(padding_mask)
|
||||
# sd_acc = float(numerator) / float(denominator)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
loss_spk = self.criterion_spk(torch.log(weights), text_id)
|
||||
|
||||
acc_spk= th_accuracy(
|
||||
weights.view(-1, self.max_spk_num),
|
||||
text_id,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
@ -38,6 +38,7 @@ class DefaultFrontend(AbsFrontend):
|
||||
htk: bool = False,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
apply_stft: bool = True,
|
||||
use_channel: int = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
@ -77,6 +78,7 @@ class DefaultFrontend(AbsFrontend):
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
self.frontend_type = "default"
|
||||
self.use_channel = use_channel
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
@ -100,9 +102,12 @@ class DefaultFrontend(AbsFrontend):
|
||||
if input_stft.dim() == 4:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
if self.training:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(input_stft.size(2))
|
||||
input_stft = input_stft[:, :, ch, :]
|
||||
if self.use_channel == None:
|
||||
input_stft = input_stft[:, :, 0, :]
|
||||
else:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(input_stft.size(2))
|
||||
input_stft = input_stft[:, :, ch, :]
|
||||
else:
|
||||
# Use the first channel
|
||||
input_stft = input_stft[:, :, 0, :]
|
||||
|
||||
@ -83,9 +83,9 @@ def windowed_statistic_pooling(
|
||||
num_chunk = int(math.ceil(tt / pooling_stride))
|
||||
pad = pooling_size // 2
|
||||
if len(xs_pad.shape) == 4:
|
||||
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
|
||||
features = F.pad(xs_pad, (0, 0, pad, pad), "replicate")
|
||||
else:
|
||||
features = F.pad(xs_pad, (pad, pad), "reflect")
|
||||
features = F.pad(xs_pad, (pad, pad), "replicate")
|
||||
stat_list = []
|
||||
|
||||
for i in range(num_chunk):
|
||||
|
||||
@ -13,6 +13,9 @@ import torch
|
||||
from torch import nn
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch.nn.functional as F
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
@ -959,3 +962,37 @@ class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
|
||||
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
|
||||
|
||||
|
||||
class CosineDistanceAttention(nn.Module):
|
||||
""" Compute Cosine Distance between spk decoder output and speaker profile
|
||||
Args:
|
||||
profile_path: speaker profile file path (.npy file)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, spk_decoder_out, profile, profile_lens=None):
|
||||
"""
|
||||
Args:
|
||||
spk_decoder_out(torch.Tensor):(B, L, D)
|
||||
spk_profiles(torch.Tensor):(B, N, D)
|
||||
"""
|
||||
x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
|
||||
if profile_lens is not None:
|
||||
|
||||
mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
|
||||
)
|
||||
weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
|
||||
weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
|
||||
else:
|
||||
x = x[:, -1:, :, :]
|
||||
weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
|
||||
weights = self.softmax(weights_not_softmax) # (B, 1, N)
|
||||
spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
|
||||
|
||||
return spk_embedding, weights
|
||||
|
||||
525
funasr/modules/beam_search/beam_search_sa_asr.py
Executable file
525
funasr/modules/beam_search/beam_search_sa_asr.py
Executable file
@ -0,0 +1,525 @@
|
||||
"""Beam search module."""
|
||||
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import NamedTuple
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.modules.e2e_asr_common import end_detect
|
||||
from funasr.modules.scorers.scorer_interface import PartialScorerInterface
|
||||
from funasr.modules.scorers.scorer_interface import ScorerInterface
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
|
||||
|
||||
class Hypothesis(NamedTuple):
|
||||
"""Hypothesis data type."""
|
||||
|
||||
yseq: torch.Tensor
|
||||
spk_weigths : List
|
||||
score: Union[float, torch.Tensor] = 0
|
||||
scores: Dict[str, Union[float, torch.Tensor]] = dict()
|
||||
states: Dict[str, Any] = dict()
|
||||
|
||||
def asdict(self) -> dict:
|
||||
"""Convert data to JSON-friendly dict."""
|
||||
return self._replace(
|
||||
yseq=self.yseq.tolist(),
|
||||
score=float(self.score),
|
||||
scores={k: float(v) for k, v in self.scores.items()},
|
||||
)._asdict()
|
||||
|
||||
|
||||
class BeamSearch(torch.nn.Module):
|
||||
"""Beam search implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorers: Dict[str, ScorerInterface],
|
||||
weights: Dict[str, float],
|
||||
beam_size: int,
|
||||
vocab_size: int,
|
||||
sos: int,
|
||||
eos: int,
|
||||
token_list: List[str] = None,
|
||||
pre_beam_ratio: float = 1.5,
|
||||
pre_beam_score_key: str = None,
|
||||
):
|
||||
"""Initialize beam search.
|
||||
|
||||
Args:
|
||||
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
||||
e.g., Decoder, CTCPrefixScorer, LM
|
||||
The scorer will be ignored if it is `None`
|
||||
weights (dict[str, float]): Dict of weights for each scorers
|
||||
The scorer will be ignored if its weight is 0
|
||||
beam_size (int): The number of hypotheses kept during search
|
||||
vocab_size (int): The number of vocabulary
|
||||
sos (int): Start of sequence id
|
||||
eos (int): End of sequence id
|
||||
token_list (list[str]): List of tokens for debug log
|
||||
pre_beam_score_key (str): key of scores to perform pre-beam search
|
||||
pre_beam_ratio (float): beam size in the pre-beam search
|
||||
will be `int(pre_beam_ratio * beam_size)`
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
# set scorers
|
||||
self.weights = weights
|
||||
self.scorers = dict()
|
||||
self.full_scorers = dict()
|
||||
self.part_scorers = dict()
|
||||
# this module dict is required for recursive cast
|
||||
# `self.to(device, dtype)` in `recog.py`
|
||||
self.nn_dict = torch.nn.ModuleDict()
|
||||
for k, v in scorers.items():
|
||||
w = weights.get(k, 0)
|
||||
if w == 0 or v is None:
|
||||
continue
|
||||
assert isinstance(
|
||||
v, ScorerInterface
|
||||
), f"{k} ({type(v)}) does not implement ScorerInterface"
|
||||
self.scorers[k] = v
|
||||
if isinstance(v, PartialScorerInterface):
|
||||
self.part_scorers[k] = v
|
||||
else:
|
||||
self.full_scorers[k] = v
|
||||
if isinstance(v, torch.nn.Module):
|
||||
self.nn_dict[k] = v
|
||||
|
||||
# set configurations
|
||||
self.sos = sos
|
||||
self.eos = eos
|
||||
self.token_list = token_list
|
||||
self.pre_beam_size = int(pre_beam_ratio * beam_size)
|
||||
self.beam_size = beam_size
|
||||
self.n_vocab = vocab_size
|
||||
if (
|
||||
pre_beam_score_key is not None
|
||||
and pre_beam_score_key != "full"
|
||||
and pre_beam_score_key not in self.full_scorers
|
||||
):
|
||||
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
|
||||
self.pre_beam_score_key = pre_beam_score_key
|
||||
self.do_pre_beam = (
|
||||
self.pre_beam_score_key is not None
|
||||
and self.pre_beam_size < self.n_vocab
|
||||
and len(self.part_scorers) > 0
|
||||
)
|
||||
|
||||
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
|
||||
"""Get an initial hypothesis data.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoder output feature
|
||||
|
||||
Returns:
|
||||
Hypothesis: The initial hypothesis.
|
||||
|
||||
"""
|
||||
init_states = dict()
|
||||
init_scores = dict()
|
||||
for k, d in self.scorers.items():
|
||||
init_states[k] = d.init_state(x)
|
||||
init_scores[k] = 0.0
|
||||
return [
|
||||
Hypothesis(
|
||||
score=0.0,
|
||||
scores=init_scores,
|
||||
states=init_states,
|
||||
yseq=torch.tensor([self.sos], device=x.device),
|
||||
spk_weigths=[],
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
|
||||
"""Append new token to prefix tokens.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): The prefix token
|
||||
x (int): The new token to append
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
|
||||
|
||||
"""
|
||||
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
|
||||
return torch.cat((xs, x))
|
||||
|
||||
def score_full(
|
||||
self, hyp: Hypothesis, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.full_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||
and tensor score values of shape: `(self.n_vocab,)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.full_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.full_scorers.items():
|
||||
if isinstance(d, AbsDecoder):
|
||||
scores[k], spk_weigths, states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc, spk_enc, profile)
|
||||
else:
|
||||
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc)
|
||||
return scores, spk_weigths, states
|
||||
|
||||
def score_partial(
|
||||
self, hyp: Hypothesis, ids: torch.Tensor, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
|
||||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.part_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
ids (torch.Tensor): 1D tensor of new partial tokens to score
|
||||
x (torch.Tensor): Corresponding input feature
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.part_scorers`
|
||||
and tensor score values of shape: `(len(ids),)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.part_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.part_scorers.items():
|
||||
if isinstance(d, AbsDecoder):
|
||||
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc, spk_enc, profile)
|
||||
else:
|
||||
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc)
|
||||
return scores, states
|
||||
|
||||
def beam(
|
||||
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute topk full token ids and partial token ids.
|
||||
|
||||
Args:
|
||||
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
||||
Its shape is `(self.n_vocab,)`.
|
||||
ids (torch.Tensor): The partial token ids to compute topk
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
The topk full token ids and partial token ids.
|
||||
Their shapes are `(self.beam_size,)`
|
||||
|
||||
"""
|
||||
# no pre beam performed
|
||||
if weighted_scores.size(0) == ids.size(0):
|
||||
top_ids = weighted_scores.topk(self.beam_size)[1]
|
||||
return top_ids, top_ids
|
||||
|
||||
# mask pruned in pre-beam not to select in topk
|
||||
tmp = weighted_scores[ids]
|
||||
weighted_scores[:] = -float("inf")
|
||||
weighted_scores[ids] = tmp
|
||||
top_ids = weighted_scores.topk(self.beam_size)[1]
|
||||
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
|
||||
return top_ids, local_ids
|
||||
|
||||
@staticmethod
|
||||
def merge_scores(
|
||||
prev_scores: Dict[str, float],
|
||||
next_full_scores: Dict[str, torch.Tensor],
|
||||
full_idx: int,
|
||||
next_part_scores: Dict[str, torch.Tensor],
|
||||
part_idx: int,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Merge scores for new hypothesis.
|
||||
|
||||
Args:
|
||||
prev_scores (Dict[str, float]):
|
||||
The previous hypothesis scores by `self.scorers`
|
||||
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
|
||||
full_idx (int): The next token id for `next_full_scores`
|
||||
next_part_scores (Dict[str, torch.Tensor]):
|
||||
scores of partial tokens by `self.part_scorers`
|
||||
part_idx (int): The new token id for `next_part_scores`
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: The new score dict.
|
||||
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||
Its values are scalar tensors by the scorers.
|
||||
|
||||
"""
|
||||
new_scores = dict()
|
||||
for k, v in next_full_scores.items():
|
||||
new_scores[k] = prev_scores[k] + v[full_idx]
|
||||
for k, v in next_part_scores.items():
|
||||
new_scores[k] = prev_scores[k] + v[part_idx]
|
||||
return new_scores
|
||||
|
||||
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
||||
"""Merge states for new hypothesis.
|
||||
|
||||
Args:
|
||||
states: states of `self.full_scorers`
|
||||
part_states: states of `self.part_scorers`
|
||||
part_idx (int): The new token id for `part_scores`
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: The new score dict.
|
||||
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||
Its values are states of the scorers.
|
||||
|
||||
"""
|
||||
new_states = dict()
|
||||
for k, v in states.items():
|
||||
new_states[k] = v
|
||||
for k, d in self.part_scorers.items():
|
||||
new_states[k] = d.select_state(part_states[k], part_idx)
|
||||
return new_states
|
||||
|
||||
def search(
|
||||
self, running_hyps: List[Hypothesis], asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor
|
||||
) -> List[Hypothesis]:
|
||||
"""Search new tokens for running hypotheses and encoded speech x.
|
||||
|
||||
Args:
|
||||
running_hyps (List[Hypothesis]): Running hypotheses on beam
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
|
||||
Returns:
|
||||
List[Hypotheses]: Best sorted hypotheses
|
||||
|
||||
"""
|
||||
# import ipdb;ipdb.set_trace()
|
||||
best_hyps = []
|
||||
part_ids = torch.arange(self.n_vocab, device=asr_enc.device) # no pre-beam
|
||||
for hyp in running_hyps:
|
||||
# scoring
|
||||
weighted_scores = torch.zeros(self.n_vocab, dtype=asr_enc.dtype, device=asr_enc.device)
|
||||
scores, spk_weigths, states = self.score_full(hyp, asr_enc, spk_enc, profile)
|
||||
for k in self.full_scorers:
|
||||
weighted_scores += self.weights[k] * scores[k]
|
||||
# partial scoring
|
||||
if self.do_pre_beam:
|
||||
pre_beam_scores = (
|
||||
weighted_scores
|
||||
if self.pre_beam_score_key == "full"
|
||||
else scores[self.pre_beam_score_key]
|
||||
)
|
||||
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
|
||||
part_scores, part_states = self.score_partial(hyp, part_ids, asr_enc, spk_enc, profile)
|
||||
for k in self.part_scorers:
|
||||
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
|
||||
# add previous hyp score
|
||||
weighted_scores += hyp.score
|
||||
|
||||
# update hyps
|
||||
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
|
||||
# will be (2 x beam at most)
|
||||
best_hyps.append(
|
||||
Hypothesis(
|
||||
score=weighted_scores[j],
|
||||
yseq=self.append_token(hyp.yseq, j),
|
||||
scores=self.merge_scores(
|
||||
hyp.scores, scores, j, part_scores, part_j
|
||||
),
|
||||
states=self.merge_states(states, part_states, part_j),
|
||||
spk_weigths=hyp.spk_weigths+[spk_weigths],
|
||||
)
|
||||
)
|
||||
|
||||
# sort and prune 2 x beam -> beam
|
||||
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
|
||||
: min(len(best_hyps), self.beam_size)
|
||||
]
|
||||
return best_hyps
|
||||
|
||||
def forward(
|
||||
self, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
|
||||
) -> List[Hypothesis]:
|
||||
"""Perform beam search.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
maxlenratio (float): Input length ratio to obtain max output length.
|
||||
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||
to automatically find maximum hypothesis lengths
|
||||
minlenratio (float): Input length ratio to obtain min output length.
|
||||
|
||||
Returns:
|
||||
list[Hypothesis]: N-best decoding results
|
||||
|
||||
"""
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# set length bounds
|
||||
if maxlenratio == 0:
|
||||
maxlen = asr_enc.shape[0]
|
||||
else:
|
||||
maxlen = max(1, int(maxlenratio * asr_enc.size(0)))
|
||||
minlen = int(minlenratio * asr_enc.size(0))
|
||||
logging.info("decoder input length: " + str(asr_enc.shape[0]))
|
||||
logging.info("max output length: " + str(maxlen))
|
||||
logging.info("min output length: " + str(minlen))
|
||||
|
||||
# main loop of prefix search
|
||||
running_hyps = self.init_hyp(asr_enc)
|
||||
ended_hyps = []
|
||||
for i in range(maxlen):
|
||||
logging.debug("position " + str(i))
|
||||
best = self.search(running_hyps, asr_enc, spk_enc, profile)
|
||||
#import pdb;pdb.set_trace()
|
||||
# post process of one iteration
|
||||
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
|
||||
# end detection
|
||||
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
|
||||
logging.info(f"end detected at {i}")
|
||||
break
|
||||
if len(running_hyps) == 0:
|
||||
logging.info("no hypothesis. Finish decoding.")
|
||||
break
|
||||
else:
|
||||
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
||||
|
||||
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
||||
# check the number of hypotheses reaching to eos
|
||||
if len(nbest_hyps) == 0:
|
||||
logging.warning(
|
||||
"there is no N-best results, perform recognition "
|
||||
"again with smaller minlenratio."
|
||||
)
|
||||
return (
|
||||
[]
|
||||
if minlenratio < 0.1
|
||||
else self.forward(asr_enc, spk_enc, profile, maxlenratio, max(0.0, minlenratio - 0.1))
|
||||
)
|
||||
|
||||
# report the best result
|
||||
best = nbest_hyps[0]
|
||||
for k, v in best.scores.items():
|
||||
logging.info(
|
||||
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
|
||||
)
|
||||
logging.info(f"total log probability: {best.score:.2f}")
|
||||
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
||||
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
||||
if self.token_list is not None:
|
||||
logging.info(
|
||||
"best hypo: "
|
||||
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
||||
+ "\n"
|
||||
)
|
||||
return nbest_hyps
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
i: int,
|
||||
maxlen: int,
|
||||
maxlenratio: float,
|
||||
running_hyps: List[Hypothesis],
|
||||
ended_hyps: List[Hypothesis],
|
||||
) -> List[Hypothesis]:
|
||||
"""Perform post-processing of beam search iterations.
|
||||
|
||||
Args:
|
||||
i (int): The length of hypothesis tokens.
|
||||
maxlen (int): The maximum length of tokens in beam search.
|
||||
maxlenratio (int): The maximum length ratio in beam search.
|
||||
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
|
||||
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
||||
|
||||
Returns:
|
||||
List[Hypothesis]: The new running hypotheses.
|
||||
|
||||
"""
|
||||
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
|
||||
if self.token_list is not None:
|
||||
logging.debug(
|
||||
"best hypo: "
|
||||
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
|
||||
)
|
||||
# add eos in the final loop to avoid that there are no ended hyps
|
||||
if i == maxlen - 1:
|
||||
logging.info("adding <eos> in the last position in the loop")
|
||||
running_hyps = [
|
||||
h._replace(yseq=self.append_token(h.yseq, self.eos))
|
||||
for h in running_hyps
|
||||
]
|
||||
|
||||
# add ended hypotheses to a final list, and removed them from current hypotheses
|
||||
# (this will be a problem, number of hyps < beam)
|
||||
remained_hyps = []
|
||||
for hyp in running_hyps:
|
||||
if hyp.yseq[-1] == self.eos:
|
||||
# e.g., Word LM needs to add final <eos> score
|
||||
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
|
||||
s = d.final_score(hyp.states[k])
|
||||
hyp.scores[k] += s
|
||||
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
|
||||
ended_hyps.append(hyp)
|
||||
else:
|
||||
remained_hyps.append(hyp)
|
||||
return remained_hyps
|
||||
|
||||
|
||||
def beam_search(
|
||||
x: torch.Tensor,
|
||||
sos: int,
|
||||
eos: int,
|
||||
beam_size: int,
|
||||
vocab_size: int,
|
||||
scorers: Dict[str, ScorerInterface],
|
||||
weights: Dict[str, float],
|
||||
token_list: List[str] = None,
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
pre_beam_ratio: float = 1.5,
|
||||
pre_beam_score_key: str = "full",
|
||||
) -> list:
|
||||
"""Perform beam search with scorers.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Encoded speech feature (T, D)
|
||||
sos (int): Start of sequence id
|
||||
eos (int): End of sequence id
|
||||
beam_size (int): The number of hypotheses kept during search
|
||||
vocab_size (int): The number of vocabulary
|
||||
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
||||
e.g., Decoder, CTCPrefixScorer, LM
|
||||
The scorer will be ignored if it is `None`
|
||||
weights (dict[str, float]): Dict of weights for each scorers
|
||||
The scorer will be ignored if its weight is 0
|
||||
token_list (list[str]): List of tokens for debug log
|
||||
maxlenratio (float): Input length ratio to obtain max output length.
|
||||
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||
to automatically find maximum hypothesis lengths
|
||||
minlenratio (float): Input length ratio to obtain min output length.
|
||||
pre_beam_score_key (str): key of scores to perform pre-beam search
|
||||
pre_beam_ratio (float): beam size in the pre-beam search
|
||||
will be `int(pre_beam_ratio * beam_size)`
|
||||
|
||||
Returns:
|
||||
list: N-best decoding results
|
||||
|
||||
"""
|
||||
ret = BeamSearch(
|
||||
scorers,
|
||||
weights,
|
||||
beam_size=beam_size,
|
||||
vocab_size=vocab_size,
|
||||
pre_beam_ratio=pre_beam_ratio,
|
||||
pre_beam_score_key=pre_beam_score_key,
|
||||
sos=sos,
|
||||
eos=eos,
|
||||
token_list=token_list,
|
||||
).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
|
||||
return [h.asdict() for h in ret]
|
||||
@ -444,6 +444,12 @@ class AbsTask(ABC):
|
||||
default=False,
|
||||
help='Perform on "collect stats" mode',
|
||||
)
|
||||
group.add_argument(
|
||||
"--mc",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="MultiChannel input",
|
||||
)
|
||||
group.add_argument(
|
||||
"--write_collected_feats",
|
||||
type=str2bool,
|
||||
@ -635,8 +641,8 @@ class AbsTask(ABC):
|
||||
group.add_argument(
|
||||
"--init_param",
|
||||
type=str,
|
||||
action="append",
|
||||
default=[],
|
||||
nargs="*",
|
||||
help="Specify the file path used for initialization of parameters. "
|
||||
"The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
|
||||
"where file_path is the model file path, "
|
||||
@ -662,7 +668,7 @@ class AbsTask(ABC):
|
||||
"--freeze_param",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="*",
|
||||
action="append",
|
||||
help="Freeze parameters",
|
||||
)
|
||||
|
||||
@ -1153,10 +1159,10 @@ class AbsTask(ABC):
|
||||
elif args.distributed and args.simple_ddp:
|
||||
distributed_option.init_torch_distributed_pai(args)
|
||||
args.ngpu = dist.get_world_size()
|
||||
if args.dataset_type == "small":
|
||||
if args.dataset_type == "small" and args.ngpu > 0:
|
||||
if args.batch_size is not None:
|
||||
args.batch_size = args.batch_size * args.ngpu
|
||||
if args.batch_bins is not None:
|
||||
if args.batch_bins is not None and args.ngpu > 0:
|
||||
args.batch_bins = args.batch_bins * args.ngpu
|
||||
|
||||
# filter samples if wav.scp and text are mismatch
|
||||
@ -1316,6 +1322,7 @@ class AbsTask(ABC):
|
||||
data_path_and_name_and_type=args.train_data_path_and_name_and_type,
|
||||
key_file=train_key_file,
|
||||
batch_size=args.batch_size,
|
||||
mc=args.mc,
|
||||
dtype=args.train_dtype,
|
||||
num_workers=args.num_workers,
|
||||
allow_variable_data_keys=args.allow_variable_data_keys,
|
||||
@ -1327,6 +1334,7 @@ class AbsTask(ABC):
|
||||
data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
|
||||
key_file=valid_key_file,
|
||||
batch_size=args.valid_batch_size,
|
||||
mc=args.mc,
|
||||
dtype=args.train_dtype,
|
||||
num_workers=args.num_workers,
|
||||
allow_variable_data_keys=args.allow_variable_data_keys,
|
||||
|
||||
623
funasr/tasks/sa_asr.py
Normal file
623
funasr/tasks/sa_asr.py
Normal file
@ -0,0 +1,623 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.preprocessor import CommonPreprocessor
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.decoder.rnn_decoder import RNNDecoder
|
||||
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
|
||||
from funasr.models.decoder.transformer_decoder import (
|
||||
DynamicConvolution2DTransformerDecoder, # noqa: H301
|
||||
)
|
||||
from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder
|
||||
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
|
||||
from funasr.models.decoder.transformer_decoder import (
|
||||
LightweightConvolution2DTransformerDecoder, # noqa: H301
|
||||
)
|
||||
from funasr.models.decoder.transformer_decoder import (
|
||||
LightweightConvolutionTransformerDecoder, # noqa: H301
|
||||
)
|
||||
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
|
||||
from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
||||
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
|
||||
from funasr.models.e2e_sa_asr import ESPnetASRModel
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
|
||||
from funasr.models.e2e_tp import TimestampPredictor
|
||||
from funasr.models.e2e_asr_mfcca import MFCCA
|
||||
from funasr.models.e2e_uni_asr import UniASR
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
||||
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||
from funasr.models.encoder.rnn_encoder import RNNEncoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
||||
from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
||||
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
|
||||
from funasr.models.encoder.resnet34_encoder import ResNet34,ResNet34Diar
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.default import MultiChannelFrontend
|
||||
from funasr.models.frontend.fused import FusedFrontends
|
||||
from funasr.models.frontend.s3prl import S3prlFrontend
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.models.frontend.windowing import SlidingWindow
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
|
||||
HuggingFaceTransformersPostEncoder, # noqa: H301
|
||||
)
|
||||
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.preencoder.linear import LinearProjection
|
||||
from funasr.models.preencoder.sinc import LightweightSincConvs
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.models.specaug.specaug import SpecAugLFR
|
||||
from funasr.modules.subsampling import Conv1dSubsampling
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
from funasr.train.trainer import Trainer
|
||||
from funasr.utils.get_default_kwargs import get_default_kwargs
|
||||
from funasr.utils.nested_dict_action import NestedDictAction
|
||||
from funasr.utils.types import float_or_none
|
||||
from funasr.utils.types import int_or_none
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str_or_none
|
||||
|
||||
frontend_choices = ClassChoices(
|
||||
name="frontend",
|
||||
classes=dict(
|
||||
default=DefaultFrontend,
|
||||
sliding_window=SlidingWindow,
|
||||
s3prl=S3prlFrontend,
|
||||
fused=FusedFrontends,
|
||||
wav_frontend=WavFrontend,
|
||||
multichannelfrontend=MultiChannelFrontend,
|
||||
),
|
||||
type_check=AbsFrontend,
|
||||
default="default",
|
||||
)
|
||||
specaug_choices = ClassChoices(
|
||||
name="specaug",
|
||||
classes=dict(
|
||||
specaug=SpecAug,
|
||||
specaug_lfr=SpecAugLFR,
|
||||
),
|
||||
type_check=AbsSpecAug,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
normalize_choices = ClassChoices(
|
||||
"normalize",
|
||||
classes=dict(
|
||||
global_mvn=GlobalMVN,
|
||||
utterance_mvn=UtteranceMVN,
|
||||
),
|
||||
type_check=AbsNormalize,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
model_choices = ClassChoices(
|
||||
"model",
|
||||
classes=dict(
|
||||
asr=ESPnetASRModel,
|
||||
uniasr=UniASR,
|
||||
paraformer=Paraformer,
|
||||
paraformer_bert=ParaformerBert,
|
||||
bicif_paraformer=BiCifParaformer,
|
||||
contextual_paraformer=ContextualParaformer,
|
||||
mfcca=MFCCA,
|
||||
timestamp_prediction=TimestampPredictor,
|
||||
),
|
||||
type_check=AbsESPnetModel,
|
||||
default="asr",
|
||||
)
|
||||
preencoder_choices = ClassChoices(
|
||||
name="preencoder",
|
||||
classes=dict(
|
||||
sinc=LightweightSincConvs,
|
||||
linear=LinearProjection,
|
||||
),
|
||||
type_check=AbsPreEncoder,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
asr_encoder_choices = ClassChoices(
|
||||
"asr_encoder",
|
||||
classes=dict(
|
||||
conformer=ConformerEncoder,
|
||||
transformer=TransformerEncoder,
|
||||
rnn=RNNEncoder,
|
||||
sanm=SANMEncoder,
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
data2vec_encoder=Data2VecEncoder,
|
||||
mfcca_enc=MFCCAEncoder,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
default="rnn",
|
||||
)
|
||||
|
||||
spk_encoder_choices = ClassChoices(
|
||||
"spk_encoder",
|
||||
classes=dict(
|
||||
resnet34_diar=ResNet34Diar,
|
||||
),
|
||||
default="resnet34_diar",
|
||||
)
|
||||
|
||||
encoder_choices2 = ClassChoices(
|
||||
"encoder2",
|
||||
classes=dict(
|
||||
conformer=ConformerEncoder,
|
||||
transformer=TransformerEncoder,
|
||||
rnn=RNNEncoder,
|
||||
sanm=SANMEncoder,
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
default="rnn",
|
||||
)
|
||||
postencoder_choices = ClassChoices(
|
||||
name="postencoder",
|
||||
classes=dict(
|
||||
hugging_face_transformers=HuggingFaceTransformersPostEncoder,
|
||||
),
|
||||
type_check=AbsPostEncoder,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
decoder_choices = ClassChoices(
|
||||
"decoder",
|
||||
classes=dict(
|
||||
transformer=TransformerDecoder,
|
||||
lightweight_conv=LightweightConvolutionTransformerDecoder,
|
||||
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
|
||||
dynamic_conv=DynamicConvolutionTransformerDecoder,
|
||||
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
|
||||
rnn=RNNDecoder,
|
||||
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
|
||||
paraformer_decoder_sanm=ParaformerSANMDecoder,
|
||||
paraformer_decoder_san=ParaformerDecoderSAN,
|
||||
contextual_paraformer_decoder=ContextualParaformerDecoder,
|
||||
sa_decoder=SAAsrTransformerDecoder,
|
||||
),
|
||||
type_check=AbsDecoder,
|
||||
default="sa_decoder",
|
||||
)
|
||||
decoder_choices2 = ClassChoices(
|
||||
"decoder2",
|
||||
classes=dict(
|
||||
transformer=TransformerDecoder,
|
||||
lightweight_conv=LightweightConvolutionTransformerDecoder,
|
||||
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
|
||||
dynamic_conv=DynamicConvolutionTransformerDecoder,
|
||||
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
|
||||
rnn=RNNDecoder,
|
||||
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
|
||||
paraformer_decoder_sanm=ParaformerSANMDecoder,
|
||||
),
|
||||
type_check=AbsDecoder,
|
||||
default="rnn",
|
||||
)
|
||||
predictor_choices = ClassChoices(
|
||||
name="predictor",
|
||||
classes=dict(
|
||||
cif_predictor=CifPredictor,
|
||||
ctc_predictor=None,
|
||||
cif_predictor_v2=CifPredictorV2,
|
||||
cif_predictor_v3=CifPredictorV3,
|
||||
),
|
||||
type_check=None,
|
||||
default="cif_predictor",
|
||||
optional=True,
|
||||
)
|
||||
predictor_choices2 = ClassChoices(
|
||||
name="predictor2",
|
||||
classes=dict(
|
||||
cif_predictor=CifPredictor,
|
||||
ctc_predictor=None,
|
||||
cif_predictor_v2=CifPredictorV2,
|
||||
),
|
||||
type_check=None,
|
||||
default="cif_predictor",
|
||||
optional=True,
|
||||
)
|
||||
stride_conv_choices = ClassChoices(
|
||||
name="stride_conv",
|
||||
classes=dict(
|
||||
stride_conv1d=Conv1dSubsampling
|
||||
),
|
||||
type_check=None,
|
||||
default="stride_conv1d",
|
||||
optional=True,
|
||||
)
|
||||
|
||||
|
||||
class ASRTask(AbsTask):
|
||||
# If you need more than one optimizers, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
# --preencoder and --preencoder_conf
|
||||
preencoder_choices,
|
||||
# --asr_encoder and --asr_encoder_conf
|
||||
asr_encoder_choices,
|
||||
# --spk_encoder and --spk_encoder_conf
|
||||
spk_encoder_choices,
|
||||
# --postencoder and --postencoder_conf
|
||||
postencoder_choices,
|
||||
# --decoder and --decoder_conf
|
||||
decoder_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def add_task_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(description="Task related")
|
||||
|
||||
# NOTE(kamo): add_arguments(..., required=True) can't be used
|
||||
# to provide --print_config mode. Instead of it, do as
|
||||
# required = parser.get_default("required")
|
||||
# required += ["token_list"]
|
||||
|
||||
group.add_argument(
|
||||
"--token_list",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="A text mapping int-id to token",
|
||||
)
|
||||
group.add_argument(
|
||||
"--split_with_space",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether to split text using <space>",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max_spk_num",
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help="A text mapping int-id to token",
|
||||
)
|
||||
group.add_argument(
|
||||
"--seg_dict_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="seg_dict_file for text processing",
|
||||
)
|
||||
group.add_argument(
|
||||
"--init",
|
||||
type=lambda x: str_or_none(x.lower()),
|
||||
default=None,
|
||||
help="The initialization method",
|
||||
choices=[
|
||||
"chainer",
|
||||
"xavier_uniform",
|
||||
"xavier_normal",
|
||||
"kaiming_uniform",
|
||||
"kaiming_normal",
|
||||
None,
|
||||
],
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input_size",
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help="The number of input dimension of the feature",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--ctc_conf",
|
||||
action=NestedDictAction,
|
||||
default=get_default_kwargs(CTC),
|
||||
help="The keyword arguments for CTC class.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--joint_net_conf",
|
||||
action=NestedDictAction,
|
||||
default=None,
|
||||
help="The keyword arguments for joint network class.",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group(description="Preprocess related")
|
||||
group.add_argument(
|
||||
"--use_preprocessor",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Apply preprocessing to data or not",
|
||||
)
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str,
|
||||
default="bpe",
|
||||
choices=["bpe", "char", "word", "phn"],
|
||||
help="The text will be tokenized " "in the specified level token",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bpemodel",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The model file of sentencepiece",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non_linguistic_symbols",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="non_linguistic_symbols file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cleaner",
|
||||
type=str_or_none,
|
||||
choices=[None, "tacotron", "jaconv", "vietnamese"],
|
||||
default=None,
|
||||
help="Apply text cleaning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g2p",
|
||||
type=str_or_none,
|
||||
choices=g2p_choices,
|
||||
default=None,
|
||||
help="Specify g2p method if --token_type=phn",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speech_volume_normalize",
|
||||
type=float_or_none,
|
||||
default=None,
|
||||
help="Scale the maximum amplitude to the given value.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of rir scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="THe probability for applying RIR convolution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cmvn_file",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of noise scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of noise scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The probability applying Noise adding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_db_range",
|
||||
type=str,
|
||||
default="13_15",
|
||||
help="The range of noise decibel level.",
|
||||
)
|
||||
|
||||
for class_choices in cls.class_choices_list:
|
||||
# Append --<name> and --<name>_conf.
|
||||
# e.g. --encoder and --encoder_conf
|
||||
class_choices.add_arguments(group)
|
||||
|
||||
@classmethod
|
||||
def build_collate_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Callable[
|
||||
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
|
||||
Tuple[List[str], Dict[str, torch.Tensor]],
|
||||
]:
|
||||
assert check_argument_types()
|
||||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
|
||||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||
assert check_argument_types()
|
||||
if args.use_preprocessor:
|
||||
retval = CommonPreprocessor(
|
||||
train=train,
|
||||
token_type=args.token_type,
|
||||
token_list=args.token_list,
|
||||
bpemodel=args.bpemodel,
|
||||
non_linguistic_symbols=args.non_linguistic_symbols,
|
||||
text_cleaner=args.cleaner,
|
||||
g2p_type=args.g2p,
|
||||
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
|
||||
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
|
||||
# NOTE(kamo): Check attribute existence for backward compatibility
|
||||
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
|
||||
rir_apply_prob=args.rir_apply_prob
|
||||
if hasattr(args, "rir_apply_prob")
|
||||
else 1.0,
|
||||
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
|
||||
noise_apply_prob=args.noise_apply_prob
|
||||
if hasattr(args, "noise_apply_prob")
|
||||
else 1.0,
|
||||
noise_db_range=args.noise_db_range
|
||||
if hasattr(args, "noise_db_range")
|
||||
else "13_15",
|
||||
speech_volume_normalize=args.speech_volume_normalize
|
||||
if hasattr(args, "rir_scp")
|
||||
else None,
|
||||
)
|
||||
else:
|
||||
retval = None
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def required_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
if not inference:
|
||||
retval = ("speech", "text")
|
||||
else:
|
||||
# Recognition mode
|
||||
retval = ("speech",)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def optional_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
retval = ()
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size}")
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
if args.frontend == 'wav_frontend':
|
||||
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||
else:
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# 4. Pre-encoder input block
|
||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility
|
||||
if getattr(args, "preencoder", None) is not None:
|
||||
preencoder_class = preencoder_choices.get_class(args.preencoder)
|
||||
preencoder = preencoder_class(**args.preencoder_conf)
|
||||
input_size = preencoder.output_size()
|
||||
else:
|
||||
preencoder = None
|
||||
|
||||
# 5. Encoder
|
||||
asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
|
||||
asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
|
||||
spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
|
||||
spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
|
||||
|
||||
# 6. Post-encoder block
|
||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility
|
||||
asr_encoder_output_size = asr_encoder.output_size()
|
||||
if getattr(args, "postencoder", None) is not None:
|
||||
postencoder_class = postencoder_choices.get_class(args.postencoder)
|
||||
postencoder = postencoder_class(
|
||||
input_size=asr_encoder_output_size, **args.postencoder_conf
|
||||
)
|
||||
asr_encoder_output_size = postencoder.output_size()
|
||||
else:
|
||||
postencoder = None
|
||||
|
||||
# 7. Decoder
|
||||
decoder_class = decoder_choices.get_class(args.decoder)
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=asr_encoder_output_size,
|
||||
**args.decoder_conf,
|
||||
)
|
||||
|
||||
# 8. CTC
|
||||
ctc = CTC(
|
||||
odim=vocab_size, encoder_output_size=asr_encoder_output_size, **args.ctc_conf
|
||||
)
|
||||
|
||||
max_spk_num=int(args.max_spk_num)
|
||||
|
||||
# import ipdb;ipdb.set_trace()
|
||||
# 9. Build model
|
||||
try:
|
||||
model_class = model_choices.get_class(args.model)
|
||||
except AttributeError:
|
||||
model_class = model_choices.get_class("asr")
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
max_spk_num=max_spk_num,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
preencoder=preencoder,
|
||||
asr_encoder=asr_encoder,
|
||||
spk_encoder=spk_encoder,
|
||||
postencoder=postencoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
token_list=token_list,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
# 10. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
@ -106,18 +106,17 @@ def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
|
||||
if num in abbr_begin:
|
||||
if time_stamp is not None:
|
||||
begin = time_stamp[ts_nums[num]][0]
|
||||
abbr_word = words[num].upper()
|
||||
word_lists.append(words[num].upper())
|
||||
num += 1
|
||||
while num < words_size:
|
||||
if num in abbr_end:
|
||||
abbr_word += words[num].upper()
|
||||
word_lists.append(words[num].upper())
|
||||
last_num = num
|
||||
break
|
||||
else:
|
||||
if words[num].encode('utf-8').isalpha():
|
||||
abbr_word += words[num].upper()
|
||||
word_lists.append(words[num].upper())
|
||||
num += 1
|
||||
word_lists.append(abbr_word)
|
||||
if time_stamp is not None:
|
||||
end = time_stamp[ts_nums[num]][1]
|
||||
ts_lists.append([begin, end])
|
||||
|
||||
7
setup.py
7
setup.py
@ -13,7 +13,7 @@ requirements = {
|
||||
"install": [
|
||||
"setuptools>=38.5.1",
|
||||
# "configargparse>=1.2.1",
|
||||
"typeguard<=2.13.3",
|
||||
"typeguard==2.13.3",
|
||||
"humanfriendly",
|
||||
"scipy>=1.4.1",
|
||||
# "filelock",
|
||||
@ -42,7 +42,10 @@ requirements = {
|
||||
"oss2",
|
||||
# "kaldi-native-fbank",
|
||||
# timestamp
|
||||
"edit-distance"
|
||||
"edit-distance",
|
||||
# textgrid
|
||||
"textgrid",
|
||||
"protobuf==3.20.0",
|
||||
],
|
||||
# train: The modules invoked when training only.
|
||||
"train": [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user