Merge pull request #473 from alibaba-damo-academy/dev_smohan

Add speaker-attributed ASR task for alimeeting (baseline for m2met2.0).
This commit is contained in:
jmwang66 2023-05-09 10:58:33 +08:00 committed by GitHub
commit 8dab6d184a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 9252 additions and 26 deletions

View File

@ -0,0 +1,79 @@
# Get Started
Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
To run this receipe, first you need to install FunASR and ModelScope. ([installation](https://alibaba-damo-academy.github.io/FunASR/en/installation.html))
There are two startup scripts, `run.sh` for training and evaluating on the old eval and test sets, and `run_m2met_2023_infer.sh` for inference on the new test set of the Multi-Channel Multi-Party Meeting Transcription 2.0 ([M2MET2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)) Challenge.
Before running `run.sh`, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory:
```shell
dataset
|—— Eval_Ali_far
|—— Eval_Ali_near
|—— Test_Ali_far
|—— Test_Ali_near
|—— Train_Ali_far
|—— Train_Ali_near
```
There are 18 stages in `run.sh`:
```shell
stage 1 - 5: Data preparation and processing.
stage 6: Generate speaker profiles (Stage 6 takes a lot of time).
stage 7 - 9: Language model training (Optional).
stage 10 - 11: ASR training (SA-ASR requires loading the pre-trained ASR model).
stage 12: SA-ASR training.
stage 13 - 18: Inference and evaluation.
```
Before running `run_m2met_2023_infer.sh`, you need to place the new test set `Test_2023_Ali_far` (to be released after the challenge starts) in the `./dataset` directory, which contains only raw audios. Then put the given `wav.scp`, `wav_raw.scp`, `segments`, `utt2spk` and `spk2utt` in the `./data/Test_2023_Ali_far` directory.
```shell
data/Test_2023_Ali_far
|—— wav.scp
|—— wav_raw.scp
|—— segments
|—— utt2spk
|—— spk2utt
```
There are 4 stages in `run_m2met_2023_infer.sh`:
```shell
stage 1: Data preparation and processing.
stage 2: Generate speaker profiles for inference.
stage 3: Inference.
stage 4: Generation of SA-ASR results required for final submission.
```
# Format of Final Submission
Finally, you need to submit a file called `text_spk_merge` with the following format:
```shell
Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
...
```
Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
# Baseline Results
The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Eval and Test Set are also provided to show the impact of speaker profile accuracy.
<table>
<tr >
<td rowspan="2"></td>
<td colspan="2">SI-CER(%)</td>
<td colspan="2">cpCER(%)</td>
</tr>
<tr>
<td>Eval</td>
<td>Test</td>
<td>Eval</td>
<td>Test</td>
</tr>
<tr>
<td>oracle profile</td>
<td>31.93</td>
<td>32.75</td>
<td>48.56</td>
<td>53.33</td>
</tr>
<tr>
<td>cluster profile</td>
<td>31.94</td>
<td>32.77</td>
<td>55.49</td>
<td>58.17</td>
</tr>
</table>
# Reference
N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 44134417.

1572
egs/alimeeting/sa-asr/asr_local.sh Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,591 @@
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
min() {
local a b
a=$1
for b in "$@"; do
if [ "${b}" -le "${a}" ]; then
a="${b}"
fi
done
echo "${a}"
}
SECONDS=0
# General configuration
stage=1 # Processes starts from the specified stage.
stop_stage=10000 # Processes is stopped at the specified stage.
skip_data_prep=false # Skip data preparation stages.
skip_train=false # Skip training stages.
skip_eval=false # Skip decoding and evaluation stages.
skip_upload=true # Skip packing and uploading stages.
ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu).
num_nodes=1 # The number of nodes.
nj=16 # The number of parallel jobs.
inference_nj=16 # The number of parallel jobs in decoding.
gpu_inference=false # Whether to perform gpu decoding.
njob_infer=4
dumpdir=dump2 # Directory to dump features.
expdir=exp # Directory to save experiments.
python=python3 # Specify python to execute espnet commands.
device=0
# Data preparation related
local_data_opts= # The options given to local/data.sh.
# Speed perturbation related
speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
# Feature extraction related
feats_type=raw # Feature type (raw or fbank_pitch).
audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw).
fs=16000 # Sampling rate.
min_wav_duration=0.1 # Minimum duration in second.
max_wav_duration=20 # Maximum duration in second.
# Tokenization related
token_type=bpe # Tokenization type (char or bpe).
nbpe=30 # The number of BPE vocabulary.
bpemode=unigram # Mode of BPE (unigram or bpe).
oov="<unk>" # Out of vocabulary symbol.
blank="<blank>" # CTC blank symbol
sos_eos="<sos/eos>" # sos and eos symbole
bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE
bpe_char_cover=1.0 # character coverage when modeling BPE
# Language model related
use_lm=true # Use language model for ASR decoding.
lm_tag= # Suffix to the result dir for language model training.
lm_exp= # Specify the direcotry path for LM experiment.
# If this option is specified, lm_tag is ignored.
lm_stats_dir= # Specify the direcotry path for LM statistics.
lm_config= # Config for language model training.
lm_args= # Arguments for language model training, e.g., "--max_epoch 10".
# Note that it will overwrite args in lm config.
use_word_lm=false # Whether to use word language model.
num_splits_lm=1 # Number of splitting for lm corpus.
# shellcheck disable=SC2034
word_vocab_size=10000 # Size of word vocabulary.
# ASR model related
asr_tag= # Suffix to the result dir for asr model training.
asr_exp= # Specify the direcotry path for ASR experiment.
# If this option is specified, asr_tag is ignored.
sa_asr_exp=
asr_stats_dir= # Specify the direcotry path for ASR statistics.
asr_config= # Config for asr model training.
sa_asr_config=
asr_args= # Arguments for asr model training, e.g., "--max_epoch 10".
# Note that it will overwrite args in asr config.
feats_normalize=global_mvn # Normalizaton layer type.
num_splits_asr=1 # Number of splitting for lm corpus.
# Decoding related
inference_tag= # Suffix to the result dir for decoding.
inference_config= # Config for decoding.
inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
# Note that it will overwrite args in inference config.
sa_asr_inference_tag=
sa_asr_inference_args=
inference_lm=valid.loss.ave.pb # Language modle path for decoding.
inference_asr_model=valid.acc.ave.pb # ASR model path for decoding.
# e.g.
# inference_asr_model=train.loss.best.pth
# inference_asr_model=3epoch.pth
# inference_asr_model=valid.acc.best.pth
# inference_asr_model=valid.loss.ave.pth
inference_sa_asr_model=valid.acc_spk.ave.pb
download_model= # Download a model from Model Zoo and use it for decoding.
# [Task dependent] Set the datadir name created by local/data.sh
train_set= # Name of training set.
valid_set= # Name of validation set used for monitoring/tuning network training.
test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
bpe_train_text= # Text file path of bpe training set.
lm_train_text= # Text file path of language model training set.
lm_dev_text= # Text file path of language model development set.
lm_test_text= # Text file path of language model evaluation set.
nlsyms_txt=none # Non-linguistic symbol list if existing.
cleaner=none # Text cleaner.
g2p=none # g2p method (needed if token_type=phn).
lang=zh # The language type of corpus.
score_opts= # The options given to sclite scoring
local_score_opts= # The options given to local/score.sh.
help_message=$(cat << EOF
Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
Options:
# General configuration
--stage # Processes starts from the specified stage (default="${stage}").
--stop_stage # Processes is stopped at the specified stage (default="${stop_stage}").
--skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
--skip_train # Skip training stages (default="${skip_train}").
--skip_eval # Skip decoding and evaluation stages (default="${skip_eval}").
--skip_upload # Skip packing and uploading stages (default="${skip_upload}").
--ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
--num_nodes # The number of nodes (default="${num_nodes}").
--nj # The number of parallel jobs (default="${nj}").
--inference_nj # The number of parallel jobs in decoding (default="${inference_nj}").
--gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}").
--dumpdir # Directory to dump features (default="${dumpdir}").
--expdir # Directory to save experiments (default="${expdir}").
--python # Specify python to execute espnet commands (default="${python}").
--device # Which GPUs are use for local training (defalut="${device}").
# Data preparation related
--local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
# Speed perturbation related
--speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
# Feature extraction related
--feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
--audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}").
--fs # Sampling rate (default="${fs}").
--min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
--max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
# Tokenization related
--token_type # Tokenization type (char or bpe, default="${token_type}").
--nbpe # The number of BPE vocabulary (default="${nbpe}").
--bpemode # Mode of BPE (unigram or bpe, default="${bpemode}").
--oov # Out of vocabulary symbol (default="${oov}").
--blank # CTC blank symbol (default="${blank}").
--sos_eos # sos and eos symbole (default="${sos_eos}").
--bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
--bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
--bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}").
# Language model related
--lm_tag # Suffix to the result dir for language model training (default="${lm_tag}").
--lm_exp # Specify the direcotry path for LM experiment.
# If this option is specified, lm_tag is ignored (default="${lm_exp}").
--lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
--lm_config # Config for language model training (default="${lm_config}").
--lm_args # Arguments for language model training (default="${lm_args}").
# e.g., --lm_args "--max_epoch 10"
# Note that it will overwrite args in lm config.
--use_word_lm # Whether to use word language model (default="${use_word_lm}").
--word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
--num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}").
# ASR model related
--asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}").
--asr_exp # Specify the direcotry path for ASR experiment.
# If this option is specified, asr_tag is ignored (default="${asr_exp}").
--asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
--asr_config # Config for asr model training (default="${asr_config}").
--asr_args # Arguments for asr model training (default="${asr_args}").
# e.g., --asr_args "--max_epoch 10"
# Note that it will overwrite args in asr config.
--feats_normalize # Normalizaton layer type (default="${feats_normalize}").
--num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}").
# Decoding related
--inference_tag # Suffix to the result dir for decoding (default="${inference_tag}").
--inference_config # Config for decoding (default="${inference_config}").
--inference_args # Arguments for decoding (default="${inference_args}").
# e.g., --inference_args "--lm_weight 0.1"
# Note that it will overwrite args in inference config.
--inference_lm # Language modle path for decoding (default="${inference_lm}").
--inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
--download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}").
# [Task dependent] Set the datadir name created by local/data.sh
--train_set # Name of training set (required).
--valid_set # Name of validation set used for monitoring/tuning network training (required).
--test_sets # Names of test sets.
# Multiple items (e.g., both dev and eval sets) can be specified (required).
--bpe_train_text # Text file path of bpe training set.
--lm_train_text # Text file path of language model training set.
--lm_dev_text # Text file path of language model development set (default="${lm_dev_text}").
--lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}").
--nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
--cleaner # Text cleaner (default="${cleaner}").
--g2p # g2p method (default="${g2p}").
--lang # The language type of corpus (default=${lang}).
--score_opts # The options given to sclite scoring (default="{score_opts}").
--local_score_opts # The options given to local/score.sh (default="{local_score_opts}").
EOF
)
log "$0 $*"
# Save command line args for logging (they will be lost after utils/parse_options.sh)
run_args=$(python -m funasr.utils.cli_utils $0 "$@")
. utils/parse_options.sh
if [ $# -ne 0 ]; then
log "${help_message}"
log "Error: No positional arguments are required."
exit 2
fi
. ./path.sh
# Check required arguments
[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
# Check feature type
if [ "${feats_type}" = raw ]; then
data_feats=${dumpdir}/raw
elif [ "${feats_type}" = fbank_pitch ]; then
data_feats=${dumpdir}/fbank_pitch
elif [ "${feats_type}" = fbank ]; then
data_feats=${dumpdir}/fbank
elif [ "${feats_type}" == extracted ]; then
data_feats=${dumpdir}/extracted
else
log "${help_message}"
log "Error: not supported: --feats_type ${feats_type}"
exit 2
fi
# Use the same text as ASR for bpe training if not specified.
[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
# Use the same text as ASR for lm training if not specified.
[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
# Use the same text as ASR for lm training if not specified.
[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
# Use the text of the 1st evaldir if lm_test is not specified
[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
# Check tokenization type
if [ "${lang}" != noinfo ]; then
token_listdir=data/${lang}_token_list
else
token_listdir=data/token_list
fi
bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
bpeprefix="${bpedir}"/bpe
bpemodel="${bpeprefix}".model
bpetoken_list="${bpedir}"/tokens.txt
chartoken_list="${token_listdir}"/char/tokens.txt
# NOTE: keep for future development.
# shellcheck disable=SC2034
wordtoken_list="${token_listdir}"/word/tokens.txt
if [ "${token_type}" = bpe ]; then
token_list="${bpetoken_list}"
elif [ "${token_type}" = char ]; then
token_list="${chartoken_list}"
bpemodel=none
elif [ "${token_type}" = word ]; then
token_list="${wordtoken_list}"
bpemodel=none
else
log "Error: not supported --token_type '${token_type}'"
exit 2
fi
if ${use_word_lm}; then
log "Error: Word LM is not supported yet"
exit 2
lm_token_list="${wordtoken_list}"
lm_token_type=word
else
lm_token_list="${token_list}"
lm_token_type="${token_type}"
fi
# Set tag for naming of model directory
if [ -z "${asr_tag}" ]; then
if [ -n "${asr_config}" ]; then
asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
else
asr_tag="train_${feats_type}"
fi
if [ "${lang}" != noinfo ]; then
asr_tag+="_${lang}_${token_type}"
else
asr_tag+="_${token_type}"
fi
if [ "${token_type}" = bpe ]; then
asr_tag+="${nbpe}"
fi
# Add overwritten arg's info
if [ -n "${asr_args}" ]; then
asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
fi
if [ -n "${speed_perturb_factors}" ]; then
asr_tag+="_sp"
fi
fi
if [ -z "${lm_tag}" ]; then
if [ -n "${lm_config}" ]; then
lm_tag="$(basename "${lm_config}" .yaml)"
else
lm_tag="train"
fi
if [ "${lang}" != noinfo ]; then
lm_tag+="_${lang}_${lm_token_type}"
else
lm_tag+="_${lm_token_type}"
fi
if [ "${lm_token_type}" = bpe ]; then
lm_tag+="${nbpe}"
fi
# Add overwritten arg's info
if [ -n "${lm_args}" ]; then
lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
fi
fi
# The directory used for collect-stats mode
if [ -z "${asr_stats_dir}" ]; then
if [ "${lang}" != noinfo ]; then
asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
else
asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
fi
if [ "${token_type}" = bpe ]; then
asr_stats_dir+="${nbpe}"
fi
if [ -n "${speed_perturb_factors}" ]; then
asr_stats_dir+="_sp"
fi
fi
if [ -z "${lm_stats_dir}" ]; then
if [ "${lang}" != noinfo ]; then
lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
else
lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
fi
if [ "${lm_token_type}" = bpe ]; then
lm_stats_dir+="${nbpe}"
fi
fi
# The directory used for training commands
if [ -z "${asr_exp}" ]; then
asr_exp="${expdir}/asr_${asr_tag}"
fi
if [ -z "${lm_exp}" ]; then
lm_exp="${expdir}/lm_${lm_tag}"
fi
if [ -z "${inference_tag}" ]; then
if [ -n "${inference_config}" ]; then
inference_tag="$(basename "${inference_config}" .yaml)"
else
inference_tag=inference
fi
# Add overwritten arg's info
if [ -n "${inference_args}" ]; then
inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
fi
if "${use_lm}"; then
inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
fi
inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
fi
if [ -z "${sa_asr_inference_tag}" ]; then
if [ -n "${inference_config}" ]; then
sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
else
sa_asr_inference_tag=sa_asr_inference
fi
# Add overwritten arg's info
if [ -n "${sa_asr_inference_args}" ]; then
sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
fi
if "${use_lm}"; then
sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
fi
sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
fi
train_cmd="run.pl"
cuda_cmd="run.pl"
decode_cmd="run.pl"
# ========================== Main stages start from here. ==========================
if ! "${skip_data_prep}"; then
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [ "${feats_type}" = raw ]; then
log "Stage 1: Format wav.scp: data/ -> ${data_feats}"
# ====== Recreating "wav.scp" ======
# Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
# shouldn't be used in training process.
# "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
# and it can also change the audio-format and sampling rate.
# If nothing is need, then format_wav_scp.sh does nothing:
# i.e. the input file format and rate is same as the output.
for dset in "${test_sets}" ; do
_suf=""
local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
_opts=
if [ -e data/"${dset}"/segments ]; then
# "segments" is used for splitting wav files which are written in "wav".scp
# into utterances. The file format of segments:
# <segment_id> <record_id> <start_time> <end_time>
# "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
# Where the time is written in seconds.
_opts+="--segments data/${dset}/segments "
fi
# shellcheck disable=SC2086
local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
--audio-format "${audio_format}" --fs "${fs}" ${_opts} \
"data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
done
else
log "Error: not supported: --feats_type ${feats_type}"
exit 2
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
log "Stage 2: Generate speaker profile by spectral-cluster"
mkdir -p "profile_log"
for dset in "${test_sets}"; do
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
done
fi
else
log "Skip the stages for data preparation"
fi
# ========================== Data preparation is done here. ==========================
if ! "${skip_eval}"; then
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
log "Stage 3: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
if ${gpu_inference}; then
_cmd="${cuda_cmd}"
inference_nj=$[${ngpu}*${njob_infer}]
_ngpu=1
else
_cmd="${decode_cmd}"
inference_nj=$njob_infer
_ngpu=0
fi
_opts=
if [ -n "${inference_config}" ]; then
_opts+="--config ${inference_config} "
fi
if "${use_lm}"; then
if "${use_word_lm}"; then
_opts+="--word_lm_train_config ${lm_exp}/config.yaml "
_opts+="--word_lm_file ${lm_exp}/${inference_lm} "
else
_opts+="--lm_train_config ${lm_exp}/config.yaml "
_opts+="--lm_file ${lm_exp}/${inference_lm} "
fi
fi
# 2. Generate run.sh
log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
for dset in ${test_sets}; do
_data="${data_feats}/${dset}"
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
_logdir="${_dir}/logdir"
mkdir -p "${_logdir}"
_feats_type="$(<${_data}/feats_type)"
if [ "${_feats_type}" = raw ]; then
_scp=wav.scp
if [[ "${audio_format}" == *ark* ]]; then
_type=kaldi_ark
else
_type=sound
fi
else
_scp=feats.scp
_type=kaldi_ark
fi
# 1. Split the key file
key_file=${_data}/${_scp}
split_scps=""
_nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
# 2. Submit decoding jobs
log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
# shellcheck disable=SC2086
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--mc True \
--nbest 1 \
--ngpu "${_ngpu}" \
--njob ${njob_infer} \
--gpuid_list ${device} \
--data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
--data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
--key_file "${_logdir}"/keys.JOB.scp \
--allow_variable_data_keys true \
--asr_train_config "${sa_asr_exp}"/config.yaml \
--asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
--output_dir "${_logdir}"/output.JOB \
--mode sa_asr \
${_opts}
# 3. Concatenates the output files from each jobs
for f in token token_int score text text_id; do
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/1best_recog/${f}"
done | LC_ALL=C sort -k1 >"${_dir}/${f}"
done
done
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
log "Stage 4: Generate SA-ASR results (cluster profile)"
for dset in ${test_sets}; do
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
python local/process_text_spk_merge.py ${_dir}
done
fi
else
log "Skip the evaluation stages"
fi
log "Successfully finished. [elapsed=${SECONDS}s]"

View File

@ -0,0 +1,6 @@
beam_size: 20
penalty: 0.0
maxlenratio: 0.0
minlenratio: 0.0
ctc_weight: 0.6
lm_weight: 0.3

View File

@ -0,0 +1,87 @@
# network architecture
frontend: default
frontend_conf:
n_fft: 400
win_length: 400
hop_length: 160
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder architecture type
normalize_before: true
rel_pos_type: latest
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
activation_type: swish
macaron_style: true
use_cnn_module: true
cnn_module_kernel: 15
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# ctc related
ctc_conf:
ignore_nan_grad: true
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
# minibatch related
batch_type: numel
batch_bins: 10000000 # reduce/increase this number according to your GPU memory
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 100
val_scheduler_criterion:
- valid
- acc
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2

View File

@ -0,0 +1,29 @@
lm: transformer
lm_conf:
pos_enc: null
embed_unit: 128
att_unit: 512
head: 8
unit: 2048
layer: 16
dropout_rate: 0.1
# optimization related
grad_clip: 5.0
batch_type: numel
batch_bins: 500000 # 4gpus * 500000
accum_grad: 1
max_epoch: 15 # 15epoch is enougth
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 10 # 10 is good.

View File

@ -0,0 +1,115 @@
# network architecture
frontend: default
frontend_conf:
n_fft: 400
win_length: 400
hop_length: 160
# encoder related
asr_encoder: conformer
asr_encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder architecture type
normalize_before: true
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
activation_type: swish
macaron_style: true
use_cnn_module: true
cnn_module_kernel: 15
spk_encoder: resnet34_diar
spk_encoder_conf:
use_head_conv: true
batchnorm_momentum: 0.5
use_head_maxpool: false
num_nodes_pooling_layer: 256
layers_in_block:
- 3
- 4
- 6
- 3
filters_in_block:
- 32
- 64
- 128
- 256
pooling_type: statistic
num_nodes_resnet1: 256
num_nodes_last_layer: 256
batchnorm_momentum: 0.5
# decoder related
decoder: sa_decoder
decoder_conf:
attention_heads: 4
linear_units: 2048
asr_num_blocks: 6
spk_num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
spk_weight: 0.5
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
ctc_conf:
ignore_nan_grad: true
# minibatch related
batch_type: numel
batch_bins: 10000000
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 60
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- acc
- max
- - valid
- acc_spk
- max
- - valid
- loss
- min
keep_nbest_models: 10
optim: adam
optim_conf:
lr: 0.0005
scheduler: warmuplr
scheduler_conf:
warmup_steps: 8000
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2

View File

@ -0,0 +1,162 @@
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
help_messge=$(cat << EOF
Usage: $0
Options:
--no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
--tgt (string): Which set to process, test or train.
EOF
)
SECONDS=0
tgt=Train #Train or Eval
log "$0 $*"
echo $tgt
. ./utils/parse_options.sh
. ./path.sh
AliMeeting="${PWD}/dataset"
if [ $# -gt 2 ]; then
log "${help_message}"
exit 2
fi
if [ ! -d "${AliMeeting}" ]; then
log "Error: ${AliMeeting} is empty."
exit 2
fi
# To absolute path
AliMeeting=$(cd ${AliMeeting}; pwd)
echo $AliMeeting
far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
near_raw_dir=${AliMeeting}/${tgt}_Ali_near/
far_dir=data/local/${tgt}_Ali_far
near_dir=data/local/${tgt}_Ali_near
far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
mkdir -p $far_single_speaker_dir
stage=1
stop_stage=4
mkdir -p $far_dir
mkdir -p $near_dir
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
log "stage 1:process alimeeting near dir"
find -L $near_raw_dir/audio_dir -iname "*.wav" > $near_dir/wavlist
awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist
n1_wav=$(wc -l < $near_dir/wavlist)
n2_text=$(wc -l < $near_dir/textgrid.flist)
log near file found $n1_wav wav and $n2_text text.
paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
# cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
#sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
sed -e 's///g' $near_dir/tmp1> $near_dir/tmp2
sed -e 's///g' $near_dir/tmp2> $near_dir/text
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
log "stage 2:process alimeeting far dir"
find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist
awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid
find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist
n1_wav=$(wc -l < $far_dir/wavlist)
n2_text=$(wc -l < $far_dir/textgrid.flist)
log far file found $n1_wav wav and $n2_text text.
paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
python local/alimeeting_process_overlap_force.py --path $far_dir \
--no-overlap false --mars True \
--overlap_length 0.8 --max_length 7
cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
sed -e 's///g' $far_dir/tmp2> $far_dir/tmp3
sed -e 's///g' $far_dir/tmp3> $far_dir/text
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
log "stage 3: finali data process"
local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
# remove space in text
for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
cp data/${x}/text data/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
> data/${x}/text
rm data/${x}/text.org
done
log "Successfully finished. [elapsed=${SECONDS}s]"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
cp -r $far_dir/* $far_single_speaker_dir
mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
./local/fix_data_dir.sh $far_single_speaker_dir
local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do
cp data/${x}/text data/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
> data/${x}/text
rm data/${x}/text.org
done
log "Successfully finished. [elapsed=${SECONDS}s]"
fi

View File

@ -0,0 +1,129 @@
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
help_messge=$(cat << EOF
Usage: $0
Options:
--no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
--tgt (string): Which set to process, test or train.
EOF
)
SECONDS=0
tgt=Train #Train or Eval
log "$0 $*"
echo $tgt
. ./utils/parse_options.sh
. ./path.sh
AliMeeting="${PWD}/dataset"
if [ $# -gt 2 ]; then
log "${help_message}"
exit 2
fi
if [ ! -d "${AliMeeting}" ]; then
log "Error: ${AliMeeting} is empty."
exit 2
fi
# To absolute path
AliMeeting=$(cd ${AliMeeting}; pwd)
echo $AliMeeting
far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
far_dir=data/local/${tgt}_Ali_far
far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
mkdir -p $far_single_speaker_dir
stage=1
stop_stage=3
mkdir -p $far_dir
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
log "stage 1:process alimeeting far dir"
find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist
awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid
find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist
n1_wav=$(wc -l < $far_dir/wavlist)
n2_text=$(wc -l < $far_dir/textgrid.flist)
log far file found $n1_wav wav and $n2_text text.
paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
python local/alimeeting_process_overlap_force.py --path $far_dir \
--no-overlap false --mars True \
--overlap_length 0.8 --max_length 7
cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
sed -e 's///g' $far_dir/tmp2> $far_dir/tmp3
sed -e 's///g' $far_dir/tmp3> $far_dir/text
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
log "stage 2: finali data process"
local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
# remove space in text
for x in ${tgt}_Ali_far; do
cp data/${x}/text data/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
> data/${x}/text
rm data/${x}/text.org
done
log "Successfully finished. [elapsed=${SECONDS}s]"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
log "stage 3:process alimeeting far dir (single speaker by oracal time strap)"
cp -r $far_dir/* $far_single_speaker_dir
mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
./local/fix_data_dir.sh $far_single_speaker_dir
local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do
cp data/${x}/text data/${x}/text.org
paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
> data/${x}/text
rm data/${x}/text.org
done
log "Successfully finished. [elapsed=${SECONDS}s]"
fi

View File

@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
"""
Process the textgrid files
"""
import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
import pdb
class Segment(object):
def __init__(self, uttid, spkr, stime, etime, text):
self.uttid = uttid
self.spkr = spkr
self.spkr_all = uttid+"-"+spkr
self.stime = round(stime, 2)
self.etime = round(etime, 2)
self.text = text
self.spk_text = {uttid+"-"+spkr: text}
def change_stime(self, time):
self.stime = time
def change_etime(self, time):
self.etime = time
def get_args():
parser = argparse.ArgumentParser(description="process the textgrid files")
parser.add_argument("--path", type=str, required=True, help="Data path")
parser.add_argument(
"--no-overlap",
type=strtobool,
default=False,
help="Whether to ignore the overlapping utterances.",
)
parser.add_argument(
"--max_length",
default=100000,
type=float,
help="overlap speech max time,if longger than max length should cut",
)
parser.add_argument(
"--overlap_length",
default=1,
type=float,
help="if length longer than max length, speech overlength shorter, is cut",
)
parser.add_argument(
"--mars",
type=strtobool,
default=False,
help="Whether to process mars data set.",
)
args = parser.parse_args()
return args
def preposs_overlap(segments,max_length,overlap_length):
new_segments = []
# init a helper list to store all overlap segments
tmp_segments = segments[0]
min_stime = segments[0].stime
max_etime = segments[0].etime
overlap_length_big = 1.5
max_length_big = 15
for i in range(1, len(segments)):
if segments[i].stime >= max_etime:
# doesn't overlap with preivous segments
new_segments.append(tmp_segments)
tmp_segments = segments[i]
min_stime = segments[i].stime
max_etime = segments[i].etime
else:
# overlap with previous segments
dur_time = max_etime - min_stime
if dur_time < max_length:
if min_stime > segments[i].stime:
min_stime = segments[i].stime
if max_etime < segments[i].etime:
max_etime = segments[i].etime
tmp_segments.stime = min_stime
tmp_segments.etime = max_etime
tmp_segments.text = tmp_segments.text + "src" + segments[i].text
spk_name =segments[i].uttid +"-" + segments[i].spkr
if spk_name in tmp_segments.spk_text:
tmp_segments.spk_text[spk_name] += segments[i].text
else:
tmp_segments.spk_text[spk_name] = segments[i].text
tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
else:
overlap_time = max_etime - segments[i].stime
if dur_time < max_length_big:
overlap_length_option = overlap_length
else:
overlap_length_option = overlap_length_big
if overlap_time > overlap_length_option:
if min_stime > segments[i].stime:
min_stime = segments[i].stime
if max_etime < segments[i].etime:
max_etime = segments[i].etime
tmp_segments.stime = min_stime
tmp_segments.etime = max_etime
tmp_segments.text = tmp_segments.text + "src" + segments[i].text
spk_name =segments[i].uttid +"-" + segments[i].spkr
if spk_name in tmp_segments.spk_text:
tmp_segments.spk_text[spk_name] += segments[i].text
else:
tmp_segments.spk_text[spk_name] = segments[i].text
tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
else:
new_segments.append(tmp_segments)
tmp_segments = segments[i]
min_stime = segments[i].stime
max_etime = segments[i].etime
return new_segments
def filter_overlap(segments):
new_segments = []
# init a helper list to store all overlap segments
tmp_segments = [segments[0]]
min_stime = segments[0].stime
max_etime = segments[0].etime
for i in range(1, len(segments)):
if segments[i].stime >= max_etime:
# doesn't overlap with preivous segments
if len(tmp_segments) == 1:
new_segments.append(tmp_segments[0])
# TODO: for multi-spkr asr, we can reset the stime/etime to
# min_stime/max_etime for generating a max length mixutre speech
tmp_segments = [segments[i]]
min_stime = segments[i].stime
max_etime = segments[i].etime
else:
# overlap with previous segments
tmp_segments.append(segments[i])
if min_stime > segments[i].stime:
min_stime = segments[i].stime
if max_etime < segments[i].etime:
max_etime = segments[i].etime
return new_segments
def main(args):
wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
# get the path of textgrid file for each utterance
utt2textgrid = {}
for line in textgrid_flist:
path = Path(line.strip())
uttid = path.stem
utt2textgrid[uttid] = path
# parse the textgrid file for each utterance
all_segments = []
for line in wav_scp:
uttid = line.strip().split(" ")[0]
uttid_part=uttid
if args.mars == True:
uttid_list = uttid.split("_")
uttid_part= uttid_list[0]+"_"+uttid_list[1]
if uttid_part not in utt2textgrid:
print("%s doesn't have transcription" % uttid)
continue
segments = []
tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
for i in range(tg.__len__()):
for j in range(tg[i].__len__()):
if tg[i][j].mark:
segments.append(
Segment(
uttid,
tg[i].name,
tg[i][j].minTime,
tg[i][j].maxTime,
tg[i][j].mark.strip(),
)
)
segments = sorted(segments, key=lambda x: x.stime)
if args.no_overlap:
segments = filter_overlap(segments)
else:
segments = preposs_overlap(segments,args.max_length,args.overlap_length)
all_segments += segments
wav_scp.close()
textgrid_flist.close()
segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
utt2spk_file_fifo = codecs.open(Path(args.path) / "utt2spk_all_fifo", "w", "utf-8")
for i in range(len(all_segments)):
utt_name = "%s-%s-%07d-%07d" % (
all_segments[i].uttid,
all_segments[i].spkr,
all_segments[i].stime * 100,
all_segments[i].etime * 100,
)
segments_file.write(
"%s %s %.2f %.2f\n"
% (
utt_name,
all_segments[i].uttid,
all_segments[i].stime,
all_segments[i].etime,
)
)
utt2spk_file.write(
"%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
)
utt2spk_file_fifo.write(
"%s %s\n" % (utt_name, all_segments[i].spkr_all)
)
text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
segments_file.close()
utt2spk_file.close()
text_file.close()
utt2spk_file_fifo.close()
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
"""
Process the textgrid files
"""
import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
import pdb
class Segment(object):
def __init__(self, uttid, spkr, stime, etime, text):
self.uttid = uttid
self.spkr = spkr
self.stime = round(stime, 2)
self.etime = round(etime, 2)
self.text = text
def change_stime(self, time):
self.stime = time
def change_etime(self, time):
self.etime = time
def get_args():
parser = argparse.ArgumentParser(description="process the textgrid files")
parser.add_argument("--path", type=str, required=True, help="Data path")
parser.add_argument(
"--no-overlap",
type=strtobool,
default=False,
help="Whether to ignore the overlapping utterances.",
)
parser.add_argument(
"--mars",
type=strtobool,
default=False,
help="Whether to process mars data set.",
)
args = parser.parse_args()
return args
def filter_overlap(segments):
new_segments = []
# init a helper list to store all overlap segments
tmp_segments = [segments[0]]
min_stime = segments[0].stime
max_etime = segments[0].etime
for i in range(1, len(segments)):
if segments[i].stime >= max_etime:
# doesn't overlap with preivous segments
if len(tmp_segments) == 1:
new_segments.append(tmp_segments[0])
# TODO: for multi-spkr asr, we can reset the stime/etime to
# min_stime/max_etime for generating a max length mixutre speech
tmp_segments = [segments[i]]
min_stime = segments[i].stime
max_etime = segments[i].etime
else:
# overlap with previous segments
tmp_segments.append(segments[i])
if min_stime > segments[i].stime:
min_stime = segments[i].stime
if max_etime < segments[i].etime:
max_etime = segments[i].etime
return new_segments
def main(args):
wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
# get the path of textgrid file for each utterance
utt2textgrid = {}
for line in textgrid_flist:
path = Path(line.strip())
uttid = path.stem
utt2textgrid[uttid] = path
# parse the textgrid file for each utterance
all_segments = []
for line in wav_scp:
uttid = line.strip().split(" ")[0]
uttid_part=uttid
if args.mars == True:
uttid_list = uttid.split("_")
uttid_part= uttid_list[0]+"_"+uttid_list[1]
if uttid_part not in utt2textgrid:
print("%s doesn't have transcription" % uttid)
continue
#pdb.set_trace()
segments = []
try:
tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
except:
pdb.set_trace()
for i in range(tg.__len__()):
for j in range(tg[i].__len__()):
if tg[i][j].mark:
segments.append(
Segment(
uttid,
tg[i].name,
tg[i][j].minTime,
tg[i][j].maxTime,
tg[i][j].mark.strip(),
)
)
segments = sorted(segments, key=lambda x: x.stime)
if args.no_overlap:
segments = filter_overlap(segments)
all_segments += segments
wav_scp.close()
textgrid_flist.close()
segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
for i in range(len(all_segments)):
utt_name = "%s-%s-%07d-%07d" % (
all_segments[i].uttid,
all_segments[i].spkr,
all_segments[i].stime * 100,
all_segments[i].etime * 100,
)
segments_file.write(
"%s %s %.2f %.2f\n"
% (
utt_name,
all_segments[i].uttid,
all_segments[i].stime,
all_segments[i].etime,
)
)
utt2spk_file.write(
"%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
)
text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
segments_file.close()
utt2spk_file.close()
text_file.close()
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,97 @@
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
# Apache 2.0.
# This program is a bit like ./sym2int.pl in that it applies a map
# to things in a file, but it's a bit more general in that it doesn't
# assume the things being mapped to are single tokens, they could
# be sequences of tokens. See the usage message.
$permissive = 0;
for ($x = 0; $x <= 2; $x++) {
if (@ARGV > 0 && $ARGV[0] eq "-f") {
shift @ARGV;
$field_spec = shift @ARGV;
if ($field_spec =~ m/^\d+$/) {
$field_begin = $field_spec - 1; $field_end = $field_spec - 1;
}
if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
if ($1 ne "") {
$field_begin = $1 - 1; # Change to zero-based indexing.
}
if ($2 ne "") {
$field_end = $2 - 1; # Change to zero-based indexing.
}
}
if (!defined $field_begin && !defined $field_end) {
die "Bad argument to -f option: $field_spec";
}
}
if (@ARGV > 0 && $ARGV[0] eq '--permissive') {
shift @ARGV;
# Mapping is optional (missing key is printed to output)
$permissive = 1;
}
}
if(@ARGV != 1) {
print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n";
print STDERR <<'EOF';
Usage: apply_map.pl [options] map <input >output
options: [-f <field-range> ] [--permissive]
This applies a map to some specified fields of some input text:
For each line in the map file: the first field is the thing we
map from, and the remaining fields are the sequence we map it to.
The -f (field-range) option says which fields of the input file the map
map should apply to.
If the --permissive option is supplied, fields which are not present
in the map will be left as they were.
Applies the map 'map' to all input text, where each line of the map
is interpreted as a map from the first field to the list of the other fields
Note: <field-range> can look like 4-5, or 4-, or 5-, or 1, it means the field
range in the input to apply the map to.
e.g.: echo A B | apply_map.pl a.txt
where a.txt is:
A a1 a2
B b
will produce:
a1 a2 b
EOF
exit(1);
}
($map_file) = @ARGV;
open(M, "<$map_file") || die "Error opening map file $map_file: $!";
while (<M>) {
@A = split(" ", $_);
@A >= 1 || die "apply_map.pl: empty line.";
$i = shift @A;
$o = join(" ", @A);
$map{$i} = $o;
}
while(<STDIN>) {
@A = split(" ", $_);
for ($x = 0; $x < @A; $x++) {
if ( (!defined $field_begin || $x >= $field_begin)
&& (!defined $field_end || $x <= $field_end)) {
$a = $A[$x];
if (!defined $map{$a}) {
if (!$permissive) {
die "apply_map.pl: undefined key $a in $map_file\n";
} else {
print STDERR "apply_map.pl: warning! missing key $a in $map_file\n";
}
} else {
$A[$x] = $map{$a};
}
}
}
print join(" ", @A) . "\n";
}

View File

@ -0,0 +1,146 @@
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0.
# 2014 David Snyder
# This script combines the data from multiple source directories into
# a single destination directory.
# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information
# about what these directories contain.
# Begin configuration section.
extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..."
skip_fix=false # skip the fix_data_dir.sh in the end
# End configuration section.
echo "$0 $@" # Print the command line for logging
if [ -f path.sh ]; then . ./path.sh; fi
. parse_options.sh || exit 1;
if [ $# -lt 2 ]; then
echo "Usage: combine_data.sh [--extra-files 'file1 file2'] <dest-data-dir> <src-data-dir1> <src-data-dir2> ..."
echo "Note, files that don't appear in all source dirs will not be combined,"
echo "with the exception of utt2uniq and segments, which are created where necessary."
exit 1
fi
dest=$1;
shift;
first_src=$1;
rm -r $dest 2>/dev/null || true
mkdir -p $dest;
export LC_ALL=C
for dir in $*; do
if [ ! -f $dir/utt2spk ]; then
echo "$0: no such file $dir/utt2spk"
exit 1;
fi
done
# Check that frame_shift are compatible, where present together with features.
dir_with_frame_shift=
for dir in $*; do
if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then
if [[ $dir_with_frame_shift ]] &&
! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then
echo "$0:error: different frame_shift in directories $dir and " \
"$dir_with_frame_shift. Cannot combine features."
exit 1;
fi
dir_with_frame_shift=$dir
fi
done
# W.r.t. utt2uniq file the script has different behavior compared to other files
# it is not compulsary for it to exist in src directories, but if it exists in
# even one it should exist in all. We will create the files where necessary
has_utt2uniq=false
for in_dir in $*; do
if [ -f $in_dir/utt2uniq ]; then
has_utt2uniq=true
break
fi
done
if $has_utt2uniq; then
# we are going to create an utt2uniq file in the destdir
for in_dir in $*; do
if [ ! -f $in_dir/utt2uniq ]; then
# we assume that utt2uniq is a one to one mapping
cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}'
else
cat $in_dir/utt2uniq
fi
done | sort -k1 > $dest/utt2uniq
echo "$0: combined utt2uniq"
else
echo "$0 [info]: not combining utt2uniq as it does not exist"
fi
# some of the old scripts might provide utt2uniq as an extrafile, so just remove it
extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g")
# segments are treated similarly to utt2uniq. If it exists in some, but not all
# src directories, then we generate segments where necessary.
has_segments=false
for in_dir in $*; do
if [ -f $in_dir/segments ]; then
has_segments=true
break
fi
done
if $has_segments; then
for in_dir in $*; do
if [ ! -f $in_dir/segments ]; then
echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
local/data/get_segments_for_data.sh $in_dir
else
cat $in_dir/segments
fi
done | sort -k1 > $dest/segments
echo "$0: combined segments"
else
echo "$0 [info]: not combining segments as it does not exist"
fi
for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do
exists_somewhere=false
absent_somewhere=false
for d in $*; do
if [ -f $d/$file ]; then
exists_somewhere=true
else
absent_somewhere=true
fi
done
if ! $absent_somewhere; then
set -o pipefail
( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1;
set +o pipefail
echo "$0: combined $file"
else
if ! $exists_somewhere; then
echo "$0 [info]: not combining $file as it does not exist"
else
echo "$0 [info]: **not combining $file as it does not exist everywhere**"
fi
fi
done
local/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
if [[ $dir_with_frame_shift ]]; then
cp $dir_with_frame_shift/frame_shift $dest
fi
if ! $skip_fix ; then
local/fix_data_dir.sh $dest || exit 1;
fi
exit 0

View File

@ -0,0 +1,91 @@
import editdistance
import sys
import os
from itertools import permutations
def load_transcripts(file_path):
trans_list = []
for one_line in open(file_path, "rt"):
meeting_id, trans = one_line.strip().split(" ")
trans_list.append((meeting_id.strip(), trans.strip()))
return trans_list
def calc_spk_trans(trans):
spk_trans_ = [x.strip() for x in trans.split("$")]
spk_trans = []
for i in range(len(spk_trans_)):
spk_trans.append((str(i), spk_trans_[i]))
return spk_trans
def calc_cer(ref_trans, hyp_trans):
ref_spk_trans = calc_spk_trans(ref_trans)
hyp_spk_trans = calc_spk_trans(hyp_trans)
ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
errors, counts, permutes = [], [], []
min_error = 0
cost_dict = {}
for perm in permutations(range(num_spk)):
flag = True
p_err, p_count = 0, 0
for idx, p in enumerate(perm):
if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
flag = False
break
cost_key = "{}-{}".format(idx, p)
if cost_key in cost_dict:
_e = cost_dict[cost_key]
else:
_e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
cost_dict[cost_key] = _e
if _e > min_error > 0:
flag = False
break
p_err += _e
p_count += len(ref_spk_trans[idx][1])
if flag:
if p_err < min_error or min_error == 0:
min_error = p_err
errors.append(p_err)
counts.append(p_count)
permutes.append(perm)
sd_cer = [(err, cnt, err/cnt, permute)
for err, cnt, permute in zip(errors, counts, permutes)]
# import ipdb;ipdb.set_trace()
best_rst = min(sd_cer, key=lambda x: x[2])
return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
def main():
ref=sys.argv[1]
hyp=sys.argv[2]
result_path=sys.argv[3]
ref_list = load_transcripts(ref)
hyp_list = load_transcripts(hyp)
result_file = open(result_path,'w')
error, count = 0, 0
for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
assert ref_id == hyp_id
mid = ref_id
dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
error, count = error + dist, count + length
result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
# print("{} {:.2f} {} {}".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
result_file.close()
# print("Sum/Avg: {:.2f}".format(error / count * 100.0))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,145 @@
#!/usr/bin/env bash
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
# Apache 2.0
# This script operates on a directory, such as in data/train/,
# that contains some subset of the following files:
# feats.scp
# wav.scp
# vad.scp
# spk2utt
# utt2spk
# text
#
# It copies to another directory, possibly adding a specified prefix or a suffix
# to the utterance and/or speaker names. Note, the recording-ids stay the same.
#
# begin configuration section
spk_prefix=
utt_prefix=
spk_suffix=
utt_suffix=
validate_opts= # should rarely be needed.
# end configuration section
. utils/parse_options.sh
if [ $# != 2 ]; then
echo "Usage: "
echo " $0 [options] <srcdir> <destdir>"
echo "e.g.:"
echo " $0 --spk-prefix=1- --utt-prefix=1- data/train data/train_1"
echo "Options"
echo " --spk-prefix=<prefix> # Prefix for speaker ids, default empty"
echo " --utt-prefix=<prefix> # Prefix for utterance ids, default empty"
echo " --spk-suffix=<suffix> # Suffix for speaker ids, default empty"
echo " --utt-suffix=<suffix> # Suffix for utterance ids, default empty"
exit 1;
fi
export LC_ALL=C
srcdir=$1
destdir=$2
if [ ! -f $srcdir/utt2spk ]; then
echo "copy_data_dir.sh: no such file $srcdir/utt2spk"
exit 1;
fi
if [ "$destdir" == "$srcdir" ]; then
echo "$0: this script requires <srcdir> and <destdir> to be different."
exit 1
fi
set -e;
mkdir -p $destdir
cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/utt_map
cat $srcdir/spk2utt | awk -v p=$spk_prefix -v s=$spk_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/spk_map
if [ ! -f $srcdir/utt2uniq ]; then
if [[ ! -z $utt_prefix || ! -z $utt_suffix ]]; then
cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $1);}' > $destdir/utt2uniq
fi
else
cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
fi
cat $srcdir/utt2spk | local/apply_map.pl -f 1 $destdir/utt_map | \
local/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
local/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
if [ -f $srcdir/feats.scp ]; then
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
fi
if [ -f $srcdir/vad.scp ]; then
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
fi
if [ -f $srcdir/segments ]; then
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
cp $srcdir/wav.scp $destdir
else # no segments->wav indexed by utt.
if [ -f $srcdir/wav.scp ]; then
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
fi
fi
if [ -f $srcdir/reco2file_and_channel ]; then
cp $srcdir/reco2file_and_channel $destdir/
fi
if [ -f $srcdir/text ]; then
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
fi
if [ -f $srcdir/utt2dur ]; then
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
fi
if [ -f $srcdir/utt2num_frames ]; then
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
fi
if [ -f $srcdir/reco2dur ]; then
if [ -f $srcdir/segments ]; then
cp $srcdir/reco2dur $destdir/reco2dur
else
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
fi
fi
if [ -f $srcdir/spk2gender ]; then
local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
fi
if [ -f $srcdir/cmvn.scp ]; then
local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
fi
for f in frame_shift stm glm ctm; do
if [ -f $srcdir/$f ]; then
cp $srcdir/$f $destdir
fi
done
rm $destdir/spk_map $destdir/utt_map
echo "$0: copied data from $srcdir to $destdir"
for f in feats.scp cmvn.scp vad.scp utt2lang utt2uniq utt2dur utt2num_frames text wav.scp reco2file_and_channel frame_shift stm glm ctm; do
if [ -f $destdir/$f ] && [ ! -f $srcdir/$f ]; then
echo "$0: file $f exists in dest $destdir but not in src $srcdir. Moving it to"
echo " ... $destdir/.backup/$f"
mkdir -p $destdir/.backup
mv $destdir/$f $destdir/.backup/
fi
done
[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
local/validate_data_dir.sh $validate_opts $destdir

View File

@ -0,0 +1,143 @@
#!/usr/bin/env bash
# Copyright 2016 Johns Hopkins University (author: Daniel Povey)
# 2018 Andrea Carmantini
# Apache 2.0
# This script operates on a data directory, such as in data/train/, and adds the
# reco2dur file if it does not already exist. The file 'reco2dur' maps from
# recording to the duration of the recording in seconds. This script works it
# out from the 'wav.scp' file, or, if utterance-ids are the same as recording-ids, from the
# utt2dur file (it first tries interrogating the headers, and if this fails, it reads the wave
# files in entirely.)
# We could use durations from segments file, but that's not the duration of the recordings
# but the sum of utterance lenghts (silence in between could be excluded from segments)
# For sum of utterance lenghts:
# awk 'FNR==NR{uttdur[$1]=$2;next}
# { for(i=2;i<=NF;i++){dur+=uttdur[$i];}
# print $1 FS dur; dur=0 }' $data/utt2dur $data/reco2utt
frame_shift=0.01
cmd=run.pl
nj=4
. utils/parse_options.sh
. ./path.sh
if [ $# != 1 ]; then
echo "Usage: $0 [options] <datadir>"
echo "e.g.:"
echo " $0 data/train"
echo " Options:"
echo " --frame-shift # frame shift in seconds. Only relevant when we are"
echo " # getting duration from feats.scp (default: 0.01). "
exit 1
fi
export LC_ALL=C
data=$1
if [ -s $data/reco2dur ] && \
[ $(wc -l < $data/wav.scp) -eq $(wc -l < $data/reco2dur) ]; then
echo "$0: $data/reco2dur already exists with the expected length. We won't recompute it."
exit 0;
fi
if [ -s $data/utt2dur ] && \
[ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ] && \
[ ! -s $data/segments ]; then
echo "$0: $data/wav.scp indexed by utt-id; copying utt2dur to reco2dur"
cp $data/utt2dur $data/reco2dur && exit 0;
elif [ -f $data/wav.scp ]; then
echo "$0: obtaining durations from recordings"
# if the wav.scp contains only lines of the form
# utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph |
if cat $data/wav.scp | perl -e '
while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space.
@A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
$A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
$reco = $A[0]; $sphere_file = $A[4];
if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
$sample_rate = -1; $sample_count = -1;
for ($n = 0; $n <= 30; $n++) {
$line = <F>;
if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
if ($line =~ m/end_head/) { break; }
}
close(F);
if ($sample_rate == -1 || $sample_count == -1) {
die "could not parse sphere header from $sphere_file";
}
$duration = $sample_count * 1.0 / $sample_rate;
print "$reco $duration\n";
} ' > $data/reco2dur; then
echo "$0: successfully obtained recording lengths from sphere-file headers"
else
echo "$0: could not get recording lengths from sphere-file headers, using wav-to-duration"
if ! command -v wav-to-duration >/dev/null; then
echo "$0: wav-to-duration is not on your path"
exit 1;
fi
read_entire_file=false
if grep -q 'sox.*speed' $data/wav.scp; then
read_entire_file=true
echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
echo "... perturb_data_dir_speed_3way.sh."
fi
num_recos=$(wc -l <$data/wav.scp)
if [ $nj -gt $num_recos ]; then
nj=$num_recos
fi
temp_data_dir=$data/wav${nj}split
wavscps=$(for n in `seq $nj`; do echo $temp_data_dir/$n/wav.scp; done)
subdirs=$(for n in `seq $nj`; do echo $temp_data_dir/$n; done)
if ! mkdir -p $subdirs >&/dev/null; then
for n in `seq $nj`; do
mkdir -p $temp_data_dir/$n
done
fi
utils/split_scp.pl $data/wav.scp $wavscps
$cmd JOB=1:$nj $data/log/get_reco_durations.JOB.log \
wav-to-duration --read-entire-file=$read_entire_file \
scp:$temp_data_dir/JOB/wav.scp ark,t:$temp_data_dir/JOB/reco2dur || \
{ echo "$0: there was a problem getting the durations"; exit 1; } # This could
for n in `seq $nj`; do
cat $temp_data_dir/$n/reco2dur
done > $data/reco2dur
fi
rm -r $temp_data_dir
else
echo "$0: Expected $data/wav.scp to exist"
exit 1
fi
len1=$(wc -l < $data/wav.scp)
len2=$(wc -l < $data/reco2dur)
if [ "$len1" != "$len2" ]; then
echo "$0: warning: length of reco2dur does not equal that of wav.scp, $len2 != $len1"
if [ $len1 -gt $[$len2*2] ]; then
echo "$0: less than half of recordings got a duration: failing."
exit 1
fi
fi
echo "$0: computed $data/reco2dur"
exit 0

View File

@ -0,0 +1,29 @@
#!/usr/bin/env bash
# This script operates on a data directory, such as in data/train/,
# and writes new segments to stdout. The file 'segments' maps from
# utterance to time offsets into a recording, with the format:
# <utterance-id> <recording-id> <segment-begin> <segment-end>
# This script assumes utterance and recording ids are the same (i.e., that
# wav.scp is indexed by utterance), and uses durations from 'utt2dur',
# created if necessary by get_utt2dur.sh.
. ./path.sh
if [ $# != 1 ]; then
echo "Usage: $0 [options] <datadir>"
echo "e.g.:"
echo " $0 data/train > data/train/segments"
exit 1
fi
data=$1
if [ ! -s $data/utt2dur ]; then
local/data/get_utt2dur.sh $data 1>&2 || exit 1;
fi
# <utt-id> <utt-id> 0 <utt-dur>
awk '{ print $1, $1, 0, $2 }' $data/utt2dur
exit 0

View File

@ -0,0 +1,135 @@
#!/usr/bin/env bash
# Copyright 2016 Johns Hopkins University (author: Daniel Povey)
# Apache 2.0
# This script operates on a data directory, such as in data/train/, and adds the
# utt2dur file if it does not already exist. The file 'utt2dur' maps from
# utterance to the duration of the utterance in seconds. This script works it
# out from the 'segments' file, or, if not present, from the wav.scp file (it
# first tries interrogating the headers, and if this fails, it reads the wave
# files in entirely.)
frame_shift=0.01
cmd=run.pl
nj=4
read_entire_file=false
. utils/parse_options.sh
. ./path.sh
if [ $# != 1 ]; then
echo "Usage: $0 [options] <datadir>"
echo "e.g.:"
echo " $0 data/train"
echo " Options:"
echo " --frame-shift # frame shift in seconds. Only relevant when we are"
echo " # getting duration from feats.scp, and only if the "
echo " # file frame_shift does not exist (default: 0.01). "
exit 1
fi
export LC_ALL=C
data=$1
if [ -s $data/utt2dur ] && \
[ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ]; then
echo "$0: $data/utt2dur already exists with the expected length. We won't recompute it."
exit 0;
fi
if [ -s $data/segments ]; then
echo "$0: working out $data/utt2dur from $data/segments"
awk '{len=$4-$3; print $1, len;}' < $data/segments > $data/utt2dur
elif [[ -s $data/frame_shift && -f $data/utt2num_frames ]]; then
echo "$0: computing $data/utt2dur from $data/{frame_shift,utt2num_frames}."
frame_shift=$(cat $data/frame_shift) || exit 1
# The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
awk -v fs=$frame_shift '{ $2=($2+1.5)*fs; print }' <$data/utt2num_frames >$data/utt2dur
elif [ -f $data/wav.scp ]; then
echo "$0: segments file does not exist so getting durations from wave files"
# if the wav.scp contains only lines of the form
# utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph |
if perl <$data/wav.scp -e '
while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space.
@A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
$A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
$utt = $A[0]; $sphere_file = $A[4];
if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
$sample_rate = -1; $sample_count = -1;
for ($n = 0; $n <= 30; $n++) {
$line = <F>;
if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
if ($line =~ m/end_head/) { break; }
}
close(F);
if ($sample_rate == -1 || $sample_count == -1) {
die "could not parse sphere header from $sphere_file";
}
$duration = $sample_count * 1.0 / $sample_rate;
print "$utt $duration\n";
} ' > $data/utt2dur; then
echo "$0: successfully obtained utterance lengths from sphere-file headers"
else
echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration"
if ! command -v wav-to-duration >/dev/null; then
echo "$0: wav-to-duration is not on your path"
exit 1;
fi
if grep -q 'sox.*speed' $data/wav.scp; then
read_entire_file=true
echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
echo "... It is much faster if you call get_utt2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
echo "... perturb_data_dir_speed_3way.sh."
fi
num_utts=$(wc -l <$data/utt2spk)
if [ $nj -gt $num_utts ]; then
nj=$num_utts
fi
local/data/split_data.sh --per-utt $data $nj
sdata=$data/split${nj}utt
$cmd JOB=1:$nj $data/log/get_durations.JOB.log \
wav-to-duration --read-entire-file=$read_entire_file \
scp:$sdata/JOB/wav.scp ark,t:$sdata/JOB/utt2dur || \
{ echo "$0: there was a problem getting the durations"; exit 1; }
for n in `seq $nj`; do
cat $sdata/$n/utt2dur
done > $data/utt2dur
fi
elif [ -f $data/feats.scp ]; then
echo "$0: wave file does not exist so getting durations from feats files"
if [[ -s $data/frame_shift ]]; then
frame_shift=$(cat $data/frame_shift) || exit 1
echo "$0: using frame_shift=$frame_shift from file $data/frame_shift"
fi
# The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
feat-to-len scp:$data/feats.scp ark,t:- |
awk -v frame_shift=$frame_shift '{print $1, ($2+1.5)*frame_shift}' >$data/utt2dur
else
echo "$0: Expected $data/wav.scp, $data/segments or $data/feats.scp to exist"
exit 1
fi
len1=$(wc -l < $data/utt2spk)
len2=$(wc -l < $data/utt2dur)
if [ "$len1" != "$len2" ]; then
echo "$0: warning: length of utt2dur does not equal that of utt2spk, $len2 != $len1"
if [ $len1 -gt $[$len2*2] ]; then
echo "$0: less than half of utterances got a duration: failing."
exit 1
fi
fi
echo "$0: computed $data/utt2dur"
exit 0

View File

@ -0,0 +1,160 @@
#!/usr/bin/env bash
# Copyright 2010-2013 Microsoft Corporation
# Johns Hopkins University (Author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
split_per_spk=true
if [ "$1" == "--per-utt" ]; then
split_per_spk=false
shift
fi
if [ $# != 2 ]; then
echo "Usage: $0 [--per-utt] <data-dir> <num-to-split>"
echo "E.g.: $0 data/train 50"
echo "It creates its output in e.g. data/train/split50/{1,2,3,...50}, or if the "
echo "--per-utt option was given, in e.g. data/train/split50utt/{1,2,3,...50}."
echo ""
echo "This script will not split the data-dir if it detects that the output is newer than the input."
echo "By default it splits per speaker (so each speaker is in only one split dir),"
echo "but with the --per-utt option it will ignore the speaker information while splitting."
exit 1
fi
data=$1
numsplit=$2
if ! [ "$numsplit" -gt 0 ]; then
echo "Invalid num-split argument $numsplit";
exit 1;
fi
if $split_per_spk; then
warning_opt=
else
# suppress warnings from filter_scps.pl about 'some input lines were output
# to multiple files'.
warning_opt="--no-warn"
fi
n=0;
feats=""
wavs=""
utt2spks=""
texts=""
nu=`cat $data/utt2spk | wc -l`
nf=`cat $data/feats.scp 2>/dev/null | wc -l`
nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
echo "** use local/fix_data_dir.sh $data to fix this."
fi
if [ -f $data/text ] && [ $nu -ne $nt ]; then
echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
echo "** use local/fix_data_dir.sh to fix this."
fi
if $split_per_spk; then
utt2spk_opt="--utt2spk=$data/utt2spk"
utt=""
else
utt2spk_opt=
utt="utt"
fi
s1=$data/split${numsplit}${utt}/1
if [ ! -d $s1 ]; then
need_to_split=true
else
need_to_split=false
for f in utt2spk spk2utt spk2warp feats.scp text wav.scp cmvn.scp spk2gender \
vad.scp segments reco2file_and_channel utt2lang; do
if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then
need_to_split=true
fi
done
fi
if ! $need_to_split; then
exit 0;
fi
utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done)
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done)
# if this mkdir fails due to argument-list being too long, iterate.
if ! mkdir -p $directories >&/dev/null; then
for n in `seq $numsplit`; do
mkdir -p $data/split${numsplit}${utt}/$n
done
fi
# If lockfile is not installed, just don't lock it. It's not a big deal.
which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock
trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM
utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
for n in `seq $numsplit`; do
dsn=$data/split${numsplit}${utt}/$n
local/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
done
maybe_wav_scp=
if [ ! -f $data/segments ]; then
maybe_wav_scp=wav.scp # If there is no segments file, then wav file is
# indexed per utt.
fi
# split some things that are indexed by utterance.
for f in feats.scp text vad.scp utt2lang $maybe_wav_scp utt2dur utt2num_frames; do
if [ -f $data/$f ]; then
utils/filter_scps.pl JOB=1:$numsplit \
$data/split${numsplit}${utt}/JOB/utt2spk $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
fi
done
# split some things that are indexed by speaker
for f in spk2gender spk2warp cmvn.scp; do
if [ -f $data/$f ]; then
utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
$data/split${numsplit}${utt}/JOB/spk2utt $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
fi
done
if [ -f $data/segments ]; then
utils/filter_scps.pl JOB=1:$numsplit \
$data/split${numsplit}${utt}/JOB/utt2spk $data/segments $data/split${numsplit}${utt}/JOB/segments || exit 1
for n in `seq $numsplit`; do
dsn=$data/split${numsplit}${utt}/$n
awk '{print $2;}' $dsn/segments | sort | uniq > $dsn/tmp.reco # recording-ids.
done
if [ -f $data/reco2file_and_channel ]; then
utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
$data/split${numsplit}${utt}/JOB/tmp.reco $data/reco2file_and_channel \
$data/split${numsplit}${utt}/JOB/reco2file_and_channel || exit 1
fi
if [ -f $data/wav.scp ]; then
utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
$data/split${numsplit}${utt}/JOB/tmp.reco $data/wav.scp \
$data/split${numsplit}${utt}/JOB/wav.scp || exit 1
fi
for f in $data/split${numsplit}${utt}/*/tmp.reco; do rm $f; done
fi
exit 0

View File

@ -0,0 +1,6 @@
from modelscope.hub.snapshot_download import snapshot_download
import sys
cache_dir = sys.argv[1]
model_dir = snapshot_download('damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch', cache_dir=cache_dir)

View File

@ -0,0 +1,22 @@
import sys
if __name__=="__main__":
uttid_path=sys.argv[1]
src_path=sys.argv[2]
tgt_path=sys.argv[3]
uttid_file=open(uttid_path,'r')
uttid_line=uttid_file.readlines()
uttid_file.close()
ori_utt2spk_all_fifo_file=open(src_path+'/utt2spk_all_fifo','r')
ori_utt2spk_all_fifo_line=ori_utt2spk_all_fifo_file.readlines()
ori_utt2spk_all_fifo_file.close()
new_utt2spk_all_fifo_file=open(tgt_path+'/utt2spk_all_fifo','w')
uttid_list=[]
for line in uttid_line:
uttid_list.append(line.strip())
for line in ori_utt2spk_all_fifo_line:
if line.strip().split(' ')[0] in uttid_list:
new_utt2spk_all_fifo_file.write(line)
new_utt2spk_all_fifo_file.close()

View File

@ -0,0 +1,215 @@
#!/usr/bin/env bash
# This script makes sure that only the segments present in
# all of "feats.scp", "wav.scp" [if present], segments [if present]
# text, and utt2spk are present in any of them.
# It puts the original contents of data-dir into
# data-dir/.backup
cmd="$@"
utt_extra_files=
spk_extra_files=
. utils/parse_options.sh
if [ $# != 1 ]; then
echo "Usage: utils/data/fix_data_dir.sh <data-dir>"
echo "e.g.: utils/data/fix_data_dir.sh data/train"
echo "This script helps ensure that the various files in a data directory"
echo "are correctly sorted and filtered, for example removing utterances"
echo "that have no features (if feats.scp is present)"
exit 1
fi
data=$1
if [ -f $data/images.scp ]; then
image/fix_data_dir.sh $cmd
exit $?
fi
mkdir -p $data/.backup
[ ! -d $data ] && echo "$0: no such directory $data" && exit 1;
[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1;
set -e -o pipefail -u
tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
export LC_ALL=C
function check_sorted {
file=$1
sort -k1,1 -u <$file >$file.tmp
if ! cmp -s $file $file.tmp; then
echo "$0: file $1 is not in sorted order or not unique, sorting it"
mv $file.tmp $file
else
rm $file.tmp
fi
}
for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \
reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do
if [ -f $data/$x ]; then
cp $data/$x $data/.backup/$x
check_sorted $data/$x
fi
done
function filter_file {
filter=$1
file_to_filter=$2
cp $file_to_filter ${file_to_filter}.tmp
utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter
if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then
length1=$(cat ${file_to_filter}.tmp | wc -l)
length2=$(cat ${file_to_filter} | wc -l)
if [ $length1 -ne $length2 ]; then
echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter."
fi
fi
rm $file_to_filter.tmp
}
function filter_recordings {
# We call this once before the stage when we filter on utterance-id, and once
# after.
if [ -f $data/segments ]; then
# We have a segments file -> we need to filter this and the file wav.scp, and
# reco2file_and_utt, if it exists, to make sure they have the same list of
# recording-ids.
if [ ! -f $data/wav.scp ]; then
echo "$0: $data/segments exists but not $data/wav.scp"
exit 1;
fi
awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings
n1=$(cat $tmpdir/recordings | wc -l)
[ ! -s $tmpdir/recordings ] && \
echo "Empty list of recordings (bad file $data/segments)?" && exit 1;
utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp
mv $tmpdir/recordings.tmp $tmpdir/recordings
cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
filter_file $tmpdir/recordings $data/segments
cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
rm $data/segments.tmp
filter_file $tmpdir/recordings $data/wav.scp
[ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel
[ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur
true
fi
}
function filter_speakers {
# throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
for s in cmvn.scp spk2gender; do
f=$data/$s
if [ -f $f ]; then
filter_file $f $tmpdir/speakers
fi
done
filter_file $tmpdir/speakers $data/spk2utt
local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
for s in cmvn.scp spk2gender $spk_extra_files; do
f=$data/$s
if [ -f $f ]; then
filter_file $tmpdir/speakers $f
fi
done
}
function filter_utts {
cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
! cat $data/utt2spk | sort | cmp - $data/utt2spk && \
echo "utt2spk is not in sorted order (fix this yourself)" && exit 1;
! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \
echo "utt2spk is not in sorted order when sorted first on speaker-id " && \
echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
! cat $data/spk2utt | sort | cmp - $data/spk2utt && \
echo "spk2utt is not in sorted order (fix this yourself)" && exit 1;
if [ -f $data/utt2uniq ]; then
! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \
echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1;
fi
maybe_wav=
maybe_reco2dur=
[ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist.
[ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts
maybe_utt2dur=
if [ -f $data/utt2dur ]; then
cat $data/utt2dur | \
awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2dur.ok || exit 1
maybe_utt2dur=utt2dur.ok
fi
maybe_utt2num_frames=
if [ -f $data/utt2num_frames ]; then
cat $data/utt2num_frames | \
awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2num_frames.ok || exit 1
maybe_utt2num_frames=utt2num_frames.ok
fi
for x in feats.scp text segments utt2lang $maybe_wav $maybe_utt2dur $maybe_utt2num_frames; do
if [ -f $data/$x ]; then
utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp
mv $tmpdir/utts.tmp $tmpdir/utts
fi
done
rm $data/utt2dur.ok 2>/dev/null || true
rm $data/utt2num_frames.ok 2>/dev/null || true
[ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \
rm $tmpdir/utts && exit 1;
if [ -f $data/utt2spk ]; then
new_nutts=$(cat $tmpdir/utts | wc -l)
old_nutts=$(cat $data/utt2spk | wc -l)
if [ $new_nutts -ne $old_nutts ]; then
echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts"
else
echo "fix_data_dir.sh: kept all $old_nutts utterances."
fi
fi
for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do
if [ -f $data/$x ]; then
cp $data/$x $data/.backup/$x
if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then
utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x
fi
fi
done
}
filter_recordings
filter_speakers
filter_utts
filter_speakers
filter_recordings
local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
echo "fix_data_dir.sh: old files are kept in $data/.backup"

View File

@ -0,0 +1,243 @@
#!/usr/bin/env python3
import argparse
import logging
from io import BytesIO
from pathlib import Path
from typing import Tuple, Optional
import kaldiio
import humanfriendly
import numpy as np
import resampy
import soundfile
from tqdm import tqdm
from typeguard import check_argument_types
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.read_text import read_2column_text
from funasr.fileio.sound_scp import SoundScpWriter
def humanfriendly_or_none(value: str):
if value in ("none", "None", "NONE"):
return None
return humanfriendly.parse_size(value)
def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
"""
>>> str2int_tuple('3,4,5')
(3, 4, 5)
"""
assert check_argument_types()
if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
return None
return tuple(map(int, integers.strip().split(",")))
def main():
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
logging.info(get_commandline_args())
parser = argparse.ArgumentParser(
description='Create waves list from "wav.scp"',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("scp")
parser.add_argument("outdir")
parser.add_argument(
"--name",
default="wav",
help="Specify the prefix word of output file name " 'such as "wav.scp"',
)
parser.add_argument("--segments", default=None)
parser.add_argument(
"--fs",
type=humanfriendly_or_none,
default=None,
help="If the sampling rate specified, " "Change the sampling rate.",
)
parser.add_argument("--audio-format", default="wav")
group = parser.add_mutually_exclusive_group()
group.add_argument("--ref-channels", default=None, type=str2int_tuple)
group.add_argument("--utt2ref-channels", default=None, type=str)
args = parser.parse_args()
out_num_samples = Path(args.outdir) / f"utt2num_samples"
if args.ref_channels is not None:
def utt2ref_channels(x) -> Tuple[int, ...]:
return args.ref_channels
elif args.utt2ref_channels is not None:
utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
chs_str = d[x]
return tuple(map(int, chs_str.split()))
else:
utt2ref_channels = None
Path(args.outdir).mkdir(parents=True, exist_ok=True)
out_wavscp = Path(args.outdir) / f"{args.name}.scp"
if args.segments is not None:
# Note: kaldiio supports only wav-pcm-int16le file.
loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
if args.audio_format.endswith("ark"):
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
fscp = out_wavscp.open("w")
else:
writer = SoundScpWriter(
args.outdir,
out_wavscp,
format=args.audio_format,
)
with out_num_samples.open("w") as fnum_samples:
for uttid, (rate, wave) in tqdm(loader):
# wave: (Time,) or (Time, Nmic)
if wave.ndim == 2 and utt2ref_channels is not None:
wave = wave[:, utt2ref_channels(uttid)]
if args.fs is not None and args.fs != rate:
# FIXME(kamo): To use sox?
wave = resampy.resample(
wave.astype(np.float64), rate, args.fs, axis=0
)
wave = wave.astype(np.int16)
rate = args.fs
if args.audio_format.endswith("ark"):
if "flac" in args.audio_format:
suf = "flac"
elif "wav" in args.audio_format:
suf = "wav"
else:
raise RuntimeError("wav.ark or flac")
# NOTE(kamo): Using extended ark format style here.
# This format is incompatible with Kaldi
kaldiio.save_ark(
fark,
{uttid: (wave, rate)},
scp=fscp,
append=True,
write_function=f"soundfile_{suf}",
)
else:
writer[uttid] = rate, wave
fnum_samples.write(f"{uttid} {len(wave)}\n")
else:
if args.audio_format.endswith("ark"):
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
else:
wavdir = Path(args.outdir) / f"data_{args.name}"
wavdir.mkdir(parents=True, exist_ok=True)
with Path(args.scp).open("r") as fscp, out_wavscp.open(
"w"
) as fout, out_num_samples.open("w") as fnum_samples:
for line in tqdm(fscp):
uttid, wavpath = line.strip().split(None, 1)
if wavpath.endswith("|"):
# Streaming input e.g. cat a.wav |
with kaldiio.open_like_kaldi(wavpath, "rb") as f:
with BytesIO(f.read()) as g:
wave, rate = soundfile.read(g, dtype=np.int16)
if wave.ndim == 2 and utt2ref_channels is not None:
wave = wave[:, utt2ref_channels(uttid)]
if args.fs is not None and args.fs != rate:
# FIXME(kamo): To use sox?
wave = resampy.resample(
wave.astype(np.float64), rate, args.fs, axis=0
)
wave = wave.astype(np.int16)
rate = args.fs
if args.audio_format.endswith("ark"):
if "flac" in args.audio_format:
suf = "flac"
elif "wav" in args.audio_format:
suf = "wav"
else:
raise RuntimeError("wav.ark or flac")
# NOTE(kamo): Using extended ark format style here.
# This format is incompatible with Kaldi
kaldiio.save_ark(
fark,
{uttid: (wave, rate)},
scp=fout,
append=True,
write_function=f"soundfile_{suf}",
)
else:
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
soundfile.write(owavpath, wave, rate)
fout.write(f"{uttid} {owavpath}\n")
else:
wave, rate = soundfile.read(wavpath, dtype=np.int16)
if wave.ndim == 2 and utt2ref_channels is not None:
wave = wave[:, utt2ref_channels(uttid)]
save_asis = False
elif args.audio_format.endswith("ark"):
save_asis = False
elif Path(wavpath).suffix == "." + args.audio_format and (
args.fs is None or args.fs == rate
):
save_asis = True
else:
save_asis = False
if save_asis:
# Neither --segments nor --fs are specified and
# the line doesn't end with "|",
# i.e. not using unix-pipe,
# only in this case,
# just using the original file as is.
fout.write(f"{uttid} {wavpath}\n")
else:
if args.fs is not None and args.fs != rate:
# FIXME(kamo): To use sox?
wave = resampy.resample(
wave.astype(np.float64), rate, args.fs, axis=0
)
wave = wave.astype(np.int16)
rate = args.fs
if args.audio_format.endswith("ark"):
if "flac" in args.audio_format:
suf = "flac"
elif "wav" in args.audio_format:
suf = "wav"
else:
raise RuntimeError("wav.ark or flac")
# NOTE(kamo): Using extended ark format style here.
# This format is not supported in Kaldi.
kaldiio.save_ark(
fark,
{uttid: (wave, rate)},
scp=fout,
append=True,
write_function=f"soundfile_{suf}",
)
else:
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
soundfile.write(owavpath, wave, rate)
fout.write(f"{uttid} {owavpath}\n")
fnum_samples.write(f"{uttid} {len(wave)}\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,142 @@
#!/usr/bin/env bash
set -euo pipefail
SECONDS=0
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
help_message=$(cat << EOF
Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
e.g.
$0 data/test/wav.scp data/test_format/
Format 'wav.scp': In short words,
changing "kaldi-datadir" to "modified-kaldi-datadir"
The 'wav.scp' format in kaldi is very flexible,
e.g. It can use unix-pipe as describing that wav file,
but it sometime looks confusing and make scripts more complex.
This tools creates actual wav files from 'wav.scp'
and also segments wav files using 'segments'.
Options
--fs <fs>
--segments <segments>
--nj <nj>
--cmd <cmd>
EOF
)
out_filename=wav.scp
cmd=utils/run.pl
nj=30
fs=none
segments=
ref_channels=
utt2ref_channels=
audio_format=wav
write_utt2num_samples=true
log "$0 $*"
. utils/parse_options.sh
if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
log "${help_message}"
log "Error: invalid command line arguments"
exit 1
fi
. ./path.sh # Setup the environment
scp=$1
if [ ! -f "${scp}" ]; then
log "${help_message}"
echo "$0: Error: No such file: ${scp}"
exit 1
fi
dir=$2
if [ $# -eq 2 ]; then
logdir=${dir}/logs
outdir=${dir}/data
elif [ $# -eq 3 ]; then
logdir=$3
outdir=${dir}/data
elif [ $# -eq 4 ]; then
logdir=$3
outdir=$4
fi
mkdir -p ${logdir}
rm -f "${dir}/${out_filename}"
opts=
if [ -n "${utt2ref_channels}" ]; then
opts="--utt2ref-channels ${utt2ref_channels} "
elif [ -n "${ref_channels}" ]; then
opts="--ref-channels ${ref_channels} "
fi
if [ -n "${segments}" ]; then
log "[info]: using ${segments}"
nutt=$(<${segments} wc -l)
nj=$((nj<nutt?nj:nutt))
split_segments=""
for n in $(seq ${nj}); do
split_segments="${split_segments} ${logdir}/segments.${n}"
done
utils/split_scp.pl "${segments}" ${split_segments}
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
local/format_wav_scp.py \
${opts} \
--fs ${fs} \
--audio-format "${audio_format}" \
"--segment=${logdir}/segments.JOB" \
"${scp}" "${outdir}/format.JOB"
else
log "[info]: without segments"
nutt=$(<${scp} wc -l)
nj=$((nj<nutt?nj:nutt))
split_scps=""
for n in $(seq ${nj}); do
split_scps="${split_scps} ${logdir}/wav.${n}.scp"
done
utils/split_scp.pl "${scp}" ${split_scps}
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
local/format_wav_scp.py \
${opts} \
--fs "${fs}" \
--audio-format "${audio_format}" \
"${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
fi
# Workaround for the NFS problem
ls ${outdir}/format.* > /dev/null
# concatenate the .scp files together.
for n in $(seq ${nj}); do
cat "${outdir}/format.${n}/wav.scp" || exit 1;
done > "${dir}/${out_filename}" || exit 1
if "${write_utt2num_samples}"; then
for n in $(seq ${nj}); do
cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
done > "${dir}/utt2num_samples" || exit 1
fi
log "Successfully finished. [elapsed=${SECONDS}s]"

View File

@ -0,0 +1,167 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import numpy as np
import sys
import os
import soundfile
from itertools import permutations
from sklearn.metrics.pairwise import cosine_similarity
from sklearn import cluster
def custom_spectral_clustering(affinity, min_n_clusters=2, max_n_clusters=4, refine=True,
threshold=0.995, laplacian_type="graph_cut"):
if refine:
# Symmetrization
affinity = np.maximum(affinity, np.transpose(affinity))
# Diffusion
affinity = np.matmul(affinity, np.transpose(affinity))
# Row-wise max normalization
row_max = affinity.max(axis=1, keepdims=True)
affinity = affinity / row_max
# a) Construct S and set diagonal elements to 0
affinity = affinity - np.diag(np.diag(affinity))
# b) Compute Laplacian matrix L and perform normalization:
degree = np.diag(np.sum(affinity, axis=1))
laplacian = degree - affinity
if laplacian_type == "random_walk":
degree_norm = np.diag(1 / (np.diag(degree) + 1e-10))
laplacian_norm = degree_norm.dot(laplacian)
else:
degree_half = np.diag(degree) ** 0.5 + 1e-15
laplacian_norm = laplacian / degree_half[:, np.newaxis] / degree_half
# c) Compute eigenvalues and eigenvectors of L_norm
eigenvalues, eigenvectors = np.linalg.eig(laplacian_norm)
eigenvalues = eigenvalues.real
eigenvectors = eigenvectors.real
index_array = np.argsort(eigenvalues)
eigenvalues = eigenvalues[index_array]
eigenvectors = eigenvectors[:, index_array]
# d) Compute the number of clusters k
k = min_n_clusters
for k in range(min_n_clusters, max_n_clusters + 1):
if eigenvalues[k] > threshold:
break
k = max(k, min_n_clusters)
spectral_embeddings = eigenvectors[:, :k]
# print(mid, k, eigenvalues[:10])
spectral_embeddings = spectral_embeddings / np.linalg.norm(spectral_embeddings, axis=1, ord=2, keepdims=True)
solver = cluster.KMeans(n_clusters=k, max_iter=1000, random_state=42)
solver.fit(spectral_embeddings)
return solver.labels_
if __name__ == "__main__":
path = sys.argv[1] # dump2/raw/Eval_Ali_far
raw_path = sys.argv[2] # data/local/Eval_Ali_far
threshold = float(sys.argv[3]) # 0.996
sv_threshold = float(sys.argv[4]) # 0.815
wav_scp_file = open(path+'/wav.scp', 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
segments_scp = segments_scp_file.readlines()
segments_scp_file.close()
segments_map = {}
for line in segments_scp:
line_list = line.strip().split(' ')
meeting = line_list[1]
seg = (float(line_list[-2]), float(line_list[-1]))
if meeting not in segments_map.keys():
segments_map[meeting] = [seg]
else:
segments_map[meeting].append(seg)
inference_sv_pipline = pipeline(
task=Tasks.speaker_verification,
model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
)
chunk_len = int(1.5*16000) # 1.5 seconds
hop_len = int(0.75*16000) # 0.75 seconds
os.system("mkdir -p " + path + "/cluster_profile_infer")
cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
meeting_map = {}
for line in raw_meeting_scp:
meeting = line.strip().split('\t')[0]
wav_path = line.strip().split('\t')[1]
wav = soundfile.read(wav_path)[0]
# take the first channel
if wav.ndim == 2:
wav=wav[:, 0]
# gen_seg_embedding
segments_list = segments_map[meeting]
# import ipdb;ipdb.set_trace()
all_seg_embedding_list = []
for seg in segments_list:
wav_seg = wav[int(seg[0] * 16000): int(seg[1] * 16000)]
wav_seg_len = wav_seg.shape[0]
i = 0
while i < wav_seg_len:
if i + chunk_len < wav_seg_len:
cur_wav_chunk = wav_seg[i: i+chunk_len]
else:
cur_wav_chunk=wav_seg[i: ]
# chunks under 0.2s are ignored
if cur_wav_chunk.shape[0] >= 0.2 * 16000:
cur_chunk_embedding = inference_sv_pipline(audio_in=cur_wav_chunk)["spk_embedding"]
all_seg_embedding_list.append(cur_chunk_embedding)
i += hop_len
all_seg_embedding = np.vstack(all_seg_embedding_list)
# all_seg_embedding (n, dim)
# compute affinity
affinity=cosine_similarity(all_seg_embedding)
affinity = np.maximum(affinity - sv_threshold, 0.0001) / (affinity.max() - sv_threshold)
# clustering
labels = custom_spectral_clustering(
affinity=affinity,
min_n_clusters=2,
max_n_clusters=4,
refine=True,
threshold=threshold,
laplacian_type="graph_cut")
cluster_dict={}
for j in range(labels.shape[0]):
if labels[j] not in cluster_dict.keys():
cluster_dict[labels[j]] = np.atleast_2d(all_seg_embedding[j])
else:
cluster_dict[labels[j]] = np.concatenate((cluster_dict[labels[j]], np.atleast_2d(all_seg_embedding[j])))
emb_list = []
# get cluster center
for k in cluster_dict.keys():
cluster_dict[k] = np.mean(cluster_dict[k], axis=0)
emb_list.append(cluster_dict[k])
spk_num = len(emb_list)
profile_for_infer = np.vstack(emb_list)
# save profile for each meeting
np.save(path + '/cluster_profile_infer/' + meeting + '.npy', profile_for_infer)
meeting_map[meeting] = (path + '/cluster_profile_infer/' + meeting + '.npy', spk_num)
cluster_spk_num_file.write(meeting + ' ' + str(spk_num) + '\n')
cluster_spk_num_file.flush()
cluster_spk_num_file.close()
profile_scp = open(path + "/cluster_profile_infer.scp", 'w')
for line in wav_scp:
uttid = line.strip().split(' ')[0]
meeting = uttid.split('-')[0]
profile_scp.write(uttid + ' ' + meeting_map[meeting][0] + '\n')
profile_scp.flush()
profile_scp.close()

View File

@ -0,0 +1,70 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import numpy as np
import sys
import os
import soundfile
if __name__=="__main__":
path = sys.argv[1] # dump2/raw/Eval_Ali_far
raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
segments_scp = segments_scp_file.readlines()
segments_scp_file.close()
oracle_emb_dir = path + '/oracle_embedding/'
os.system("mkdir -p " + oracle_emb_dir)
oracle_emb_scp_file = open(path+'/oracle_embedding.scp', 'w')
raw_wav_map = {}
for line in raw_meeting_scp:
meeting = line.strip().split('\t')[0]
wav_path = line.strip().split('\t')[1]
raw_wav_map[meeting] = wav_path
spk_map = {}
for line in segments_scp:
line_list = line.strip().split(' ')
meeting = line_list[1]
spk_id = line_list[0].split('_')[3]
spk = meeting + '_' + spk_id
time_start = float(line_list[-2])
time_end = float(line_list[-1])
if time_end - time_start > 0.5:
if spk not in spk_map.keys():
spk_map[spk] = [(int(time_start * 16000), int(time_end * 16000))]
else:
spk_map[spk].append((int(time_start * 16000), int(time_end * 16000)))
inference_sv_pipline = pipeline(
task=Tasks.speaker_verification,
model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
)
for spk in spk_map.keys():
meeting = spk.split('_SPK')[0]
wav_path = raw_wav_map[meeting]
wav = soundfile.read(wav_path)[0]
# take the first channel
if wav.ndim == 2:
wav = wav[:, 0]
all_seg_embedding_list=[]
# import ipdb;ipdb.set_trace()
for seg_time in spk_map[spk]:
if seg_time[0] < wav.shape[0] - 0.5 * 16000:
if seg_time[1] > wav.shape[0]:
cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: ])["spk_embedding"]
else:
cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: seg_time[1]])["spk_embedding"]
all_seg_embedding_list.append(cur_seg_embedding)
all_seg_embedding = np.vstack(all_seg_embedding_list)
spk_embedding = np.mean(all_seg_embedding, axis=0)
np.save(oracle_emb_dir + spk + '.npy', spk_embedding)
oracle_emb_scp_file.write(spk + ' ' + oracle_emb_dir + spk + '.npy' + '\n')
oracle_emb_scp_file.flush()
oracle_emb_scp_file.close()

View File

@ -0,0 +1,59 @@
import random
import numpy as np
import os
import sys
if __name__=="__main__":
path = sys.argv[1] # dump2/raw/Eval_Ali_far
wav_scp_file = open(path+"/wav.scp", 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
spk2id_file = open(path + "/spk2id", 'r')
spk2id = spk2id_file.readlines()
spk2id_file.close()
embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
embedding_scp = embedding_scp_file.readlines()
embedding_scp_file.close()
embedding_map = {}
for line in embedding_scp:
spk = line.strip().split(' ')[0]
if spk not in embedding_map.keys():
emb=np.load(line.strip().split(' ')[1])
embedding_map[spk] = emb
meeting_map_tmp = {}
global_spk_list = []
for line in spk2id:
line_list = line.strip().split(' ')
meeting = line_list[0].split('-')[0]
spk_id = line_list[0].split('-')[-1].split('_')[-1]
spk = meeting + '_' + spk_id
global_spk_list.append(spk)
if meeting in meeting_map_tmp.keys():
meeting_map_tmp[meeting].append(spk)
else:
meeting_map_tmp[meeting] = [spk]
meeting_map = {}
os.system('mkdir -p ' + path + '/oracle_profile_nopadding')
for meeting in meeting_map_tmp.keys():
emb_list = []
for i in range(len(meeting_map_tmp[meeting])):
spk = meeting_map_tmp[meeting][i]
emb_list.append(embedding_map[spk])
profile = np.vstack(emb_list)
np.save(path + '/oracle_profile_nopadding/' + meeting + '.npy', profile)
meeting_map[meeting] = path + '/oracle_profile_nopadding/' + meeting + '.npy'
profile_scp = open(path + '/oracle_profile_nopadding.scp', 'w')
profile_map_scp = open(path + '/oracle_profile_nopadding_spk_list', 'w')
for line in wav_scp:
uttid = line.strip().split(' ')[0]
meeting = uttid.split('-')[0]
profile_scp.write(uttid + ' ' + meeting_map[meeting] + '\n')
profile_map_scp.write(uttid + ' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
profile_scp.close()
profile_map_scp.close()

View File

@ -0,0 +1,68 @@
import random
import numpy as np
import os
import sys
if __name__=="__main__":
path = sys.argv[1] # dump2/raw/Train_Ali_far
wav_scp_file = open(path+"/wav.scp", 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
spk2id_file = open(path+"/spk2id", 'r')
spk2id = spk2id_file.readlines()
spk2id_file.close()
embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
embedding_scp = embedding_scp_file.readlines()
embedding_scp_file.close()
embedding_map = {}
for line in embedding_scp:
spk = line.strip().split(' ')[0]
if spk not in embedding_map.keys():
emb = np.load(line.strip().split(' ')[1])
embedding_map[spk] = emb
meeting_map_tmp = {}
global_spk_list = []
for line in spk2id:
line_list = line.strip().split(' ')
meeting = line_list[0].split('-')[0]
spk_id = line_list[0].split('-')[-1].split('_')[-1]
spk = meeting+'_' + spk_id
global_spk_list.append(spk)
if meeting in meeting_map_tmp.keys():
meeting_map_tmp[meeting].append(spk)
else:
meeting_map_tmp[meeting] = [spk]
for meeting in meeting_map_tmp.keys():
num = len(meeting_map_tmp[meeting])
if num < 4:
global_spk_list_tmp = global_spk_list[: ]
for spk in meeting_map_tmp[meeting]:
global_spk_list_tmp.remove(spk)
padding_spk = random.sample(global_spk_list_tmp, 4 - num)
meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk
meeting_map = {}
os.system('mkdir -p ' + path + '/oracle_profile_padding')
for meeting in meeting_map_tmp.keys():
emb_list = []
for i in range(len(meeting_map_tmp[meeting])):
spk = meeting_map_tmp[meeting][i]
emb_list.append(embedding_map[spk])
profile = np.vstack(emb_list)
np.save(path + '/oracle_profile_padding/' + meeting + '.npy',profile)
meeting_map[meeting] = path + '/oracle_profile_padding/' + meeting + '.npy'
profile_scp = open(path + '/oracle_profile_padding.scp', 'w')
profile_map_scp = open(path + '/oracle_profile_padding_spk_list', 'w')
for line in wav_scp:
uttid = line.strip().split(' ')[0]
meeting = uttid.split('-')[0]
profile_scp.write(uttid+' ' + meeting_map[meeting] + '\n')
profile_map_scp.write(uttid+' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
profile_scp.close()
profile_map_scp.close()

View File

@ -0,0 +1,116 @@
#!/usr/bin/env bash
# 2020 @kamo-naoyuki
# This file was copied from Kaldi and
# I deleted parts related to wav duration
# because we shouldn't use kaldi's command here
# and we don't need the files actually.
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
# 2014 Tom Ko
# 2018 Emotech LTD (author: Pawel Swietojanski)
# Apache 2.0
# This script operates on a directory, such as in data/train/,
# that contains some subset of the following files:
# wav.scp
# spk2utt
# utt2spk
# text
#
# It generates the files which are used for perturbing the speed of the original data.
export LC_ALL=C
set -euo pipefail
if [[ $# != 3 ]]; then
echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
echo "e.g.:"
echo " $0 0.9 data/train_si284 data/train_si284p"
exit 1
fi
factor=$1
srcdir=$2
destdir=$3
label="sp"
spk_prefix="${label}${factor}-"
utt_prefix="${label}${factor}-"
#check is sox on the path
! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
if [[ ! -f ${srcdir}/utt2spk ]]; then
echo "$0: no such file ${srcdir}/utt2spk"
exit 1;
fi
if [[ ${destdir} == "${srcdir}" ]]; then
echo "$0: this script requires <srcdir> and <destdir> to be different."
exit 1
fi
mkdir -p "${destdir}"
<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
if [[ ! -f ${srcdir}/utt2uniq ]]; then
<"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
else
<"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
fi
<"${srcdir}"/utt2spk local/apply_map.pl -f 1 "${destdir}"/utt_map | \
local/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
local/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
if [[ -f ${srcdir}/segments ]]; then
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
local/apply_map.pl -f 2 "${destdir}"/reco_map | \
awk -v factor="${factor}" \
'{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
>"${destdir}"/segments
local/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
> "${destdir}"/wav.scp
if [[ -f ${srcdir}/reco2file_and_channel ]]; then
local/apply_map.pl -f 1 "${destdir}"/reco_map \
<"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
fi
else # no segments->wav indexed by utterance.
if [[ -f ${srcdir}/wav.scp ]]; then
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
> "${destdir}"/wav.scp
fi
fi
if [[ -f ${srcdir}/text ]]; then
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
fi
if [[ -f ${srcdir}/spk2gender ]]; then
local/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
fi
if [[ -f ${srcdir}/utt2lang ]]; then
local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
fi
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
local/validate_data_dir.sh --no-feats --no-text "${destdir}"

View File

@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
"""
Process the textgrid files
"""
import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
import pdb
def get_args():
parser = argparse.ArgumentParser(description="process the textgrid files")
parser.add_argument("--path", type=str, required=True, help="Data path")
args = parser.parse_args()
return args
class Segment(object):
def __init__(self, uttid, text):
self.uttid = uttid
self.text = text
def main(args):
text = codecs.open(Path(args.path) / "text", "r", "utf-8")
spk2utt = codecs.open(Path(args.path) / "spk2utt", "r", "utf-8")
utt2spk = codecs.open(Path(args.path) / "utt2spk_all_fifo", "r", "utf-8")
spk2id = codecs.open(Path(args.path) / "spk2id", "w", "utf-8")
spkid_map = {}
meetingid_map = {}
for line in spk2utt:
spkid = line.strip().split(" ")[0]
meeting_id_list = spkid.split("_")[:3]
meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
if meeting_id not in meetingid_map:
meetingid_map[meeting_id] = 1
else:
meetingid_map[meeting_id] += 1
spkid_map[spkid] = meetingid_map[meeting_id]
spk2id.write("%s %s\n" % (spkid, meetingid_map[meeting_id]))
utt2spklist = {}
for line in utt2spk:
uttid = line.strip().split(" ")[0]
spkid = line.strip().split(" ")[1]
spklist = spkid.split("$")
tmp = []
for index in range(len(spklist)):
tmp.append(spkid_map[spklist[index]])
utt2spklist[uttid] = tmp
# parse the textgrid file for each utterance
all_segments = []
for line in text:
uttid = line.strip().split(" ")[0]
context = line.strip().split(" ")[1]
spklist = utt2spklist[uttid]
length_text = len(context)
cnt = 0
tmp_text = ""
for index in range(length_text):
if context[index] != "$":
tmp_text += str(spklist[cnt])
else:
tmp_text += "$"
cnt += 1
tmp_seg = Segment(uttid,tmp_text)
all_segments.append(tmp_seg)
text.close()
utt2spk.close()
spk2utt.close()
spk2id.close()
text_id = codecs.open(Path(args.path) / "text_id", "w", "utf-8")
for i in range(len(all_segments)):
uttid_tmp = all_segments[i].uttid
text_tmp = all_segments[i].text
text_id.write("%s %s\n" % (uttid_tmp, text_tmp))
text_id.close()
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,24 @@
import sys
if __name__=="__main__":
path=sys.argv[1]
text_id_old_file=open(path+"/text_id",'r')
text_id_old=text_id_old_file.readlines()
text_id_old_file.close()
text_id=open(path+"/text_id_train",'w')
for line in text_id_old:
uttid=line.strip().split(' ')[0]
old_id=line.strip().split(' ')[1]
pre_id='0'
new_id_list=[]
for i in old_id:
if i == '$':
new_id_list.append(pre_id)
else:
new_id_list.append(str(int(i)-1))
pre_id=str(int(i)-1)
new_id_list.append(pre_id)
new_id=' '.join(new_id_list)
text_id.write(uttid+' '+new_id+'\n')
text_id.close()

View File

@ -0,0 +1,55 @@
import sys
if __name__ == "__main__":
path=sys.argv[1]
text_scp_file = open(path + '/text', 'r')
text_scp = text_scp_file.readlines()
text_scp_file.close()
text_id_scp_file = open(path + '/text_id', 'r')
text_id_scp = text_id_scp_file.readlines()
text_id_scp_file.close()
text_spk_merge_file = open(path + '/text_spk_merge', 'w')
assert len(text_scp) == len(text_id_scp)
meeting_map = {} # {meeting_id: [(start_time, text, text_id), (start_time, text, text_id), ...]}
for i in range(len(text_scp)):
text_line = text_scp[i].strip().split(' ')
text_id_line = text_id_scp[i].strip().split(' ')
assert text_line[0] == text_id_line[0]
if len(text_line) > 1:
uttid = text_line[0]
text = text_line[1]
text_id = text_id_line[1]
meeting_id = uttid.split('-')[0]
start_time = int(uttid.split('-')[-2])
if meeting_id not in meeting_map:
meeting_map[meeting_id] = [(start_time,text,text_id)]
else:
meeting_map[meeting_id].append((start_time,text,text_id))
for meeting_id in sorted(meeting_map.keys()):
cur_meeting_list = sorted(meeting_map[meeting_id], key=lambda x: x[0])
text_spk_merge_map = {} #{1: text1, 2: text2, ...}
for cur_utt in cur_meeting_list:
cur_text = cur_utt[1]
cur_text_id = cur_utt[2]
assert len(cur_text)==len(cur_text_id)
if len(cur_text) != 0:
cur_text_split = cur_text.split('$')
cur_text_id_split = cur_text_id.split('$')
assert len(cur_text_split) == len(cur_text_id_split)
for i in range(len(cur_text_split)):
if len(cur_text_split[i]) != 0:
spk_id = int(cur_text_id_split[i][0])
if spk_id not in text_spk_merge_map.keys():
text_spk_merge_map[spk_id] = cur_text_split[i]
else:
text_spk_merge_map[spk_id] += cur_text_split[i]
text_spk_merge_list = []
for spk_id in sorted(text_spk_merge_map.keys()):
text_spk_merge_list.append(text_spk_merge_map[spk_id])
text_spk_merge_file.write(meeting_id + ' ' + '$'.join(text_spk_merge_list) + '\n')
text_spk_merge_file.flush()
text_spk_merge_file.close()

View File

@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
"""
Process the textgrid files
"""
import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
import pdb
import numpy as np
import sys
import math
class Segment(object):
def __init__(self, uttid, spkr, stime, etime, text):
self.uttid = uttid
self.spkr = spkr
self.stime = round(stime, 2)
self.etime = round(etime, 2)
self.text = text
def change_stime(self, time):
self.stime = time
def change_etime(self, time):
self.etime = time
def get_args():
parser = argparse.ArgumentParser(description="process the textgrid files")
parser.add_argument("--path", type=str, required=True, help="Data path")
args = parser.parse_args()
return args
def main(args):
textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
segment_file = codecs.open(Path(args.path)/"segments", "w", "utf-8")
utt2spk = codecs.open(Path(args.path)/"utt2spk", "w", "utf-8")
# get the path of textgrid file for each utterance
for line in textgrid_flist:
line_array = line.strip().split(" ")
path = Path(line_array[1])
uttid = line_array[0]
try:
tg = textgrid.TextGrid.fromFile(path)
except:
pdb.set_trace()
num_spk = tg.__len__()
spk2textgrid = {}
spk2weight = {}
weight2spk = {}
cnt = 2
xmax = 0
for i in range(tg.__len__()):
spk_name = tg[i].name
if spk_name not in spk2weight:
spk2weight[spk_name] = cnt
weight2spk[cnt] = spk_name
cnt = cnt * 2
segments = []
for j in range(tg[i].__len__()):
if tg[i][j].mark:
if xmax < tg[i][j].maxTime:
xmax = tg[i][j].maxTime
segments.append(
Segment(
uttid,
tg[i].name,
tg[i][j].minTime,
tg[i][j].maxTime,
tg[i][j].mark.strip(),
)
)
segments = sorted(segments, key=lambda x: x.stime)
spk2textgrid[spk_name] = segments
olp_label = np.zeros((num_spk, int(xmax/0.01)), dtype=np.int32)
for spkid in spk2weight.keys():
weight = spk2weight[spkid]
segments = spk2textgrid[spkid]
idx = int(math.log2(weight) )- 1
for i in range(len(segments)):
stime = segments[i].stime
etime = segments[i].etime
olp_label[idx, int(stime/0.01): int(etime/0.01)] = weight
sum_label = olp_label.sum(axis=0)
stime = 0
pre_value = 0
for pos in range(sum_label.shape[0]):
if sum_label[pos] in weight2spk:
if pre_value in weight2spk:
if sum_label[pos] != pre_value:
spkids = weight2spk[pre_value]
spkid_array = spkids.split("_")
spkid = spkid_array[-1]
#spkid = uttid+spkid
if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
stime = pos
pre_value = sum_label[pos]
else:
stime = pos
pre_value = sum_label[pos]
else:
if pre_value in weight2spk:
spkids = weight2spk[pre_value]
spkid_array = spkids.split("_")
spkid = spkid_array[-1]
#spkid = uttid+spkid
if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
stime = pos
pre_value = sum_label[pos]
textgrid_flist.close()
segment_file.close()
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,27 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
while(<>){
@A = split(" ", $_);
@A > 1 || die "Invalid line in spk2utt file: $_";
$s = shift @A;
foreach $u ( @A ) {
print "$u $s\n";
}
}

View File

@ -0,0 +1,14 @@
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright Chao Weng
# normalizations for hkust trascript
# see the docs/trans-guidelines.pdf for details
while (<STDIN>) {
@A = split(" ", $_);
if (@A == 1) {
next;
}
print $_
}

View File

@ -0,0 +1,38 @@
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright Chao Weng
# normalizations for hkust trascript
# see the docs/trans-guidelines.pdf for details
while (<STDIN>) {
@A = split(" ", $_);
print "$A[0] ";
for ($n = 1; $n < @A; $n++) {
$tmp = $A[$n];
if ($tmp =~ /<sil>/) {$tmp =~ s:<sil>::g;}
if ($tmp =~ /<%>/) {$tmp =~ s:<%>::g;}
if ($tmp =~ /<->/) {$tmp =~ s:<->::g;}
if ($tmp =~ /<\$>/) {$tmp =~ s:<\$>::g;}
if ($tmp =~ /<#>/) {$tmp =~ s:<#>::g;}
if ($tmp =~ /<_>/) {$tmp =~ s:<_>::g;}
if ($tmp =~ /<space>/) {$tmp =~ s:<space>::g;}
if ($tmp =~ /`/) {$tmp =~ s:`::g;}
if ($tmp =~ /&/) {$tmp =~ s:&::g;}
if ($tmp =~ /,/) {$tmp =~ s:,::g;}
if ($tmp =~ /[a-zA-Z]/) {$tmp=uc($tmp);}
if ($tmp =~ //) {$tmp =~ s::A:g;}
if ($tmp =~ //) {$tmp =~ s::A:g;}
if ($tmp =~ //) {$tmp =~ s::B:g;}
if ($tmp =~ //) {$tmp =~ s::C:g;}
if ($tmp =~ //) {$tmp =~ s::K:g;}
if ($tmp =~ //) {$tmp =~ s::T:g;}
if ($tmp =~ //) {$tmp =~ s:::g;}
if ($tmp =~ //) {$tmp =~ s:::g;}
if ($tmp =~ /。/) {$tmp =~ s:::g;}
if ($tmp =~ /、/) {$tmp =~ s:::g;}
if ($tmp =~ //) {$tmp =~ s:::g;}
print "$tmp ";
}
print "\n";
}

View File

@ -0,0 +1,38 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# converts an utt2spk file to a spk2utt file.
# Takes input from the stdin or from a file argument;
# output goes to the standard out.
if ( @ARGV > 1 ) {
die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
}
while(<>){
@A = split(" ", $_);
@A == 2 || die "Invalid line in utt2spk file: $_";
($u,$s) = @A;
if(!$seen_spk{$s}) {
$seen_spk{$s} = 1;
push @spklist, $s;
}
push (@{$spk_hash{$s}}, "$u");
}
foreach $s (@spklist) {
$l = join(' ',@{$spk_hash{$s}});
print "$s $l\n";
}

View File

@ -0,0 +1,404 @@
#!/usr/bin/env bash
cmd="$@"
no_feats=false
no_wav=false
no_text=false
no_spk_sort=false
non_print=false
function show_help
{
echo "Usage: $0 [--no-feats] [--no-text] [--non-print] [--no-wav] [--no-spk-sort] <data-dir>"
echo "The --no-xxx options mean that the script does not require "
echo "xxx.scp to be present, but it will check it if it is present."
echo "--no-spk-sort means that the script does not require the utt2spk to be "
echo "sorted by the speaker-id in addition to being sorted by utterance-id."
echo "--non-print ignore the presence of non-printable characters."
echo "By default, utt2spk is expected to be sorted by both, which can be "
echo "achieved by making the speaker-id prefixes of the utterance-ids"
echo "e.g.: $0 data/train"
}
while [ $# -ne 0 ] ; do
case "$1" in
"--no-feats")
no_feats=true;
;;
"--no-text")
no_text=true;
;;
"--non-print")
non_print=true;
;;
"--no-wav")
no_wav=true;
;;
"--no-spk-sort")
no_spk_sort=true;
;;
*)
if ! [ -z "$data" ] ; then
show_help;
exit 1
fi
data=$1
;;
esac
shift
done
if [ ! -d $data ]; then
echo "$0: no such directory $data"
exit 1;
fi
if [ -f $data/images.scp ]; then
cmd=${cmd/--no-wav/} # remove --no-wav if supplied
image/validate_data_dir.sh $cmd
exit $?
fi
for f in spk2utt utt2spk; do
if [ ! -f $data/$f ]; then
echo "$0: no such file $f"
exit 1;
fi
if [ ! -s $data/$f ]; then
echo "$0: empty file $f"
exit 1;
fi
done
! cat $data/utt2spk | awk '{if (NF != 2) exit(1); }' && \
echo "$0: $data/utt2spk has wrong format." && exit;
ns=$(wc -l < $data/spk2utt)
if [ "$ns" == 1 ]; then
echo "$0: WARNING: you have only one speaker. This probably a bad idea."
echo " Search for the word 'bold' in http://kaldi-asr.org/doc/data_prep.html"
echo " for more information."
fi
tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
export LC_ALL=C
function check_sorted_and_uniq {
! perl -ne '((substr $_,-1) eq "\n") or die "file $ARGV has invalid newline";' $1 && exit 1;
! awk '{print $1}' < $1 | sort -uC && echo "$0: file $1 is not sorted or has duplicates" && exit 1;
}
function partial_diff {
diff -U1 $1 $2 | (head -n 6; echo "..."; tail -n 6)
n1=`cat $1 | wc -l`
n2=`cat $2 | wc -l`
echo "[Lengths are $1=$n1 versus $2=$n2]"
}
check_sorted_and_uniq $data/utt2spk
if ! $no_spk_sort; then
! sort -k2 -C $data/utt2spk && \
echo "$0: utt2spk is not in sorted order when sorted first on speaker-id " && \
echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
fi
check_sorted_and_uniq $data/spk2utt
! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
<(local/spk2utt_to_utt2spk.pl $data/spk2utt) && \
echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
if [ ! -f $data/text ] && ! $no_text; then
echo "$0: no such file $data/text (if this is by design, specify --no-text)"
exit 1;
fi
num_utts=`cat $tmpdir/utts | wc -l`
if ! $no_text; then
if ! $non_print; then
if locale -a | grep "C.UTF-8" >/dev/null; then
L=C.UTF-8
else
L=en_US.UTF-8
fi
n_non_print=$(LC_ALL="$L" grep -c '[^[:print:][:space:]]' $data/text) && \
echo "$0: text contains $n_non_print lines with non-printable characters" &&\
exit 1;
fi
local/validate_text.pl $data/text || exit 1;
check_sorted_and_uniq $data/text
text_len=`cat $data/text | wc -l`
illegal_sym_list="<s> </s> #0"
for x in $illegal_sym_list; do
if grep -w "$x" $data/text > /dev/null; then
echo "$0: Error: in $data, text contains illegal symbol $x"
exit 1;
fi
done
awk '{print $1}' < $data/text > $tmpdir/utts.txt
if ! cmp -s $tmpdir/utts{,.txt}; then
echo "$0: Error: in $data, utterance lists extracted from utt2spk and text"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/utts{,.txt}
exit 1;
fi
fi
if [ -f $data/segments ] && [ ! -f $data/wav.scp ]; then
echo "$0: in directory $data, segments file exists but no wav.scp"
exit 1;
fi
if [ ! -f $data/wav.scp ] && ! $no_wav; then
echo "$0: no such file $data/wav.scp (if this is by design, specify --no-wav)"
exit 1;
fi
if [ -f $data/wav.scp ]; then
check_sorted_and_uniq $data/wav.scp
if grep -E -q '^\S+\s+~' $data/wav.scp; then
# note: it's not a good idea to have any kind of tilde in wav.scp, even if
# part of a command, as it would cause compatibility problems if run by
# other users, but this used to be not checked for so we let it slide unless
# it's something of the form "foo ~/foo.wav" (i.e. a plain file name) which
# would definitely cause problems as the fopen system call does not do
# tilde expansion.
echo "$0: Please do not use tilde (~) in your wav.scp."
exit 1;
fi
if [ -f $data/segments ]; then
check_sorted_and_uniq $data/segments
# We have a segments file -> interpret wav file as "recording-ids" not utterance-ids.
! cat $data/segments | \
awk '{if (NF != 4 || $4 <= $3) { print "Bad line in segments file", $0; exit(1); }}' && \
echo "$0: badly formatted segments file" && exit 1;
segments_len=`cat $data/segments | wc -l`
if [ -f $data/text ]; then
! cmp -s $tmpdir/utts <(awk '{print $1}' <$data/segments) && \
echo "$0: Utterance list differs between $data/utt2spk and $data/segments " && \
echo "$0: Lengths are $segments_len vs $num_utts" && \
exit 1
fi
cat $data/segments | awk '{print $2}' | sort | uniq > $tmpdir/recordings
awk '{print $1}' $data/wav.scp > $tmpdir/recordings.wav
if ! cmp -s $tmpdir/recordings{,.wav}; then
echo "$0: Error: in $data, recording-ids extracted from segments and wav.scp"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/recordings{,.wav}
exit 1;
fi
if [ -f $data/reco2file_and_channel ]; then
# this file is needed only for ctm scoring; it's indexed by recording-id.
check_sorted_and_uniq $data/reco2file_and_channel
! cat $data/reco2file_and_channel | \
awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
if ( NF == 3 && $3 == "1" ) {
warning_issued = 1;
} else {
print "Bad line ", $0; exit 1;
}
}
}
END {
if (warning_issued == 1) {
print "The channel should be marked as A or B, not 1! You should change it ASAP! "
}
}' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/recordings.r2fc
if ! cmp -s $tmpdir/recordings{,.r2fc}; then
echo "$0: Error: in $data, recording-ids extracted from segments and reco2file_and_channel"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/recordings{,.r2fc}
exit 1;
fi
fi
else
# No segments file -> assume wav.scp indexed by utterance.
cat $data/wav.scp | awk '{print $1}' > $tmpdir/utts.wav
if ! cmp -s $tmpdir/utts{,.wav}; then
echo "$0: Error: in $data, utterance lists extracted from utt2spk and wav.scp"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/utts{,.wav}
exit 1;
fi
if [ -f $data/reco2file_and_channel ]; then
# this file is needed only for ctm scoring; it's indexed by recording-id.
check_sorted_and_uniq $data/reco2file_and_channel
! cat $data/reco2file_and_channel | \
awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
if ( NF == 3 && $3 == "1" ) {
warning_issued = 1;
} else {
print "Bad line ", $0; exit 1;
}
}
}
END {
if (warning_issued == 1) {
print "The channel should be marked as A or B, not 1! You should change it ASAP! "
}
}' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/utts.r2fc
if ! cmp -s $tmpdir/utts{,.r2fc}; then
echo "$0: Error: in $data, utterance-ids extracted from segments and reco2file_and_channel"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/utts{,.r2fc}
exit 1;
fi
fi
fi
fi
if [ ! -f $data/feats.scp ] && ! $no_feats; then
echo "$0: no such file $data/feats.scp (if this is by design, specify --no-feats)"
exit 1;
fi
if [ -f $data/feats.scp ]; then
check_sorted_and_uniq $data/feats.scp
cat $data/feats.scp | awk '{print $1}' > $tmpdir/utts.feats
if ! cmp -s $tmpdir/utts{,.feats}; then
echo "$0: Error: in $data, utterance-ids extracted from utt2spk and features"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/utts{,.feats}
exit 1;
fi
fi
if [ -f $data/cmvn.scp ]; then
check_sorted_and_uniq $data/cmvn.scp
cat $data/cmvn.scp | awk '{print $1}' > $tmpdir/speakers.cmvn
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
if ! cmp -s $tmpdir/speakers{,.cmvn}; then
echo "$0: Error: in $data, speaker lists extracted from spk2utt and cmvn"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/speakers{,.cmvn}
exit 1;
fi
fi
if [ -f $data/spk2gender ]; then
check_sorted_and_uniq $data/spk2gender
! cat $data/spk2gender | awk '{if (!((NF == 2 && ($2 == "m" || $2 == "f")))) exit 1; }' && \
echo "$0: Mal-formed spk2gender file" && exit 1;
cat $data/spk2gender | awk '{print $1}' > $tmpdir/speakers.spk2gender
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
if ! cmp -s $tmpdir/speakers{,.spk2gender}; then
echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2gender"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/speakers{,.spk2gender}
exit 1;
fi
fi
if [ -f $data/spk2warp ]; then
check_sorted_and_uniq $data/spk2warp
! cat $data/spk2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
echo "$0: Mal-formed spk2warp file" && exit 1;
cat $data/spk2warp | awk '{print $1}' > $tmpdir/speakers.spk2warp
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
if ! cmp -s $tmpdir/speakers{,.spk2warp}; then
echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2warp"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/speakers{,.spk2warp}
exit 1;
fi
fi
if [ -f $data/utt2warp ]; then
check_sorted_and_uniq $data/utt2warp
! cat $data/utt2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
echo "$0: Mal-formed utt2warp file" && exit 1;
cat $data/utt2warp | awk '{print $1}' > $tmpdir/utts.utt2warp
cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
if ! cmp -s $tmpdir/utts{,.utt2warp}; then
echo "$0: Error: in $data, utterance lists extracted from utt2spk and utt2warp"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/utts{,.utt2warp}
exit 1;
fi
fi
# check some optionally-required things
for f in vad.scp utt2lang utt2uniq; do
if [ -f $data/$f ]; then
check_sorted_and_uniq $data/$f
if ! cmp -s <( awk '{print $1}' $data/utt2spk ) \
<( awk '{print $1}' $data/$f ); then
echo "$0: error: in $data, $f and utt2spk do not have identical utterance-id list"
exit 1;
fi
fi
done
if [ -f $data/utt2dur ]; then
check_sorted_and_uniq $data/utt2dur
cat $data/utt2dur | awk '{print $1}' > $tmpdir/utts.utt2dur
if ! cmp -s $tmpdir/utts{,.utt2dur}; then
echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2dur file"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/utts{,.utt2dur}
exit 1;
fi
cat $data/utt2dur | \
awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line utt2dur:" NR ":" $0; exit(1) }}' || exit 1
fi
if [ -f $data/utt2num_frames ]; then
check_sorted_and_uniq $data/utt2num_frames
cat $data/utt2num_frames | awk '{print $1}' > $tmpdir/utts.utt2num_frames
if ! cmp -s $tmpdir/utts{,.utt2num_frames}; then
echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2num_frames file"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/utts{,.utt2num_frames}
exit 1
fi
awk <$data/utt2num_frames '{
if (NF != 2 || !($2 > 0) || $2 != int($2)) {
print "Bad line utt2num_frames:" NR ":" $0
exit 1 } }' || exit 1
fi
if [ -f $data/reco2dur ]; then
check_sorted_and_uniq $data/reco2dur
cat $data/reco2dur | awk '{print $1}' > $tmpdir/recordings.reco2dur
if [ -f $tmpdir/recordings ]; then
if ! cmp -s $tmpdir/recordings{,.reco2dur}; then
echo "$0: Error: in $data, recording-ids extracted from segments and reco2dur file"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/recordings{,.reco2dur}
exit 1;
fi
else
if ! cmp -s $tmpdir/{utts,recordings.reco2dur}; then
echo "$0: Error: in $data, recording-ids extracted from wav.scp and reco2dur file"
echo "$0: differ, partial diff is:"
partial_diff $tmpdir/{utts,recordings.reco2dur}
exit 1;
fi
fi
cat $data/reco2dur | \
awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line : " $0; exit(1) }}' || exit 1
fi
echo "$0: Successfully validated data-directory $data"

View File

@ -0,0 +1,136 @@
#!/usr/bin/env perl
#
#===============================================================================
# Copyright 2017 Johns Hopkins University (author: Yenda Trmal <jtrmal@gmail.com>)
# Johns Hopkins University (author: Daniel Povey)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
#===============================================================================
# validation script for data/<dataset>/text
# to be called (preferably) from utils/validate_data_dir.sh
use strict;
use warnings;
use utf8;
use Fcntl qw< SEEK_SET >;
# this function reads the opened file (supplied as a first
# parameter) into an array of lines. For each
# line, it tests whether it's a valid utf-8 compatible
# line. If all lines are valid utf-8, it returns the lines
# decoded as utf-8, otherwise it assumes the file's encoding
# is one of those 1-byte encodings, such as ISO-8859-x
# or Windows CP-X.
# Please recall we do not really care about
# the actually encoding, we just need to
# make sure the length of the (decoded) string
# is correct (to make the output formatting looking right).
sub get_utf8_or_bytestream {
use Encode qw(decode encode);
my $is_utf_compatible = 1;
my @unicode_lines;
my @raw_lines;
my $raw_text;
my $lineno = 0;
my $file = shift;
while (<$file>) {
$raw_text = $_;
last unless $raw_text;
if ($is_utf_compatible) {
my $decoded_text = eval { decode("UTF-8", $raw_text, Encode::FB_CROAK) } ;
$is_utf_compatible = $is_utf_compatible && defined($decoded_text);
push @unicode_lines, $decoded_text;
} else {
#print STDERR "WARNING: the line $raw_text cannot be interpreted as UTF-8: $decoded_text\n";
;
}
push @raw_lines, $raw_text;
$lineno += 1;
}
if (!$is_utf_compatible) {
return (0, @raw_lines);
} else {
return (1, @unicode_lines);
}
}
# check if the given unicode string contain unicode whitespaces
# other than the usual four: TAB, LF, CR and SPACE
sub validate_utf8_whitespaces {
my $unicode_lines = shift;
use feature 'unicode_strings';
for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) {
my $current_line = $unicode_lines->[$i];
if ((substr $current_line, -1) ne "\n"){
print STDERR "$0: The current line (nr. $i) has invalid newline\n";
return 1;
}
my @A = split(" ", $current_line);
my $utt_id = $A[0];
# we replace TAB, LF, CR, and SPACE
# this is to simplify the test
if ($current_line =~ /\x{000d}/) {
print STDERR "$0: The line for utterance $utt_id contains CR (0x0D) character\n";
return 1;
}
$current_line =~ s/[\x{0009}\x{000a}\x{0020}]/./g;
if ($current_line =~/\s/) {
print STDERR "$0: The line for utterance $utt_id contains disallowed Unicode whitespaces\n";
return 1;
}
}
return 0;
}
# checks if the text in the file (supplied as the argument) is utf-8 compatible
# if yes, checks if it contains only allowed whitespaces. If no, then does not
# do anything. The function seeks to the original position in the file after
# reading the text.
sub check_allowed_whitespace {
my $file = shift;
my $filename = shift;
my $pos = tell($file);
(my $is_utf, my @lines) = get_utf8_or_bytestream($file);
seek($file, $pos, SEEK_SET);
if ($is_utf) {
my $has_invalid_whitespaces = validate_utf8_whitespaces(\@lines);
if ($has_invalid_whitespaces) {
print STDERR "$0: ERROR: text file '$filename' contains disallowed UTF-8 whitespace character(s)\n";
return 0;
}
}
return 1;
}
if(@ARGV != 1) {
die "Usage: validate_text.pl <text-file>\n" .
"e.g.: validate_text.pl data/train/text\n";
}
my $text = shift @ARGV;
if (-z "$text") {
print STDERR "$0: ERROR: file '$text' is empty or does not exist\n";
exit 1;
}
if(!open(FILE, "<$text")) {
print STDERR "$0: ERROR: failed to open $text\n";
exit 1;
}
check_allowed_whitespace(\*FILE, $text) or exit 1;
close(FILE);

5
egs/alimeeting/sa-asr/path.sh Executable file
View File

@ -0,0 +1,5 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH

50
egs/alimeeting/sa-asr/run.sh Executable file
View File

@ -0,0 +1,50 @@
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
ngpu=4
device="0,1,2,3"
stage=1
stop_stage=18
train_set=Train_Ali_far
valid_set=Eval_Ali_far
test_sets="Test_Ali_far"
asr_config=conf/train_asr_conformer.yaml
sa_asr_config=conf/train_sa_asr_conformer.yaml
inference_config=conf/decode_asr_rnn.yaml
lm_config=conf/train_lm_transformer.yaml
use_lm=false
use_wordlm=false
./asr_local.sh \
--device ${device} \
--ngpu ${ngpu} \
--stage ${stage} \
--stop_stage ${stop_stage} \
--gpu_inference true \
--njob_infer 4 \
--asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
--sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
--asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
--lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
--lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
--lang zh \
--audio_format wav \
--feats_type raw \
--token_type char \
--use_lm ${use_lm} \
--use_word_lm ${use_wordlm} \
--lm_config "${lm_config}" \
--asr_config "${asr_config}" \
--sa_asr_config "${sa_asr_config}" \
--inference_config "${inference_config}" \
--train_set "${train_set}" \
--valid_set "${valid_set}" \
--test_sets "${test_sets}" \
--lm_train_text "data/${train_set}/text" "$@"

View File

@ -0,0 +1,50 @@
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
ngpu=4
device="0,1,2,3"
stage=1
stop_stage=4
train_set=Train_Ali_far
valid_set=Eval_Ali_far
test_sets="Test_2023_Ali_far"
asr_config=conf/train_asr_conformer.yaml
sa_asr_config=conf/train_sa_asr_conformer.yaml
inference_config=conf/decode_asr_rnn.yaml
lm_config=conf/train_lm_transformer.yaml
use_lm=false
use_wordlm=false
./asr_local_m2met_2023_infer.sh \
--device ${device} \
--ngpu ${ngpu} \
--stage ${stage} \
--stop_stage ${stop_stage} \
--gpu_inference true \
--njob_infer 4 \
--asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
--sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
--asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
--lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
--lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
--lang zh \
--audio_format wav \
--feats_type raw \
--token_type char \
--use_lm ${use_lm} \
--use_word_lm ${use_wordlm} \
--lm_config "${lm_config}" \
--asr_config "${asr_config}" \
--sa_asr_config "${sa_asr_config}" \
--inference_config "${inference_config}" \
--train_set "${train_set}" \
--valid_set "${valid_set}" \
--test_sets "${test_sets}" \
--lm_train_text "data/${train_set}/text" "$@"

1
egs/alimeeting/sa-asr/utils Symbolic link
View File

@ -0,0 +1 @@
../../aishell/transformer/utils

View File

@ -41,6 +41,7 @@ from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@ -92,7 +93,11 @@ class Speech2Text:
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
if asr_train_args.frontend=='wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
else:
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
@ -111,7 +116,7 @@ class Speech2Text:
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
lm_train_config, lm_file, None, device
)
scorers["lm"] = lm.lm
@ -193,7 +198,7 @@ class Speech2Text:
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@ -280,6 +285,7 @@ def inference(
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
mc: bool = False,
**kwargs,
):
inference_pipeline = inference_modelscope(
@ -310,6 +316,7 @@ def inference(
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
mc=mc,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@ -342,6 +349,7 @@ def inference_modelscope(
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
mc: bool = False,
param_dict: dict = None,
**kwargs,
):
@ -355,6 +363,9 @@ def inference_modelscope(
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@ -408,6 +419,7 @@ def inference_modelscope(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
@ -416,7 +428,7 @@ def inference_modelscope(
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
@ -452,7 +464,7 @@ def inference_modelscope(
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
@ -463,6 +475,9 @@ def inference_modelscope(
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
logging.info("uttid: {}".format(key))
logging.info("text predictions: {}\n".format(text))
return asr_result_list
return _forward
@ -637,4 +652,4 @@ def main(cmd=None):
if __name__ == "__main__":
main()
main()

View File

@ -71,7 +71,13 @@ def get_parser():
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group.add_argument(
"--mc",
type=bool,
default=False,
help="MultiChannel input",
)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
@ -288,6 +294,9 @@ def inference_launch_funasr(**kwargs):
if mode == "asr":
from funasr.bin.asr_inference import inference
return inference(**kwargs)
elif mode == "sa_asr":
from funasr.bin.sa_asr_inference import inference
return inference(**kwargs)
elif mode == "uniasr":
from funasr.bin.asr_inference_uniasr import inference
return inference(**kwargs)
@ -342,4 +351,4 @@ def main(cmd=None):
if __name__ == "__main__":
main()
main()

View File

@ -27,7 +27,8 @@ if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
if args.ngpu > 0:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
@ -38,9 +39,9 @@ if __name__ == '__main__':
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None:
if args.batch_size is not None and args.ngpu > 0:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
if args.batch_bins is not None and args.ngpu > 0:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@ -0,0 +1,687 @@
import argparse
import logging
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.sa_asr import ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
batch_size: int = 1,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
streaming: bool = False,
frontend_conf: dict = None,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
if asr_train_args.frontend=='wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
else:
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
decoder = asr_model.decoder
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, None, device
)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
) -> List[
Tuple[
Optional[str],
Optional[str],
List[str],
List[int],
Union[Hypothesis],
]
]:
"""Inference
Args:
speech: Input speech data
Returns:
text, text_id, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if isinstance(profile, np.ndarray):
profile = torch.tensor(profile)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
asr_enc, _, spk_enc = self.asr_model.encode(**batch)
if isinstance(asr_enc, tuple):
asr_enc = asr_enc[0]
if isinstance(spk_enc, tuple):
spk_enc = spk_enc[0]
assert len(asr_enc) == 1, len(asr_enc)
assert len(spk_enc) == 1, len(spk_enc)
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1: last_pos]
else:
token_int = hyp.yseq[1: last_pos].tolist()
spk_weigths=torch.stack(hyp.spk_weigths, dim=0)
token_ori = self.converter.ids2tokens(token_int)
text_ori = self.tokenizer.tokens2text(token_ori)
text_ori_spklist = text_ori.split('$')
cur_index = 0
spk_choose = []
for i in range(len(text_ori_spklist)):
text_ori_split = text_ori_spklist[i]
n = len(text_ori_split)
spk_weights_local = spk_weigths[cur_index: cur_index + n]
cur_index = cur_index + n + 1
spk_weights_local = spk_weights_local.mean(dim=0)
spk_choose_local = spk_weights_local.argmax(-1)
spk_choose.append(spk_choose_local.item() + 1)
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
text_spklist = text.split('$')
assert len(spk_choose) == len(text_spklist)
spk_list=[]
for i in range(len(text_spklist)):
text_split = text_spklist[i]
n = len(text_split)
spk_list.append(str(spk_choose[i]) * n)
text_id = '$'.join(spk_list)
assert len(text) == len(text_id)
results.append((text, text_id, token, token_int, hyp))
assert check_return_type(results)
return results
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
mc: bool = False,
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
mc=mc,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
mc: bool = False,
param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
**kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
ibest_writer["text_id"][key] = text_id
if text is not None:
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
logging.info("uttid: {}".format(key))
logging.info("text predictions: {}".format(text))
logging.info("text_id predictions: {}\n".format(text_id))
return asr_result_list
return _forward
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

47
funasr/bin/sa_asr_train.py Executable file
View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
import os
from funasr.tasks.sa_asr import ASRTask
# for ASR Training
def parse_args():
parser = ASRTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for ASR Training
ASRTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
if args.ngpu > 0:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None and args.ngpu > 0:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None and args.ngpu > 0:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@ -46,13 +46,15 @@ class SoundScpReader(collections.abc.Mapping):
if self.normalize:
# soundfile.read normalizes data to [-1,1] if dtype is not given
array, rate = librosa.load(
wav, sr=self.dest_sample_rate, mono=not self.always_2d
wav, sr=self.dest_sample_rate, mono=self.always_2d
)
else:
array, rate = librosa.load(
wav, sr=self.dest_sample_rate, mono=not self.always_2d, dtype=self.dtype
wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype
)
if array.ndim==2:
array=array.transpose((1, 0))
return rate, array
def get_path(self, key):

View File

@ -79,3 +79,49 @@ class SequenceBinaryCrossEntropy(nn.Module):
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask, 0).sum() / denom
class NllLoss(nn.Module):
"""Nll loss.
:param int size: the number of class
:param int padding_idx: ignored class id
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function
"""
def __init__(
self,
size,
padding_idx,
normalize_length=False,
criterion=nn.NLLLoss(reduction='none'),
):
"""Construct an NllLoss object."""
super(NllLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
kl = self.criterion(x , target)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore, 0).sum() / denom

View File

@ -13,6 +13,7 @@ from typeguard import check_argument_types
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.attention import CosineDistanceAttention
from funasr.modules.dynamic_conv import DynamicConvolution
from funasr.modules.dynamic_conv2d import DynamicConvolution2D
from funasr.modules.embedding import PositionalEncoding
@ -763,4 +764,429 @@ class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
normalize_before,
concat_after,
),
)
)
class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
spker_embedding_dim: int = 256,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
input_layer: str = "embed",
use_asr_output_layer: bool = True,
use_spk_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_asr_output_layer:
self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.asr_output_layer = None
if use_spk_output_layer:
self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
else:
self.spk_output_layer = None
self.cos_distance_att = CosineDistanceAttention()
self.decoder1 = None
self.decoder2 = None
self.decoder3 = None
self.decoder4 = None
def forward(
self,
asr_hs_pad: torch.Tensor,
spk_hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
profile: torch.Tensor,
profile_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
asr_memory = asr_hs_pad
spk_memory = spk_hs_pad
memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
# Spk decoder
x = self.embed(tgt)
x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
x, tgt_mask, asr_memory, spk_memory, memory_mask
)
x, tgt_mask, spk_memory, memory_mask = self.decoder2(
x, tgt_mask, spk_memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.spk_output_layer is not None:
x = self.spk_output_layer(x)
dn, weights = self.cos_distance_att(x, profile, profile_lens)
# Asr decoder
x, tgt_mask, asr_memory, memory_mask = self.decoder3(
z, tgt_mask, asr_memory, memory_mask, dn
)
x, tgt_mask, asr_memory, memory_mask = self.decoder4(
x, tgt_mask, asr_memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.asr_output_layer is not None:
x = self.asr_output_layer(x)
olens = tgt_mask.sum(1)
return x, weights, olens
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
asr_memory: torch.Tensor,
spk_memory: torch.Tensor,
profile: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x = self.embed(tgt)
if cache is None:
cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
new_cache = []
x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
)
new_cache.append(x)
for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
x, tgt_mask, spk_memory, _ = decoder(
x, tgt_mask, spk_memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
x = self.after_norm(x)
else:
x = x
if self.spk_output_layer is not None:
x = self.spk_output_layer(x)
dn, weights = self.cos_distance_att(x, profile, None)
x, tgt_mask, asr_memory, _ = self.decoder3(
z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
)
new_cache.append(x)
for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
x, tgt_mask, asr_memory, _ = decoder(
x, tgt_mask, asr_memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.asr_output_layer is not None:
y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
return y, weights, new_cache
def score(self, ys, state, asr_enc, spk_enc, profile):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
logp, weights, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
)
return logp.squeeze(0), weights.squeeze(), state
class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
spker_embedding_dim: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
asr_num_blocks: int = 6,
spk_num_blocks: int = 3,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_asr_output_layer: bool = True,
use_spk_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
spker_embedding_dim=spker_embedding_dim,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_asr_output_layer=use_asr_output_layer,
use_spk_output_layer=use_spk_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
)
self.decoder2 = repeat(
spk_num_blocks - 1,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
attention_dim,
spker_embedding_dim,
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
)
self.decoder4 = repeat(
asr_num_blocks - 1,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
z = x
residual = x
if self.normalize_before:
x = self.norm1(x)
skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
if self.concat_after:
x_concat = torch.cat(
(x, skip), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(skip)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
def __init__(
self,
size,
d_size,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
self.size = size
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
self.spk_linear = nn.Linear(d_size, size, bias=False)
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
x = tgt_q
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if dn!=None:
x = x + self.spk_linear(dn)
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask

520
funasr/models/e2e_sa_asr.py Normal file
View File

@ -0,0 +1,520 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, NllLoss # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class ESPnetASRModel(AbsESPnetModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
vocab_size: int,
max_spk_num: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
asr_encoder: AbsEncoder,
spk_encoder: torch.nn.Module,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
spk_weight: float = 0.5,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.blank_id = 0
self.sos = 1
self.eos = 2
self.vocab_size = vocab_size
self.max_spk_num=max_spk_num
self.ignore_id = ignore_id
self.spk_weight = spk_weight
self.ctc_weight = ctc_weight
self.interctc_weight = interctc_weight
self.token_list = token_list.copy()
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.asr_encoder = asr_encoder
self.spk_encoder = spk_encoder
if not hasattr(self.asr_encoder, "interctc_use_conditioning"):
self.asr_encoder.interctc_use_conditioning = False
if self.asr_encoder.interctc_use_conditioning:
self.asr_encoder.conditioning_layer = torch.nn.Linear(
vocab_size, self.asr_encoder.output_size()
)
self.error_calculator = None
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
# and threw an Exception in the multi-GPU experiment.
# thanks Jeff Farris for pointing out the issue.
if ctc_weight == 1.0:
self.decoder = None
else:
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.criterion_spk = NllLoss(
size=max_spk_num,
padding_idx=ignore_id,
normalize_length=length_normalized_loss,
)
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
profile: torch.Tensor,
profile_lengths: torch.Tensor,
text_id: torch.Tensor,
text_id_lengths: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
profile: (Batch, Length, Dim)
profile_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
asr_encoder_out, encoder_out_lens, spk_encoder_out = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(asr_encoder_out, tuple):
intermediate_outs = asr_encoder_out[1]
asr_encoder_out = asr_encoder_out[0]
loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = None, None, None, None, None, None
loss_ctc, cer_ctc = None, None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
asr_encoder_out, encoder_out_lens, text, text_lengths
)
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss(
asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss_asr = loss_att
elif self.ctc_weight == 1.0:
loss_asr = loss_ctc
else:
loss_asr = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
if self.spk_weight == 0.0:
loss = loss_asr
else:
loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr
stats = dict(
loss=loss.detach(),
loss_asr=loss_asr.detach(),
loss_att=loss_att.detach() if loss_att is not None else None,
loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
loss_spk=loss_spk.detach() if loss_spk is not None else None,
acc=acc_att,
acc_spk=acc_spk,
cer=cer_att,
wer=wer_att,
cer_ctc=cer_ctc,
)
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
feats_raw = feats.clone()
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.asr_encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.asr_encoder(
feats, feats_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0]
# import ipdb;ipdb.set_trace()
if encoder_out_spk_ori.size(1)!=encoder_out.size(1):
encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
else:
encoder_out_spk=encoder_out_spk_ori
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
assert encoder_out_spk.size(0) == speech.size(0), (
encoder_out_spk.size(),
speech.size(0),
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens, encoder_out_spk
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
) # [batch, seqlen, dim]
batch_size = decoder_out.size(0)
decoder_num_class = decoder_out.size(2)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
decoder_out.view(-1, decoder_num_class),
ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction="none",
)
nll = nll.view(batch_size, -1)
nll = nll.sum(dim=1)
assert nll.size(0) == batch_size
return nll
def batchify_nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
GPU memory usage
"""
total_num = encoder_out.size(0)
if total_num <= batch_size:
nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
else:
nll = []
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
batch_ys_pad = ys_pad[start_idx:end_idx, :]
batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
batch_nll = self.nll(
batch_encoder_out,
batch_encoder_out_lens,
batch_ys_pad,
batch_ys_pad_lens,
)
nll.append(batch_nll)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nll)
assert nll.size(0) == total_num
return nll
def _calc_att_loss(
self,
asr_encoder_out: torch.Tensor,
spk_encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
profile: torch.Tensor,
profile_lens: torch.Tensor,
text_id: torch.Tensor,
text_id_lengths: torch.Tensor
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, weights_no_pad, _ = self.decoder(
asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens
)
spk_num_no_pad=weights_no_pad.size(-1)
pad=(0,self.max_spk_num-spk_num_no_pad)
weights=F.pad(weights_no_pad, pad, mode='constant', value=0)
# pre_id=weights.argmax(-1)
# pre_text=decoder_out.argmax(-1)
# id_mask=(pre_id==text_id).to(dtype=text_id.dtype)
# pre_text_mask=pre_text*id_mask+1-id_mask #相同的地方不变不同的地方设为1(<unk>)
# padding_mask= ys_out_pad != self.ignore_id
# numerator = torch.sum(pre_text_mask.masked_select(padding_mask) == ys_out_pad.masked_select(padding_mask))
# denominator = torch.sum(padding_mask)
# sd_acc = float(numerator) / float(denominator)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
loss_spk = self.criterion_spk(torch.log(weights), text_id)
acc_spk= th_accuracy(
weights.view(-1, self.max_spk_num),
text_id,
ignore_label=self.ignore_id,
)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc

View File

@ -38,6 +38,7 @@ class DefaultFrontend(AbsFrontend):
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
use_channel: int = None,
):
assert check_argument_types()
super().__init__()
@ -77,6 +78,7 @@ class DefaultFrontend(AbsFrontend):
)
self.n_mels = n_mels
self.frontend_type = "default"
self.use_channel = use_channel
def output_size(self) -> int:
return self.n_mels
@ -100,9 +102,12 @@ class DefaultFrontend(AbsFrontend):
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
# Select 1ch randomly
ch = np.random.randint(input_stft.size(2))
input_stft = input_stft[:, :, ch, :]
if self.use_channel == None:
input_stft = input_stft[:, :, 0, :]
else:
# Select 1ch randomly
ch = np.random.randint(input_stft.size(2))
input_stft = input_stft[:, :, ch, :]
else:
# Use the first channel
input_stft = input_stft[:, :, 0, :]

View File

@ -83,9 +83,9 @@ def windowed_statistic_pooling(
num_chunk = int(math.ceil(tt / pooling_stride))
pad = pooling_size // 2
if len(xs_pad.shape) == 4:
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
features = F.pad(xs_pad, (0, 0, pad, pad), "replicate")
else:
features = F.pad(xs_pad, (pad, pad), "reflect")
features = F.pad(xs_pad, (pad, pad), "replicate")
stat_list = []
for i in range(num_chunk):

View File

@ -13,6 +13,9 @@ import torch
from torch import nn
from typing import Optional, Tuple
import torch.nn.functional as F
from funasr.modules.nets_utils import make_pad_mask
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@ -959,3 +962,37 @@ class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
q, k, v = self.forward_qkv(query, key, value)
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
class CosineDistanceAttention(nn.Module):
""" Compute Cosine Distance between spk decoder output and speaker profile
Args:
profile_path: speaker profile file path (.npy file)
"""
def __init__(self):
super().__init__()
self.softmax = nn.Softmax(dim=-1)
def forward(self, spk_decoder_out, profile, profile_lens=None):
"""
Args:
spk_decoder_out(torch.Tensor):(B, L, D)
spk_profiles(torch.Tensor):(B, N, D)
"""
x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
if profile_lens is not None:
mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
min_value = float(
numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
)
weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
else:
x = x[:, -1:, :, :]
weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
weights = self.softmax(weights_not_softmax) # (B, 1, N)
spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
return spk_embedding, weights

View File

@ -0,0 +1,525 @@
"""Beam search module."""
from itertools import chain
import logging
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Tuple
from typing import Union
import torch
from funasr.modules.e2e_asr_common import end_detect
from funasr.modules.scorers.scorer_interface import PartialScorerInterface
from funasr.modules.scorers.scorer_interface import ScorerInterface
from funasr.models.decoder.abs_decoder import AbsDecoder
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
yseq: torch.Tensor
spk_weigths : List
score: Union[float, torch.Tensor] = 0
scores: Dict[str, Union[float, torch.Tensor]] = dict()
states: Dict[str, Any] = dict()
def asdict(self) -> dict:
"""Convert data to JSON-friendly dict."""
return self._replace(
yseq=self.yseq.tolist(),
score=float(self.score),
scores={k: float(v) for k, v in self.scores.items()},
)._asdict()
class BeamSearch(torch.nn.Module):
"""Beam search implementation."""
def __init__(
self,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
beam_size: int,
vocab_size: int,
sos: int,
eos: int,
token_list: List[str] = None,
pre_beam_ratio: float = 1.5,
pre_beam_score_key: str = None,
):
"""Initialize beam search.
Args:
scorers (dict[str, ScorerInterface]): Dict of decoder modules
e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is `None`
weights (dict[str, float]): Dict of weights for each scorers
The scorer will be ignored if its weight is 0
beam_size (int): The number of hypotheses kept during search
vocab_size (int): The number of vocabulary
sos (int): Start of sequence id
eos (int): End of sequence id
token_list (list[str]): List of tokens for debug log
pre_beam_score_key (str): key of scores to perform pre-beam search
pre_beam_ratio (float): beam size in the pre-beam search
will be `int(pre_beam_ratio * beam_size)`
"""
super().__init__()
# set scorers
self.weights = weights
self.scorers = dict()
self.full_scorers = dict()
self.part_scorers = dict()
# this module dict is required for recursive cast
# `self.to(device, dtype)` in `recog.py`
self.nn_dict = torch.nn.ModuleDict()
for k, v in scorers.items():
w = weights.get(k, 0)
if w == 0 or v is None:
continue
assert isinstance(
v, ScorerInterface
), f"{k} ({type(v)}) does not implement ScorerInterface"
self.scorers[k] = v
if isinstance(v, PartialScorerInterface):
self.part_scorers[k] = v
else:
self.full_scorers[k] = v
if isinstance(v, torch.nn.Module):
self.nn_dict[k] = v
# set configurations
self.sos = sos
self.eos = eos
self.token_list = token_list
self.pre_beam_size = int(pre_beam_ratio * beam_size)
self.beam_size = beam_size
self.n_vocab = vocab_size
if (
pre_beam_score_key is not None
and pre_beam_score_key != "full"
and pre_beam_score_key not in self.full_scorers
):
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
self.pre_beam_score_key = pre_beam_score_key
self.do_pre_beam = (
self.pre_beam_score_key is not None
and self.pre_beam_size < self.n_vocab
and len(self.part_scorers) > 0
)
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
"""Get an initial hypothesis data.
Args:
x (torch.Tensor): The encoder output feature
Returns:
Hypothesis: The initial hypothesis.
"""
init_states = dict()
init_scores = dict()
for k, d in self.scorers.items():
init_states[k] = d.init_state(x)
init_scores[k] = 0.0
return [
Hypothesis(
score=0.0,
scores=init_scores,
states=init_states,
yseq=torch.tensor([self.sos], device=x.device),
spk_weigths=[],
)
]
@staticmethod
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
"""Append new token to prefix tokens.
Args:
xs (torch.Tensor): The prefix token
x (int): The new token to append
Returns:
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
"""
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
return torch.cat((xs, x))
def score_full(
self, hyp: Hypothesis, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys
and state values of `self.full_scorers`
"""
scores = dict()
states = dict()
for k, d in self.full_scorers.items():
if isinstance(d, AbsDecoder):
scores[k], spk_weigths, states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc, spk_enc, profile)
else:
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc)
return scores, spk_weigths, states
def score_partial(
self, hyp: Hypothesis, ids: torch.Tensor, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.part_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
ids (torch.Tensor): 1D tensor of new partial tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.part_scorers`
and tensor score values of shape: `(len(ids),)`,
and state dict that has string keys
and state values of `self.part_scorers`
"""
scores = dict()
states = dict()
for k, d in self.part_scorers.items():
if isinstance(d, AbsDecoder):
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc, spk_enc, profile)
else:
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc)
return scores, states
def beam(
self, weighted_scores: torch.Tensor, ids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute topk full token ids and partial token ids.
Args:
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
Its shape is `(self.n_vocab,)`.
ids (torch.Tensor): The partial token ids to compute topk
Returns:
Tuple[torch.Tensor, torch.Tensor]:
The topk full token ids and partial token ids.
Their shapes are `(self.beam_size,)`
"""
# no pre beam performed
if weighted_scores.size(0) == ids.size(0):
top_ids = weighted_scores.topk(self.beam_size)[1]
return top_ids, top_ids
# mask pruned in pre-beam not to select in topk
tmp = weighted_scores[ids]
weighted_scores[:] = -float("inf")
weighted_scores[ids] = tmp
top_ids = weighted_scores.topk(self.beam_size)[1]
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
return top_ids, local_ids
@staticmethod
def merge_scores(
prev_scores: Dict[str, float],
next_full_scores: Dict[str, torch.Tensor],
full_idx: int,
next_part_scores: Dict[str, torch.Tensor],
part_idx: int,
) -> Dict[str, torch.Tensor]:
"""Merge scores for new hypothesis.
Args:
prev_scores (Dict[str, float]):
The previous hypothesis scores by `self.scorers`
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
full_idx (int): The next token id for `next_full_scores`
next_part_scores (Dict[str, torch.Tensor]):
scores of partial tokens by `self.part_scorers`
part_idx (int): The new token id for `next_part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are scalar tensors by the scorers.
"""
new_scores = dict()
for k, v in next_full_scores.items():
new_scores[k] = prev_scores[k] + v[full_idx]
for k, v in next_part_scores.items():
new_scores[k] = prev_scores[k] + v[part_idx]
return new_scores
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
"""Merge states for new hypothesis.
Args:
states: states of `self.full_scorers`
part_states: states of `self.part_scorers`
part_idx (int): The new token id for `part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are states of the scorers.
"""
new_states = dict()
for k, v in states.items():
new_states[k] = v
for k, d in self.part_scorers.items():
new_states[k] = d.select_state(part_states[k], part_idx)
return new_states
def search(
self, running_hyps: List[Hypothesis], asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor
) -> List[Hypothesis]:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
running_hyps (List[Hypothesis]): Running hypotheses on beam
x (torch.Tensor): Encoded speech feature (T, D)
Returns:
List[Hypotheses]: Best sorted hypotheses
"""
# import ipdb;ipdb.set_trace()
best_hyps = []
part_ids = torch.arange(self.n_vocab, device=asr_enc.device) # no pre-beam
for hyp in running_hyps:
# scoring
weighted_scores = torch.zeros(self.n_vocab, dtype=asr_enc.dtype, device=asr_enc.device)
scores, spk_weigths, states = self.score_full(hyp, asr_enc, spk_enc, profile)
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
# partial scoring
if self.do_pre_beam:
pre_beam_scores = (
weighted_scores
if self.pre_beam_score_key == "full"
else scores[self.pre_beam_score_key]
)
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
part_scores, part_states = self.score_partial(hyp, part_ids, asr_enc, spk_enc, profile)
for k in self.part_scorers:
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
# add previous hyp score
weighted_scores += hyp.score
# update hyps
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
# will be (2 x beam at most)
best_hyps.append(
Hypothesis(
score=weighted_scores[j],
yseq=self.append_token(hyp.yseq, j),
scores=self.merge_scores(
hyp.scores, scores, j, part_scores, part_j
),
states=self.merge_states(states, part_states, part_j),
spk_weigths=hyp.spk_weigths+[spk_weigths],
)
)
# sort and prune 2 x beam -> beam
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
: min(len(best_hyps), self.beam_size)
]
return best_hyps
def forward(
self, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
) -> List[Hypothesis]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
# import ipdb;ipdb.set_trace()
# set length bounds
if maxlenratio == 0:
maxlen = asr_enc.shape[0]
else:
maxlen = max(1, int(maxlenratio * asr_enc.size(0)))
minlen = int(minlenratio * asr_enc.size(0))
logging.info("decoder input length: " + str(asr_enc.shape[0]))
logging.info("max output length: " + str(maxlen))
logging.info("min output length: " + str(minlen))
# main loop of prefix search
running_hyps = self.init_hyp(asr_enc)
ended_hyps = []
for i in range(maxlen):
logging.debug("position " + str(i))
best = self.search(running_hyps, asr_enc, spk_enc, profile)
#import pdb;pdb.set_trace()
# post process of one iteration
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
# end detection
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
logging.info(f"end detected at {i}")
break
if len(running_hyps) == 0:
logging.info("no hypothesis. Finish decoding.")
break
else:
logging.debug(f"remained hypotheses: {len(running_hyps)}")
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return (
[]
if minlenratio < 0.1
else self.forward(asr_enc, spk_enc, profile, maxlenratio, max(0.0, minlenratio - 0.1))
)
# report the best result
best = nbest_hyps[0]
for k, v in best.scores.items():
logging.info(
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
)
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
"best hypo: "
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
return nbest_hyps
def post_process(
self,
i: int,
maxlen: int,
maxlenratio: float,
running_hyps: List[Hypothesis],
ended_hyps: List[Hypothesis],
) -> List[Hypothesis]:
"""Perform post-processing of beam search iterations.
Args:
i (int): The length of hypothesis tokens.
maxlen (int): The maximum length of tokens in beam search.
maxlenratio (int): The maximum length ratio in beam search.
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
Returns:
List[Hypothesis]: The new running hypotheses.
"""
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
if self.token_list is not None:
logging.debug(
"best hypo: "
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
)
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
running_hyps = [
h._replace(yseq=self.append_token(h.yseq, self.eos))
for h in running_hyps
]
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a problem, number of hyps < beam)
remained_hyps = []
for hyp in running_hyps:
if hyp.yseq[-1] == self.eos:
# e.g., Word LM needs to add final <eos> score
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
s = d.final_score(hyp.states[k])
hyp.scores[k] += s
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
return remained_hyps
def beam_search(
x: torch.Tensor,
sos: int,
eos: int,
beam_size: int,
vocab_size: int,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
token_list: List[str] = None,
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
pre_beam_ratio: float = 1.5,
pre_beam_score_key: str = "full",
) -> list:
"""Perform beam search with scorers.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
sos (int): Start of sequence id
eos (int): End of sequence id
beam_size (int): The number of hypotheses kept during search
vocab_size (int): The number of vocabulary
scorers (dict[str, ScorerInterface]): Dict of decoder modules
e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is `None`
weights (dict[str, float]): Dict of weights for each scorers
The scorer will be ignored if its weight is 0
token_list (list[str]): List of tokens for debug log
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
pre_beam_score_key (str): key of scores to perform pre-beam search
pre_beam_ratio (float): beam size in the pre-beam search
will be `int(pre_beam_ratio * beam_size)`
Returns:
list: N-best decoding results
"""
ret = BeamSearch(
scorers,
weights,
beam_size=beam_size,
vocab_size=vocab_size,
pre_beam_ratio=pre_beam_ratio,
pre_beam_score_key=pre_beam_score_key,
sos=sos,
eos=eos,
token_list=token_list,
).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
return [h.asdict() for h in ret]

View File

@ -444,6 +444,12 @@ class AbsTask(ABC):
default=False,
help='Perform on "collect stats" mode',
)
group.add_argument(
"--mc",
type=bool,
default=False,
help="MultiChannel input",
)
group.add_argument(
"--write_collected_feats",
type=str2bool,
@ -641,8 +647,8 @@ class AbsTask(ABC):
group.add_argument(
"--init_param",
type=str,
action="append",
default=[],
nargs="*",
help="Specify the file path used for initialization of parameters. "
"The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
"where file_path is the model file path, "
@ -668,7 +674,7 @@ class AbsTask(ABC):
"--freeze_param",
type=str,
default=[],
nargs="*",
action="append",
help="Freeze parameters",
)
@ -1159,10 +1165,10 @@ class AbsTask(ABC):
elif args.distributed and args.simple_ddp:
distributed_option.init_torch_distributed_pai(args)
args.ngpu = dist.get_world_size()
if args.dataset_type == "small":
if args.dataset_type == "small" and args.ngpu > 0:
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
if args.batch_bins is not None and args.ngpu > 0:
args.batch_bins = args.batch_bins * args.ngpu
# filter samples if wav.scp and text are mismatch
@ -1322,6 +1328,7 @@ class AbsTask(ABC):
data_path_and_name_and_type=args.train_data_path_and_name_and_type,
key_file=train_key_file,
batch_size=args.batch_size,
mc=args.mc,
dtype=args.train_dtype,
num_workers=args.num_workers,
allow_variable_data_keys=args.allow_variable_data_keys,
@ -1333,6 +1340,7 @@ class AbsTask(ABC):
data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
key_file=valid_key_file,
batch_size=args.valid_batch_size,
mc=args.mc,
dtype=args.train_dtype,
num_workers=args.num_workers,
allow_variable_data_keys=args.allow_variable_data_keys,

623
funasr/tasks/sa_asr.py Normal file
View File

@ -0,0 +1,623 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.rnn_decoder import RNNDecoder
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
from funasr.models.decoder.transformer_decoder import (
DynamicConvolution2DTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
from funasr.models.decoder.transformer_decoder import (
LightweightConvolution2DTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder import (
LightweightConvolutionTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_sa_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34,ResNet34Diar
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
HuggingFaceTransformersPostEncoder, # noqa: H301
)
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.types import float_or_none
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
frontend_choices = ClassChoices(
name="frontend",
classes=dict(
default=DefaultFrontend,
sliding_window=SlidingWindow,
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
multichannelfrontend=MultiChannelFrontend,
),
type_check=AbsFrontend,
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(
specaug=SpecAug,
specaug_lfr=SpecAugLFR,
),
type_check=AbsSpecAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default=None,
optional=True,
)
model_choices = ClassChoices(
"model",
classes=dict(
asr=ESPnetASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
),
type_check=AbsESPnetModel,
default="asr",
)
preencoder_choices = ClassChoices(
name="preencoder",
classes=dict(
sinc=LightweightSincConvs,
linear=LinearProjection,
),
type_check=AbsPreEncoder,
default=None,
optional=True,
)
asr_encoder_choices = ClassChoices(
"asr_encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
),
type_check=AbsEncoder,
default="rnn",
)
spk_encoder_choices = ClassChoices(
"spk_encoder",
classes=dict(
resnet34_diar=ResNet34Diar,
),
default="resnet34_diar",
)
encoder_choices2 = ClassChoices(
"encoder2",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
),
type_check=AbsEncoder,
default="rnn",
)
postencoder_choices = ClassChoices(
name="postencoder",
classes=dict(
hugging_face_transformers=HuggingFaceTransformersPostEncoder,
),
type_check=AbsPostEncoder,
default=None,
optional=True,
)
decoder_choices = ClassChoices(
"decoder",
classes=dict(
transformer=TransformerDecoder,
lightweight_conv=LightweightConvolutionTransformerDecoder,
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
dynamic_conv=DynamicConvolutionTransformerDecoder,
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
rnn=RNNDecoder,
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
sa_decoder=SAAsrTransformerDecoder,
),
type_check=AbsDecoder,
default="sa_decoder",
)
decoder_choices2 = ClassChoices(
"decoder2",
classes=dict(
transformer=TransformerDecoder,
lightweight_conv=LightweightConvolutionTransformerDecoder,
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
dynamic_conv=DynamicConvolutionTransformerDecoder,
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
rnn=RNNDecoder,
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
paraformer_decoder_sanm=ParaformerSANMDecoder,
),
type_check=AbsDecoder,
default="rnn",
)
predictor_choices = ClassChoices(
name="predictor",
classes=dict(
cif_predictor=CifPredictor,
ctc_predictor=None,
cif_predictor_v2=CifPredictorV2,
cif_predictor_v3=CifPredictorV3,
),
type_check=None,
default="cif_predictor",
optional=True,
)
predictor_choices2 = ClassChoices(
name="predictor2",
classes=dict(
cif_predictor=CifPredictor,
ctc_predictor=None,
cif_predictor_v2=CifPredictorV2,
),
type_check=None,
default="cif_predictor",
optional=True,
)
stride_conv_choices = ClassChoices(
name="stride_conv",
classes=dict(
stride_conv1d=Conv1dSubsampling
),
type_check=None,
default="stride_conv1d",
optional=True,
)
class ASRTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --model and --model_conf
model_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --asr_encoder and --asr_encoder_conf
asr_encoder_choices,
# --spk_encoder and --spk_encoder_conf
spk_encoder_choices,
# --postencoder and --postencoder_conf
postencoder_choices,
# --decoder and --decoder_conf
decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--split_with_space",
type=str2bool,
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--max_spk_num",
type=int_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--seg_dict_file",
type=str,
default=None,
help="seg_dict_file for text processing",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group.add_argument(
"--ctc_conf",
action=NestedDictAction,
default=get_default_kwargs(CTC),
help="The keyword arguments for CTC class.",
)
group.add_argument(
"--joint_net_conf",
action=NestedDictAction,
default=None,
help="The keyword arguments for joint network class.",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="bpe",
choices=["bpe", "char", "word", "phn"],
help="The text will be tokenized " "in the specified level token",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model file of sentencepiece",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
default=None,
help="non_linguistic_symbols file path",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Apply text cleaning",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="Specify g2p method if --token_type=phn",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--cmvn_file",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", "text")
else:
# Recognition mode
retval = ("speech",)
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 4. Pre-encoder input block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
if getattr(args, "preencoder", None) is not None:
preencoder_class = preencoder_choices.get_class(args.preencoder)
preencoder = preencoder_class(**args.preencoder_conf)
input_size = preencoder.output_size()
else:
preencoder = None
# 5. Encoder
asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
# 6. Post-encoder block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
asr_encoder_output_size = asr_encoder.output_size()
if getattr(args, "postencoder", None) is not None:
postencoder_class = postencoder_choices.get_class(args.postencoder)
postencoder = postencoder_class(
input_size=asr_encoder_output_size, **args.postencoder_conf
)
asr_encoder_output_size = postencoder.output_size()
else:
postencoder = None
# 7. Decoder
decoder_class = decoder_choices.get_class(args.decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=asr_encoder_output_size,
**args.decoder_conf,
)
# 8. CTC
ctc = CTC(
odim=vocab_size, encoder_output_size=asr_encoder_output_size, **args.ctc_conf
)
max_spk_num=int(args.max_spk_num)
# import ipdb;ipdb.set_trace()
# 9. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("asr")
model = model_class(
vocab_size=vocab_size,
max_spk_num=max_spk_num,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
asr_encoder=asr_encoder,
spk_encoder=spk_encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
token_list=token_list,
**args.model_conf,
)
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model

View File

@ -242,4 +242,4 @@ def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
if ch != ' ':
real_word_lists.append(ch)
sentence = ''.join(word_lists).strip()
return sentence, real_word_lists
return sentence, real_word_lists

View File

@ -13,7 +13,7 @@ requirements = {
"install": [
"setuptools>=38.5.1",
# "configargparse>=1.2.1",
"typeguard<=2.13.3",
"typeguard==2.13.3",
"humanfriendly",
"scipy>=1.4.1",
# "filelock",
@ -42,7 +42,10 @@ requirements = {
"oss2",
# "kaldi-native-fbank",
# timestamp
"edit-distance"
"edit-distance",
# textgrid
"textgrid",
"protobuf==3.20.0",
],
# train: The modules invoked when training only.
"train": [