mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add speaker-attributed ASR task for alimeeting
This commit is contained in:
parent
3b7e4b0d34
commit
d76aea23d9
@ -434,14 +434,14 @@ if ! "${skip_data_prep}"; then
|
||||
log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"
|
||||
for factor in ${speed_perturb_factors}; do
|
||||
if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then
|
||||
scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
|
||||
local/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
|
||||
_dirs+="data/${train_set}_sp${factor} "
|
||||
else
|
||||
# If speed factor is 1, same as the original
|
||||
_dirs+="data/${train_set} "
|
||||
fi
|
||||
done
|
||||
utils/combine_data.sh "data/${train_set}_sp" ${_dirs}
|
||||
local/combine_data.sh "data/${train_set}_sp" ${_dirs}
|
||||
else
|
||||
log "Skip stage 2: Speed perturbation"
|
||||
fi
|
||||
@ -473,7 +473,7 @@ if ! "${skip_data_prep}"; then
|
||||
_suf=""
|
||||
fi
|
||||
fi
|
||||
utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
|
||||
local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
|
||||
|
||||
cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
|
||||
|
||||
@ -488,7 +488,7 @@ if ! "${skip_data_prep}"; then
|
||||
_opts+="--segments data/${dset}/segments "
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
|
||||
local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
|
||||
--audio-format "${audio_format}" --fs "${fs}" ${_opts} \
|
||||
"data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
|
||||
|
||||
@ -515,7 +515,7 @@ if ! "${skip_data_prep}"; then
|
||||
for dset in $rm_dset; do
|
||||
|
||||
# Copy data dir
|
||||
utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
|
||||
local/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
|
||||
cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type"
|
||||
|
||||
# Remove short utterances
|
||||
@ -564,7 +564,7 @@ if ! "${skip_data_prep}"; then
|
||||
awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text"
|
||||
|
||||
# fix_data_dir.sh leaves only utts which exist in all files
|
||||
utils/fix_data_dir.sh "${data_feats}/${dset}"
|
||||
local/fix_data_dir.sh "${data_feats}/${dset}"
|
||||
|
||||
# generate uttid
|
||||
cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
|
||||
@ -1283,6 +1283,7 @@ if ! "${skip_eval}"; then
|
||||
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--mc True \
|
||||
--nbest 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob_infer} \
|
||||
@ -1312,10 +1313,10 @@ if ! "${skip_eval}"; then
|
||||
_data="${data_feats}/${dset}"
|
||||
_dir="${asr_exp}/${inference_tag}/${dset}"
|
||||
|
||||
python local/proce_text.py ${_data}/text ${_data}/text.proc
|
||||
python local/proce_text.py ${_dir}/text ${_dir}/text.proc
|
||||
python utils/proce_text.py ${_data}/text ${_data}/text.proc
|
||||
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
|
||||
|
||||
python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
|
||||
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
|
||||
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
|
||||
cat ${_dir}/text.cer.txt
|
||||
|
||||
@ -1390,6 +1391,7 @@ if ! "${skip_eval}"; then
|
||||
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--mc True \
|
||||
--nbest 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob_infer} \
|
||||
@ -1421,10 +1423,10 @@ if ! "${skip_eval}"; then
|
||||
_data="${data_feats}/${dset}"
|
||||
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
|
||||
|
||||
python local/proce_text.py ${_data}/text ${_data}/text.proc
|
||||
python local/proce_text.py ${_dir}/text ${_dir}/text.proc
|
||||
python utils/proce_text.py ${_data}/text ${_data}/text.proc
|
||||
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
|
||||
|
||||
python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
|
||||
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
|
||||
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
|
||||
cat ${_dir}/text.cer.txt
|
||||
|
||||
@ -1506,6 +1508,7 @@ if ! "${skip_eval}"; then
|
||||
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--mc True \
|
||||
--nbest 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob_infer} \
|
||||
@ -1536,10 +1539,10 @@ if ! "${skip_eval}"; then
|
||||
_data="${data_feats}/${dset}"
|
||||
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
|
||||
|
||||
python local/proce_text.py ${_data}/text ${_data}/text.proc
|
||||
python local/proce_text.py ${_dir}/text ${_dir}/text.proc
|
||||
python utils/proce_text.py ${_data}/text ${_data}/text.proc
|
||||
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
|
||||
|
||||
python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
|
||||
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
|
||||
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
|
||||
cat ${_dir}/text.cer.txt
|
||||
|
||||
|
||||
@ -436,7 +436,7 @@ if ! "${skip_data_prep}"; then
|
||||
|
||||
_suf=""
|
||||
|
||||
utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
|
||||
local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
|
||||
|
||||
rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
|
||||
_opts=
|
||||
@ -548,6 +548,7 @@ if ! "${skip_eval}"; then
|
||||
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--mc True \
|
||||
--nbest 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob_infer} \
|
||||
|
||||
@ -4,7 +4,6 @@ frontend_conf:
|
||||
n_fft: 400
|
||||
win_length: 400
|
||||
hop_length: 160
|
||||
use_channel: 0
|
||||
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
|
||||
@ -4,7 +4,6 @@ frontend_conf:
|
||||
n_fft: 400
|
||||
win_length: 400
|
||||
hop_length: 160
|
||||
use_channel: 0
|
||||
|
||||
# encoder related
|
||||
asr_encoder: conformer
|
||||
|
||||
@ -78,7 +78,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
|
||||
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
|
||||
#sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
|
||||
utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
|
||||
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
|
||||
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
|
||||
sed -e 's/!//g' $near_dir/tmp1> $near_dir/tmp2
|
||||
@ -109,7 +109,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
|
||||
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
|
||||
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
|
||||
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
|
||||
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
|
||||
@ -121,8 +121,8 @@ fi
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
log "stage 3: finali data process"
|
||||
|
||||
utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
|
||||
utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
|
||||
local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
|
||||
local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
|
||||
|
||||
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
@ -146,10 +146,10 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
|
||||
|
||||
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
|
||||
utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
|
||||
|
||||
./utils/fix_data_dir.sh $far_single_speaker_dir
|
||||
utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
|
||||
./local/fix_data_dir.sh $far_single_speaker_dir
|
||||
local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
|
||||
|
||||
# remove space in text
|
||||
for x in ${tgt}_Ali_far_single_speaker; do
|
||||
|
||||
@ -77,7 +77,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
|
||||
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
|
||||
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
|
||||
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
|
||||
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
|
||||
@ -89,7 +89,7 @@ fi
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
log "stage 2: finali data process"
|
||||
|
||||
utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
|
||||
local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
|
||||
|
||||
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
|
||||
@ -113,10 +113,10 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
|
||||
|
||||
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
|
||||
utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
|
||||
|
||||
./utils/fix_data_dir.sh $far_single_speaker_dir
|
||||
utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
|
||||
./local/fix_data_dir.sh $far_single_speaker_dir
|
||||
local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
|
||||
|
||||
# remove space in text
|
||||
for x in ${tgt}_Ali_far_single_speaker; do
|
||||
|
||||
@ -98,7 +98,7 @@ if $has_segments; then
|
||||
for in_dir in $*; do
|
||||
if [ ! -f $in_dir/segments ]; then
|
||||
echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
|
||||
utils/data/get_segments_for_data.sh $in_dir
|
||||
local/data/get_segments_for_data.sh $in_dir
|
||||
else
|
||||
cat $in_dir/segments
|
||||
fi
|
||||
@ -133,14 +133,14 @@ for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn
|
||||
fi
|
||||
done
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
|
||||
|
||||
if [[ $dir_with_frame_shift ]]; then
|
||||
cp $dir_with_frame_shift/frame_shift $dest
|
||||
fi
|
||||
|
||||
if ! $skip_fix ; then
|
||||
utils/fix_data_dir.sh $dest || exit 1;
|
||||
local/fix_data_dir.sh $dest || exit 1;
|
||||
fi
|
||||
|
||||
exit 0
|
||||
@ -71,25 +71,25 @@ else
|
||||
cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
|
||||
fi
|
||||
|
||||
cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \
|
||||
utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
|
||||
cat $srcdir/utt2spk | local/apply_map.pl -f 1 $destdir/utt_map | \
|
||||
local/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
|
||||
|
||||
if [ -f $srcdir/feats.scp ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
|
||||
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
|
||||
fi
|
||||
|
||||
if [ -f $srcdir/vad.scp ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
|
||||
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
|
||||
fi
|
||||
|
||||
if [ -f $srcdir/segments ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
|
||||
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
|
||||
cp $srcdir/wav.scp $destdir
|
||||
else # no segments->wav indexed by utt.
|
||||
if [ -f $srcdir/wav.scp ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
|
||||
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -98,26 +98,26 @@ if [ -f $srcdir/reco2file_and_channel ]; then
|
||||
fi
|
||||
|
||||
if [ -f $srcdir/text ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
|
||||
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
|
||||
fi
|
||||
if [ -f $srcdir/utt2dur ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
|
||||
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
|
||||
fi
|
||||
if [ -f $srcdir/utt2num_frames ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
|
||||
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
|
||||
fi
|
||||
if [ -f $srcdir/reco2dur ]; then
|
||||
if [ -f $srcdir/segments ]; then
|
||||
cp $srcdir/reco2dur $destdir/reco2dur
|
||||
else
|
||||
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
|
||||
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
|
||||
fi
|
||||
fi
|
||||
if [ -f $srcdir/spk2gender ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
|
||||
local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
|
||||
fi
|
||||
if [ -f $srcdir/cmvn.scp ]; then
|
||||
utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
|
||||
local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
|
||||
fi
|
||||
for f in frame_shift stm glm ctm; do
|
||||
if [ -f $srcdir/$f ]; then
|
||||
@ -142,4 +142,4 @@ done
|
||||
[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
|
||||
[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
|
||||
|
||||
utils/validate_data_dir.sh $validate_opts $destdir
|
||||
local/validate_data_dir.sh $validate_opts $destdir
|
||||
@ -20,7 +20,7 @@ fi
|
||||
data=$1
|
||||
|
||||
if [ ! -s $data/utt2dur ]; then
|
||||
utils/data/get_utt2dur.sh $data 1>&2 || exit 1;
|
||||
local/data/get_utt2dur.sh $data 1>&2 || exit 1;
|
||||
fi
|
||||
|
||||
# <utt-id> <utt-id> 0 <utt-dur>
|
||||
@ -94,7 +94,7 @@ elif [ -f $data/wav.scp ]; then
|
||||
nj=$num_utts
|
||||
fi
|
||||
|
||||
utils/data/split_data.sh --per-utt $data $nj
|
||||
local/data/split_data.sh --per-utt $data $nj
|
||||
sdata=$data/split${nj}utt
|
||||
|
||||
$cmd JOB=1:$nj $data/log/get_durations.JOB.log \
|
||||
@ -60,11 +60,11 @@ nf=`cat $data/feats.scp 2>/dev/null | wc -l`
|
||||
nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
|
||||
if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
|
||||
echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
|
||||
echo "** use utils/fix_data_dir.sh $data to fix this."
|
||||
echo "** use local/fix_data_dir.sh $data to fix this."
|
||||
fi
|
||||
if [ -f $data/text ] && [ $nu -ne $nt ]; then
|
||||
echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
|
||||
echo "** use utils/fix_data_dir.sh to fix this."
|
||||
echo "** use local/fix_data_dir.sh to fix this."
|
||||
fi
|
||||
|
||||
|
||||
@ -112,7 +112,7 @@ utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
|
||||
|
||||
for n in `seq $numsplit`; do
|
||||
dsn=$data/split${numsplit}${utt}/$n
|
||||
utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
|
||||
local/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
|
||||
done
|
||||
|
||||
maybe_wav_scp=
|
||||
@ -112,7 +112,7 @@ function filter_recordings {
|
||||
|
||||
function filter_speakers {
|
||||
# throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
|
||||
utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
|
||||
|
||||
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
|
||||
for s in cmvn.scp spk2gender; do
|
||||
@ -123,7 +123,7 @@ function filter_speakers {
|
||||
done
|
||||
|
||||
filter_file $tmpdir/speakers $data/spk2utt
|
||||
utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
|
||||
local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
|
||||
|
||||
for s in cmvn.scp spk2gender $spk_extra_files; do
|
||||
f=$data/$s
|
||||
@ -210,6 +210,6 @@ filter_utts
|
||||
filter_speakers
|
||||
filter_recordings
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
|
||||
local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
|
||||
|
||||
echo "fix_data_dir.sh: old files are kept in $data/.backup"
|
||||
243
egs/alimeeting/sa-asr/local/format_wav_scp.py
Executable file
243
egs/alimeeting/sa-asr/local/format_wav_scp.py
Executable file
@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import kaldiio
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import resampy
|
||||
import soundfile
|
||||
from tqdm import tqdm
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.fileio.read_text import read_2column_text
|
||||
from funasr.fileio.sound_scp import SoundScpWriter
|
||||
|
||||
|
||||
def humanfriendly_or_none(value: str):
|
||||
if value in ("none", "None", "NONE"):
|
||||
return None
|
||||
return humanfriendly.parse_size(value)
|
||||
|
||||
|
||||
def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
|
||||
"""
|
||||
|
||||
>>> str2int_tuple('3,4,5')
|
||||
(3, 4, 5)
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
|
||||
return None
|
||||
return tuple(map(int, integers.strip().split(",")))
|
||||
|
||||
|
||||
def main():
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
logging.info(get_commandline_args())
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Create waves list from "wav.scp"',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("scp")
|
||||
parser.add_argument("outdir")
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
default="wav",
|
||||
help="Specify the prefix word of output file name " 'such as "wav.scp"',
|
||||
)
|
||||
parser.add_argument("--segments", default=None)
|
||||
parser.add_argument(
|
||||
"--fs",
|
||||
type=humanfriendly_or_none,
|
||||
default=None,
|
||||
help="If the sampling rate specified, " "Change the sampling rate.",
|
||||
)
|
||||
parser.add_argument("--audio-format", default="wav")
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--ref-channels", default=None, type=str2int_tuple)
|
||||
group.add_argument("--utt2ref-channels", default=None, type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
out_num_samples = Path(args.outdir) / f"utt2num_samples"
|
||||
|
||||
if args.ref_channels is not None:
|
||||
|
||||
def utt2ref_channels(x) -> Tuple[int, ...]:
|
||||
return args.ref_channels
|
||||
|
||||
elif args.utt2ref_channels is not None:
|
||||
utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
|
||||
|
||||
def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
|
||||
chs_str = d[x]
|
||||
return tuple(map(int, chs_str.split()))
|
||||
|
||||
else:
|
||||
utt2ref_channels = None
|
||||
|
||||
Path(args.outdir).mkdir(parents=True, exist_ok=True)
|
||||
out_wavscp = Path(args.outdir) / f"{args.name}.scp"
|
||||
if args.segments is not None:
|
||||
# Note: kaldiio supports only wav-pcm-int16le file.
|
||||
loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
|
||||
if args.audio_format.endswith("ark"):
|
||||
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
|
||||
fscp = out_wavscp.open("w")
|
||||
else:
|
||||
writer = SoundScpWriter(
|
||||
args.outdir,
|
||||
out_wavscp,
|
||||
format=args.audio_format,
|
||||
)
|
||||
|
||||
with out_num_samples.open("w") as fnum_samples:
|
||||
for uttid, (rate, wave) in tqdm(loader):
|
||||
# wave: (Time,) or (Time, Nmic)
|
||||
if wave.ndim == 2 and utt2ref_channels is not None:
|
||||
wave = wave[:, utt2ref_channels(uttid)]
|
||||
|
||||
if args.fs is not None and args.fs != rate:
|
||||
# FIXME(kamo): To use sox?
|
||||
wave = resampy.resample(
|
||||
wave.astype(np.float64), rate, args.fs, axis=0
|
||||
)
|
||||
wave = wave.astype(np.int16)
|
||||
rate = args.fs
|
||||
if args.audio_format.endswith("ark"):
|
||||
if "flac" in args.audio_format:
|
||||
suf = "flac"
|
||||
elif "wav" in args.audio_format:
|
||||
suf = "wav"
|
||||
else:
|
||||
raise RuntimeError("wav.ark or flac")
|
||||
|
||||
# NOTE(kamo): Using extended ark format style here.
|
||||
# This format is incompatible with Kaldi
|
||||
kaldiio.save_ark(
|
||||
fark,
|
||||
{uttid: (wave, rate)},
|
||||
scp=fscp,
|
||||
append=True,
|
||||
write_function=f"soundfile_{suf}",
|
||||
)
|
||||
|
||||
else:
|
||||
writer[uttid] = rate, wave
|
||||
fnum_samples.write(f"{uttid} {len(wave)}\n")
|
||||
else:
|
||||
if args.audio_format.endswith("ark"):
|
||||
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
|
||||
else:
|
||||
wavdir = Path(args.outdir) / f"data_{args.name}"
|
||||
wavdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with Path(args.scp).open("r") as fscp, out_wavscp.open(
|
||||
"w"
|
||||
) as fout, out_num_samples.open("w") as fnum_samples:
|
||||
for line in tqdm(fscp):
|
||||
uttid, wavpath = line.strip().split(None, 1)
|
||||
|
||||
if wavpath.endswith("|"):
|
||||
# Streaming input e.g. cat a.wav |
|
||||
with kaldiio.open_like_kaldi(wavpath, "rb") as f:
|
||||
with BytesIO(f.read()) as g:
|
||||
wave, rate = soundfile.read(g, dtype=np.int16)
|
||||
if wave.ndim == 2 and utt2ref_channels is not None:
|
||||
wave = wave[:, utt2ref_channels(uttid)]
|
||||
|
||||
if args.fs is not None and args.fs != rate:
|
||||
# FIXME(kamo): To use sox?
|
||||
wave = resampy.resample(
|
||||
wave.astype(np.float64), rate, args.fs, axis=0
|
||||
)
|
||||
wave = wave.astype(np.int16)
|
||||
rate = args.fs
|
||||
|
||||
if args.audio_format.endswith("ark"):
|
||||
if "flac" in args.audio_format:
|
||||
suf = "flac"
|
||||
elif "wav" in args.audio_format:
|
||||
suf = "wav"
|
||||
else:
|
||||
raise RuntimeError("wav.ark or flac")
|
||||
|
||||
# NOTE(kamo): Using extended ark format style here.
|
||||
# This format is incompatible with Kaldi
|
||||
kaldiio.save_ark(
|
||||
fark,
|
||||
{uttid: (wave, rate)},
|
||||
scp=fout,
|
||||
append=True,
|
||||
write_function=f"soundfile_{suf}",
|
||||
)
|
||||
else:
|
||||
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
|
||||
soundfile.write(owavpath, wave, rate)
|
||||
fout.write(f"{uttid} {owavpath}\n")
|
||||
else:
|
||||
wave, rate = soundfile.read(wavpath, dtype=np.int16)
|
||||
if wave.ndim == 2 and utt2ref_channels is not None:
|
||||
wave = wave[:, utt2ref_channels(uttid)]
|
||||
save_asis = False
|
||||
|
||||
elif args.audio_format.endswith("ark"):
|
||||
save_asis = False
|
||||
|
||||
elif Path(wavpath).suffix == "." + args.audio_format and (
|
||||
args.fs is None or args.fs == rate
|
||||
):
|
||||
save_asis = True
|
||||
|
||||
else:
|
||||
save_asis = False
|
||||
|
||||
if save_asis:
|
||||
# Neither --segments nor --fs are specified and
|
||||
# the line doesn't end with "|",
|
||||
# i.e. not using unix-pipe,
|
||||
# only in this case,
|
||||
# just using the original file as is.
|
||||
fout.write(f"{uttid} {wavpath}\n")
|
||||
else:
|
||||
if args.fs is not None and args.fs != rate:
|
||||
# FIXME(kamo): To use sox?
|
||||
wave = resampy.resample(
|
||||
wave.astype(np.float64), rate, args.fs, axis=0
|
||||
)
|
||||
wave = wave.astype(np.int16)
|
||||
rate = args.fs
|
||||
|
||||
if args.audio_format.endswith("ark"):
|
||||
if "flac" in args.audio_format:
|
||||
suf = "flac"
|
||||
elif "wav" in args.audio_format:
|
||||
suf = "wav"
|
||||
else:
|
||||
raise RuntimeError("wav.ark or flac")
|
||||
|
||||
# NOTE(kamo): Using extended ark format style here.
|
||||
# This format is not supported in Kaldi.
|
||||
kaldiio.save_ark(
|
||||
fark,
|
||||
{uttid: (wave, rate)},
|
||||
scp=fout,
|
||||
append=True,
|
||||
write_function=f"soundfile_{suf}",
|
||||
)
|
||||
else:
|
||||
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
|
||||
soundfile.write(owavpath, wave, rate)
|
||||
fout.write(f"{uttid} {owavpath}\n")
|
||||
fnum_samples.write(f"{uttid} {len(wave)}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
142
egs/alimeeting/sa-asr/local/format_wav_scp.sh
Executable file
142
egs/alimeeting/sa-asr/local/format_wav_scp.sh
Executable file
@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
SECONDS=0
|
||||
log() {
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
help_message=$(cat << EOF
|
||||
Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
|
||||
e.g.
|
||||
$0 data/test/wav.scp data/test_format/
|
||||
|
||||
Format 'wav.scp': In short words,
|
||||
changing "kaldi-datadir" to "modified-kaldi-datadir"
|
||||
|
||||
The 'wav.scp' format in kaldi is very flexible,
|
||||
e.g. It can use unix-pipe as describing that wav file,
|
||||
but it sometime looks confusing and make scripts more complex.
|
||||
This tools creates actual wav files from 'wav.scp'
|
||||
and also segments wav files using 'segments'.
|
||||
|
||||
Options
|
||||
--fs <fs>
|
||||
--segments <segments>
|
||||
--nj <nj>
|
||||
--cmd <cmd>
|
||||
EOF
|
||||
)
|
||||
|
||||
out_filename=wav.scp
|
||||
cmd=utils/run.pl
|
||||
nj=30
|
||||
fs=none
|
||||
segments=
|
||||
|
||||
ref_channels=
|
||||
utt2ref_channels=
|
||||
|
||||
audio_format=wav
|
||||
write_utt2num_samples=true
|
||||
|
||||
log "$0 $*"
|
||||
. utils/parse_options.sh
|
||||
|
||||
if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
|
||||
log "${help_message}"
|
||||
log "Error: invalid command line arguments"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
. ./path.sh # Setup the environment
|
||||
|
||||
scp=$1
|
||||
if [ ! -f "${scp}" ]; then
|
||||
log "${help_message}"
|
||||
echo "$0: Error: No such file: ${scp}"
|
||||
exit 1
|
||||
fi
|
||||
dir=$2
|
||||
|
||||
|
||||
if [ $# -eq 2 ]; then
|
||||
logdir=${dir}/logs
|
||||
outdir=${dir}/data
|
||||
|
||||
elif [ $# -eq 3 ]; then
|
||||
logdir=$3
|
||||
outdir=${dir}/data
|
||||
|
||||
elif [ $# -eq 4 ]; then
|
||||
logdir=$3
|
||||
outdir=$4
|
||||
fi
|
||||
|
||||
|
||||
mkdir -p ${logdir}
|
||||
|
||||
rm -f "${dir}/${out_filename}"
|
||||
|
||||
|
||||
opts=
|
||||
if [ -n "${utt2ref_channels}" ]; then
|
||||
opts="--utt2ref-channels ${utt2ref_channels} "
|
||||
elif [ -n "${ref_channels}" ]; then
|
||||
opts="--ref-channels ${ref_channels} "
|
||||
fi
|
||||
|
||||
|
||||
if [ -n "${segments}" ]; then
|
||||
log "[info]: using ${segments}"
|
||||
nutt=$(<${segments} wc -l)
|
||||
nj=$((nj<nutt?nj:nutt))
|
||||
|
||||
split_segments=""
|
||||
for n in $(seq ${nj}); do
|
||||
split_segments="${split_segments} ${logdir}/segments.${n}"
|
||||
done
|
||||
|
||||
utils/split_scp.pl "${segments}" ${split_segments}
|
||||
|
||||
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
|
||||
local/format_wav_scp.py \
|
||||
${opts} \
|
||||
--fs ${fs} \
|
||||
--audio-format "${audio_format}" \
|
||||
"--segment=${logdir}/segments.JOB" \
|
||||
"${scp}" "${outdir}/format.JOB"
|
||||
|
||||
else
|
||||
log "[info]: without segments"
|
||||
nutt=$(<${scp} wc -l)
|
||||
nj=$((nj<nutt?nj:nutt))
|
||||
|
||||
split_scps=""
|
||||
for n in $(seq ${nj}); do
|
||||
split_scps="${split_scps} ${logdir}/wav.${n}.scp"
|
||||
done
|
||||
|
||||
utils/split_scp.pl "${scp}" ${split_scps}
|
||||
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
|
||||
local/format_wav_scp.py \
|
||||
${opts} \
|
||||
--fs "${fs}" \
|
||||
--audio-format "${audio_format}" \
|
||||
"${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
|
||||
fi
|
||||
|
||||
# Workaround for the NFS problem
|
||||
ls ${outdir}/format.* > /dev/null
|
||||
|
||||
# concatenate the .scp files together.
|
||||
for n in $(seq ${nj}); do
|
||||
cat "${outdir}/format.${n}/wav.scp" || exit 1;
|
||||
done > "${dir}/${out_filename}" || exit 1
|
||||
|
||||
if "${write_utt2num_samples}"; then
|
||||
for n in $(seq ${nj}); do
|
||||
cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
|
||||
done > "${dir}/utt2num_samples" || exit 1
|
||||
fi
|
||||
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
116
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
Executable file
116
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
Executable file
@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# 2020 @kamo-naoyuki
|
||||
# This file was copied from Kaldi and
|
||||
# I deleted parts related to wav duration
|
||||
# because we shouldn't use kaldi's command here
|
||||
# and we don't need the files actually.
|
||||
|
||||
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
|
||||
# 2014 Tom Ko
|
||||
# 2018 Emotech LTD (author: Pawel Swietojanski)
|
||||
# Apache 2.0
|
||||
|
||||
# This script operates on a directory, such as in data/train/,
|
||||
# that contains some subset of the following files:
|
||||
# wav.scp
|
||||
# spk2utt
|
||||
# utt2spk
|
||||
# text
|
||||
#
|
||||
# It generates the files which are used for perturbing the speed of the original data.
|
||||
|
||||
export LC_ALL=C
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# != 3 ]]; then
|
||||
echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
|
||||
echo "e.g.:"
|
||||
echo " $0 0.9 data/train_si284 data/train_si284p"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
factor=$1
|
||||
srcdir=$2
|
||||
destdir=$3
|
||||
label="sp"
|
||||
spk_prefix="${label}${factor}-"
|
||||
utt_prefix="${label}${factor}-"
|
||||
|
||||
#check is sox on the path
|
||||
|
||||
! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
|
||||
|
||||
if [[ ! -f ${srcdir}/utt2spk ]]; then
|
||||
echo "$0: no such file ${srcdir}/utt2spk"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [[ ${destdir} == "${srcdir}" ]]; then
|
||||
echo "$0: this script requires <srcdir> and <destdir> to be different."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "${destdir}"
|
||||
|
||||
<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
|
||||
<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
|
||||
<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
|
||||
if [[ ! -f ${srcdir}/utt2uniq ]]; then
|
||||
<"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
|
||||
else
|
||||
<"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
|
||||
fi
|
||||
|
||||
|
||||
<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
|
||||
utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
|
||||
|
||||
if [[ -f ${srcdir}/segments ]]; then
|
||||
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
|
||||
utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
|
||||
awk -v factor="${factor}" \
|
||||
'{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
|
||||
>"${destdir}"/segments
|
||||
|
||||
utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
|
||||
awk -v factor="${factor}" \
|
||||
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
|
||||
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
|
||||
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
|
||||
> "${destdir}"/wav.scp
|
||||
if [[ -f ${srcdir}/reco2file_and_channel ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/reco_map \
|
||||
<"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
|
||||
fi
|
||||
|
||||
else # no segments->wav indexed by utterance.
|
||||
if [[ -f ${srcdir}/wav.scp ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
|
||||
awk -v factor="${factor}" \
|
||||
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
|
||||
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
|
||||
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
|
||||
> "${destdir}"/wav.scp
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -f ${srcdir}/text ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
|
||||
fi
|
||||
if [[ -f ${srcdir}/spk2gender ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
|
||||
fi
|
||||
if [[ -f ${srcdir}/utt2lang ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
|
||||
fi
|
||||
|
||||
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
|
||||
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
|
||||
|
||||
utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
|
||||
@ -113,7 +113,7 @@ fi
|
||||
check_sorted_and_uniq $data/spk2utt
|
||||
|
||||
! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
|
||||
<(utils/spk2utt_to_utt2spk.pl $data/spk2utt) && \
|
||||
<(local/spk2utt_to_utt2spk.pl $data/spk2utt) && \
|
||||
echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
|
||||
|
||||
cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
|
||||
@ -135,7 +135,7 @@ if ! $no_text; then
|
||||
echo "$0: text contains $n_non_print lines with non-printable characters" &&\
|
||||
exit 1;
|
||||
fi
|
||||
utils/validate_text.pl $data/text || exit 1;
|
||||
local/validate_text.pl $data/text || exit 1;
|
||||
check_sorted_and_uniq $data/text
|
||||
text_len=`cat $data/text | wc -l`
|
||||
illegal_sym_list="<s> </s> #0"
|
||||
@ -2,5 +2,4 @@ export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
export PATH=$PWD/utils/:$PATH
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
1
egs/alimeeting/sa-asr/utils
Symbolic link
1
egs/alimeeting/sa-asr/utils
Symbolic link
@ -0,0 +1 @@
|
||||
../../aishell/transformer/utils
|
||||
@ -1,87 +0,0 @@
|
||||
#!/usr/bin/env perl
|
||||
# Copyright 2010-2012 Microsoft Corporation
|
||||
# Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This script takes a list of utterance-ids or any file whose first field
|
||||
# of each line is an utterance-id, and filters an scp
|
||||
# file (or any file whose "n-th" field is an utterance id), printing
|
||||
# out only those lines whose "n-th" field is in id_list. The index of
|
||||
# the "n-th" field is 1, by default, but can be changed by using
|
||||
# the -f <n> switch
|
||||
|
||||
$exclude = 0;
|
||||
$field = 1;
|
||||
$shifted = 0;
|
||||
|
||||
do {
|
||||
$shifted=0;
|
||||
if ($ARGV[0] eq "--exclude") {
|
||||
$exclude = 1;
|
||||
shift @ARGV;
|
||||
$shifted=1;
|
||||
}
|
||||
if ($ARGV[0] eq "-f") {
|
||||
$field = $ARGV[1];
|
||||
shift @ARGV; shift @ARGV;
|
||||
$shifted=1
|
||||
}
|
||||
} while ($shifted);
|
||||
|
||||
if(@ARGV < 1 || @ARGV > 2) {
|
||||
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
|
||||
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
|
||||
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
|
||||
"only the lines that were *not* in id_list.\n" .
|
||||
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
|
||||
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
|
||||
"-f option, add 1 to the argument.\n" .
|
||||
"See also: utils/filter_scp.pl .\n";
|
||||
}
|
||||
|
||||
|
||||
$idlist = shift @ARGV;
|
||||
open(F, "<$idlist") || die "Could not open id-list file $idlist";
|
||||
while(<F>) {
|
||||
@A = split;
|
||||
@A>=1 || die "Invalid id-list file line $_";
|
||||
$seen{$A[0]} = 1;
|
||||
}
|
||||
|
||||
if ($field == 1) { # Treat this as special case, since it is common.
|
||||
while(<>) {
|
||||
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
|
||||
# $1 is what we filter on.
|
||||
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
|
||||
print $_;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while(<>) {
|
||||
@A = split;
|
||||
@A > 0 || die "Invalid scp file line $_";
|
||||
@A >= $field || die "Invalid scp file line $_";
|
||||
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
|
||||
print $_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# tests:
|
||||
# the following should print "foo 1"
|
||||
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
|
||||
# the following should print "bar 2".
|
||||
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)
|
||||
@ -1,97 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
||||
# Arnab Ghoshal, Karel Vesely
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Parse command-line options.
|
||||
# To be sourced by another script (as in ". parse_options.sh").
|
||||
# Option format is: --option-name arg
|
||||
# and shell variable "option_name" gets set to value "arg."
|
||||
# The exception is --help, which takes no arguments, but prints the
|
||||
# $help_message variable (if defined).
|
||||
|
||||
|
||||
###
|
||||
### The --config file options have lower priority to command line
|
||||
### options, so we need to import them first...
|
||||
###
|
||||
|
||||
# Now import all the configs specified by command-line, in left-to-right order
|
||||
for ((argpos=1; argpos<$#; argpos++)); do
|
||||
if [ "${!argpos}" == "--config" ]; then
|
||||
argpos_plus1=$((argpos+1))
|
||||
config=${!argpos_plus1}
|
||||
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
||||
. $config # source the config file.
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
###
|
||||
### Now we process the command line options
|
||||
###
|
||||
while true; do
|
||||
[ -z "${1:-}" ] && break; # break if there are no arguments
|
||||
case "$1" in
|
||||
# If the enclosing script is called with --help option, print the help
|
||||
# message and exit. Scripts should put help messages in $help_message
|
||||
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
||||
else printf "$help_message\n" 1>&2 ; fi;
|
||||
exit 0 ;;
|
||||
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
||||
exit 1 ;;
|
||||
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
||||
# then work out the variable name as $name, which will equal "foo_bar".
|
||||
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
||||
# Next we test whether the variable in question is undefned-- if so it's
|
||||
# an invalid option and we die. Note: $0 evaluates to the name of the
|
||||
# enclosing script.
|
||||
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
||||
# is undefined. We then have to wrap this test inside "eval" because
|
||||
# foo_bar is itself inside a variable ($name).
|
||||
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
||||
|
||||
oldval="`eval echo \\$$name`";
|
||||
# Work out whether we seem to be expecting a Boolean argument.
|
||||
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
||||
was_bool=true;
|
||||
else
|
||||
was_bool=false;
|
||||
fi
|
||||
|
||||
# Set the variable to the right value-- the escaped quotes make it work if
|
||||
# the option had spaces, like --cmd "queue.pl -sync y"
|
||||
eval $name=\"$2\";
|
||||
|
||||
# Check that Boolean-valued arguments are really Boolean.
|
||||
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
||||
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
||||
exit 1;
|
||||
fi
|
||||
shift 2;
|
||||
;;
|
||||
*) break;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
# Check for an empty argument to the --cmd option, which can easily occur as a
|
||||
# result of scripting errors.
|
||||
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
||||
|
||||
|
||||
true; # so this script returns exit code 0.
|
||||
@ -1,246 +0,0 @@
|
||||
#!/usr/bin/env perl
|
||||
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# See ../../COPYING for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This program splits up any kind of .scp or archive-type file.
|
||||
# If there is no utt2spk option it will work on any text file and
|
||||
# will split it up with an approximately equal number of lines in
|
||||
# each but.
|
||||
# With the --utt2spk option it will work on anything that has the
|
||||
# utterance-id as the first entry on each line; the utt2spk file is
|
||||
# of the form "utterance speaker" (on each line).
|
||||
# It splits it into equal size chunks as far as it can. If you use the utt2spk
|
||||
# option it will make sure these chunks coincide with speaker boundaries. In
|
||||
# this case, if there are more chunks than speakers (and in some other
|
||||
# circumstances), some of the resulting chunks will be empty and it will print
|
||||
# an error message and exit with nonzero status.
|
||||
# You will normally call this like:
|
||||
# split_scp.pl scp scp.1 scp.2 scp.3 ...
|
||||
# or
|
||||
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
|
||||
# Note that you can use this script to split the utt2spk file itself,
|
||||
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
|
||||
|
||||
# You can also call the scripts like:
|
||||
# split_scp.pl -j 3 0 scp scp.0
|
||||
# [note: with this option, it assumes zero-based indexing of the split parts,
|
||||
# i.e. the second number must be 0 <= n < num-jobs.]
|
||||
|
||||
use warnings;
|
||||
|
||||
$num_jobs = 0;
|
||||
$job_id = 0;
|
||||
$utt2spk_file = "";
|
||||
$one_based = 0;
|
||||
|
||||
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
|
||||
if ($ARGV[0] eq "-j") {
|
||||
shift @ARGV;
|
||||
$num_jobs = shift @ARGV;
|
||||
$job_id = shift @ARGV;
|
||||
}
|
||||
if ($ARGV[0] =~ /--utt2spk=(.+)/) {
|
||||
$utt2spk_file=$1;
|
||||
shift;
|
||||
}
|
||||
if ($ARGV[0] eq '--one-based') {
|
||||
$one_based = 1;
|
||||
shift @ARGV;
|
||||
}
|
||||
}
|
||||
|
||||
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
|
||||
$job_id - $one_based >= $num_jobs)) {
|
||||
die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
|
||||
($one_based ? " --one-based" : "") . "'\n"
|
||||
}
|
||||
|
||||
$one_based
|
||||
and $job_id--;
|
||||
|
||||
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
|
||||
die
|
||||
"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
|
||||
or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
|
||||
... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
|
||||
}
|
||||
|
||||
$error = 0;
|
||||
$inscp = shift @ARGV;
|
||||
if ($num_jobs == 0) { # without -j option
|
||||
@OUTPUTS = @ARGV;
|
||||
} else {
|
||||
for ($j = 0; $j < $num_jobs; $j++) {
|
||||
if ($j == $job_id) {
|
||||
if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
|
||||
else { push @OUTPUTS, "-"; }
|
||||
} else {
|
||||
push @OUTPUTS, "/dev/null";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ($utt2spk_file ne "") { # We have the --utt2spk option...
|
||||
open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
|
||||
while(<$u_fh>) {
|
||||
@A = split;
|
||||
@A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
|
||||
($u,$s) = @A;
|
||||
$utt2spk{$u} = $s;
|
||||
}
|
||||
close $u_fh;
|
||||
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
|
||||
@spkrs = ();
|
||||
while(<$i_fh>) {
|
||||
@A = split;
|
||||
if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
|
||||
$u = $A[0];
|
||||
$s = $utt2spk{$u};
|
||||
defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
|
||||
if(!defined $spk_count{$s}) {
|
||||
push @spkrs, $s;
|
||||
$spk_count{$s} = 0;
|
||||
$spk_data{$s} = []; # ref to new empty array.
|
||||
}
|
||||
$spk_count{$s}++;
|
||||
push @{$spk_data{$s}}, $_;
|
||||
}
|
||||
# Now split as equally as possible ..
|
||||
# First allocate spks to files by allocating an approximately
|
||||
# equal number of speakers.
|
||||
$numspks = @spkrs; # number of speakers.
|
||||
$numscps = @OUTPUTS; # number of output files.
|
||||
if ($numspks < $numscps) {
|
||||
die "$0: Refusing to split data because number of speakers $numspks " .
|
||||
"is less than the number of output .scp files $numscps\n";
|
||||
}
|
||||
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
|
||||
$scparray[$scpidx] = []; # [] is array reference.
|
||||
}
|
||||
for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
|
||||
$scpidx = int(($spkidx*$numscps) / $numspks);
|
||||
$spk = $spkrs[$spkidx];
|
||||
push @{$scparray[$scpidx]}, $spk;
|
||||
$scpcount[$scpidx] += $spk_count{$spk};
|
||||
}
|
||||
|
||||
# Now will try to reassign beginning + ending speakers
|
||||
# to different scp's and see if it gets more balanced.
|
||||
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
|
||||
# We can show that if considering changing just 2 scp's, we minimize
|
||||
# this by minimizing the squared difference in sizes. This is
|
||||
# equivalent to minimizing the absolute difference in sizes. This
|
||||
# shows this method is bound to converge.
|
||||
|
||||
$changed = 1;
|
||||
while($changed) {
|
||||
$changed = 0;
|
||||
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
|
||||
# First try to reassign ending spk of this scp.
|
||||
if($scpidx < $numscps-1) {
|
||||
$sz = @{$scparray[$scpidx]};
|
||||
if($sz > 0) {
|
||||
$spk = $scparray[$scpidx]->[$sz-1];
|
||||
$count = $spk_count{$spk};
|
||||
$nutt1 = $scpcount[$scpidx];
|
||||
$nutt2 = $scpcount[$scpidx+1];
|
||||
if( abs( ($nutt2+$count) - ($nutt1-$count))
|
||||
< abs($nutt2 - $nutt1)) { # Would decrease
|
||||
# size-diff by reassigning spk...
|
||||
$scpcount[$scpidx+1] += $count;
|
||||
$scpcount[$scpidx] -= $count;
|
||||
pop @{$scparray[$scpidx]};
|
||||
unshift @{$scparray[$scpidx+1]}, $spk;
|
||||
$changed = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
|
||||
$spk = $scparray[$scpidx]->[0];
|
||||
$count = $spk_count{$spk};
|
||||
$nutt1 = $scpcount[$scpidx-1];
|
||||
$nutt2 = $scpcount[$scpidx];
|
||||
if( abs( ($nutt2-$count) - ($nutt1+$count))
|
||||
< abs($nutt2 - $nutt1)) { # Would decrease
|
||||
# size-diff by reassigning spk...
|
||||
$scpcount[$scpidx-1] += $count;
|
||||
$scpcount[$scpidx] -= $count;
|
||||
shift @{$scparray[$scpidx]};
|
||||
push @{$scparray[$scpidx-1]}, $spk;
|
||||
$changed = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
# Now print out the files...
|
||||
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
|
||||
$scpfile = $OUTPUTS[$scpidx];
|
||||
($scpfile ne '-' ? open($f_fh, '>', $scpfile)
|
||||
: open($f_fh, '>&', \*STDOUT)) ||
|
||||
die "$0: Could not open scp file $scpfile for writing: $!\n";
|
||||
$count = 0;
|
||||
if(@{$scparray[$scpidx]} == 0) {
|
||||
print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
|
||||
"$scpfile (too many splits and too few speakers?)\n";
|
||||
$error = 1;
|
||||
} else {
|
||||
foreach $spk ( @{$scparray[$scpidx]} ) {
|
||||
print $f_fh @{$spk_data{$spk}};
|
||||
$count += $spk_count{$spk};
|
||||
}
|
||||
$count == $scpcount[$scpidx] || die "Count mismatch [code error]";
|
||||
}
|
||||
close($f_fh);
|
||||
}
|
||||
} else {
|
||||
# This block is the "normal" case where there is no --utt2spk
|
||||
# option and we just break into equal size chunks.
|
||||
|
||||
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
|
||||
|
||||
$numscps = @OUTPUTS; # size of array.
|
||||
@F = ();
|
||||
while(<$i_fh>) {
|
||||
push @F, $_;
|
||||
}
|
||||
$numlines = @F;
|
||||
if($numlines == 0) {
|
||||
print STDERR "$0: error: empty input scp file $inscp\n";
|
||||
$error = 1;
|
||||
}
|
||||
$linesperscp = int( $numlines / $numscps); # the "whole part"..
|
||||
$linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
|
||||
$remainder = $numlines - ($linesperscp * $numscps);
|
||||
($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
|
||||
# [just doing int() rounds down].
|
||||
$n = 0;
|
||||
for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
|
||||
$scpfile = $OUTPUTS[$scpidx];
|
||||
($scpfile ne '-' ? open($o_fh, '>', $scpfile)
|
||||
: open($o_fh, '>&', \*STDOUT)) ||
|
||||
die "$0: Could not open scp file $scpfile for writing: $!\n";
|
||||
for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
|
||||
print $o_fh $F[$n++];
|
||||
}
|
||||
close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
|
||||
}
|
||||
$n == $numlines || die "$n != $numlines [code error]";
|
||||
}
|
||||
|
||||
exit ($error);
|
||||
@ -40,6 +40,8 @@ from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.tasks.asr import frontend_choices
|
||||
|
||||
|
||||
header_colors = '\033[95m'
|
||||
@ -90,6 +92,12 @@ class Speech2Text:
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
if asr_train_args.frontend=='wav_frontend':
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
|
||||
else:
|
||||
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
|
||||
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
||||
|
||||
logging.info("asr_model: {}".format(asr_model))
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
@ -197,12 +205,21 @@ class Speech2Text:
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
||||
if self.frontend is not None:
|
||||
feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
||||
feats = to_device(feats, device=self.device)
|
||||
feats_len = feats_len.int()
|
||||
self.asr_model.frontend = None
|
||||
else:
|
||||
feats = speech
|
||||
feats_len = speech_lengths
|
||||
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
|
||||
batch = {"speech": feats, "speech_lengths": feats_len}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
@ -275,6 +292,7 @@ def inference(
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
mc: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
inference_pipeline = inference_modelscope(
|
||||
@ -305,6 +323,7 @@ def inference(
|
||||
ngram_weight=ngram_weight,
|
||||
nbest=nbest,
|
||||
num_workers=num_workers,
|
||||
mc=mc,
|
||||
**kwargs,
|
||||
)
|
||||
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
|
||||
@ -337,6 +356,7 @@ def inference_modelscope(
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
mc: bool = False,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -406,7 +426,7 @@ def inference_modelscope(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
fs=fs,
|
||||
mc=True,
|
||||
mc=mc,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
@ -415,7 +435,7 @@ def inference_modelscope(
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
|
||||
finish_count = 0
|
||||
file_count = 1
|
||||
# 7 .Start for-loop
|
||||
|
||||
@ -71,7 +71,13 @@ def get_parser():
|
||||
)
|
||||
group.add_argument("--key_file", type=str_or_none)
|
||||
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
||||
|
||||
group.add_argument(
|
||||
"--mc",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="MultiChannel input",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("The model configuration related")
|
||||
group.add_argument(
|
||||
"--vad_infer_config",
|
||||
|
||||
@ -2,14 +2,6 @@
|
||||
|
||||
import os
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level='INFO',
|
||||
format=f"[{os.uname()[1].split('.')[0]}]"
|
||||
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
from funasr.tasks.asr import ASRTask
|
||||
|
||||
|
||||
|
||||
@ -35,6 +35,8 @@ from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.tasks.asr import frontend_choices
|
||||
|
||||
|
||||
header_colors = '\033[95m'
|
||||
@ -85,6 +87,12 @@ class Speech2Text:
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
frontend = None
|
||||
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
||||
if asr_train_args.frontend=='wav_frontend':
|
||||
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
|
||||
else:
|
||||
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
|
||||
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
||||
|
||||
logging.info("asr_model: {}".format(asr_model))
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
@ -201,7 +209,16 @@ class Speech2Text:
|
||||
if isinstance(profile, np.ndarray):
|
||||
profile = torch.tensor(profile)
|
||||
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
||||
if self.frontend is not None:
|
||||
feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
||||
feats = to_device(feats, device=self.device)
|
||||
feats_len = feats_len.int()
|
||||
self.asr_model.frontend = None
|
||||
else:
|
||||
feats = speech
|
||||
feats_len = speech_lengths
|
||||
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
|
||||
batch = {"speech": feats, "speech_lengths": feats_len}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
@ -308,6 +325,7 @@ def inference(
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
mc: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
inference_pipeline = inference_modelscope(
|
||||
@ -338,6 +356,7 @@ def inference(
|
||||
ngram_weight=ngram_weight,
|
||||
nbest=nbest,
|
||||
num_workers=num_workers,
|
||||
mc=mc,
|
||||
**kwargs,
|
||||
)
|
||||
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
|
||||
@ -370,6 +389,7 @@ def inference_modelscope(
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
mc: bool = False,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -437,7 +457,7 @@ def inference_modelscope(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
fs=fs,
|
||||
mc=True,
|
||||
mc=mc,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
|
||||
@ -2,14 +2,6 @@
|
||||
|
||||
import os
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level='INFO',
|
||||
format=f"[{os.uname()[1].split('.')[0]}]"
|
||||
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
from funasr.tasks.sa_asr import ASRTask
|
||||
|
||||
|
||||
|
||||
@ -79,3 +79,49 @@ class SequenceBinaryCrossEntropy(nn.Module):
|
||||
loss = self.criterion(pred, label)
|
||||
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
|
||||
return loss.masked_fill(pad_mask, 0).sum() / denom
|
||||
|
||||
|
||||
class NllLoss(nn.Module):
|
||||
"""Nll loss.
|
||||
|
||||
:param int size: the number of class
|
||||
:param int padding_idx: ignored class id
|
||||
:param bool normalize_length: normalize loss by sequence length if True
|
||||
:param torch.nn.Module criterion: loss function
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
padding_idx,
|
||||
normalize_length=False,
|
||||
criterion=nn.NLLLoss(reduction='none'),
|
||||
):
|
||||
"""Construct an LabelSmoothingLoss object."""
|
||||
super(NllLoss, self).__init__()
|
||||
self.criterion = criterion
|
||||
self.padding_idx = padding_idx
|
||||
self.size = size
|
||||
self.true_dist = None
|
||||
self.normalize_length = normalize_length
|
||||
|
||||
def forward(self, x, target):
|
||||
"""Compute loss between x and target.
|
||||
|
||||
:param torch.Tensor x: prediction (batch, seqlen, class)
|
||||
:param torch.Tensor target:
|
||||
target signal masked with self.padding_id (batch, seqlen)
|
||||
:return: scalar float value
|
||||
:rtype torch.Tensor
|
||||
"""
|
||||
assert x.size(2) == self.size
|
||||
batch_size = x.size(0)
|
||||
x = x.view(-1, self.size)
|
||||
target = target.view(-1)
|
||||
with torch.no_grad():
|
||||
ignore = target == self.padding_idx # (B,)
|
||||
total = len(target) - ignore.sum().item()
|
||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
||||
kl = self.criterion(x , target)
|
||||
denom = total if self.normalize_length else batch_size
|
||||
return kl.masked_fill(ignore, 0).sum() / denom
|
||||
|
||||
@ -13,6 +13,7 @@ from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.modules.attention import MultiHeadedAttention
|
||||
from funasr.modules.attention import CosineDistanceAttention
|
||||
from funasr.modules.dynamic_conv import DynamicConvolution
|
||||
from funasr.modules.dynamic_conv2d import DynamicConvolution2D
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
@ -763,4 +764,429 @@ class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
spker_embedding_dim: int = 256,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
input_layer: str = "embed",
|
||||
use_asr_output_layer: bool = True,
|
||||
use_spk_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
if input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(vocab_size, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
if use_asr_output_layer:
|
||||
self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.asr_output_layer = None
|
||||
|
||||
if use_spk_output_layer:
|
||||
self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
|
||||
else:
|
||||
self.spk_output_layer = None
|
||||
|
||||
self.cos_distance_att = CosineDistanceAttention()
|
||||
|
||||
self.decoder1 = None
|
||||
self.decoder2 = None
|
||||
self.decoder3 = None
|
||||
self.decoder4 = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
asr_hs_pad: torch.Tensor,
|
||||
spk_hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
profile_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask & m
|
||||
|
||||
asr_memory = asr_hs_pad
|
||||
spk_memory = spk_hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
|
||||
# Spk decoder
|
||||
x = self.embed(tgt)
|
||||
|
||||
x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
|
||||
x, tgt_mask, asr_memory, spk_memory, memory_mask
|
||||
)
|
||||
x, tgt_mask, spk_memory, memory_mask = self.decoder2(
|
||||
x, tgt_mask, spk_memory, memory_mask
|
||||
)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.spk_output_layer is not None:
|
||||
x = self.spk_output_layer(x)
|
||||
dn, weights = self.cos_distance_att(x, profile, profile_lens)
|
||||
# Asr decoder
|
||||
x, tgt_mask, asr_memory, memory_mask = self.decoder3(
|
||||
z, tgt_mask, asr_memory, memory_mask, dn
|
||||
)
|
||||
x, tgt_mask, asr_memory, memory_mask = self.decoder4(
|
||||
x, tgt_mask, asr_memory, memory_mask
|
||||
)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.asr_output_layer is not None:
|
||||
x = self.asr_output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, weights, olens
|
||||
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
asr_memory: torch.Tensor,
|
||||
spk_memory: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
cache: List[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
x = self.embed(tgt)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
|
||||
new_cache = []
|
||||
x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
|
||||
x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
|
||||
)
|
||||
new_cache.append(x)
|
||||
for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
|
||||
x, tgt_mask, spk_memory, _ = decoder(
|
||||
x, tgt_mask, spk_memory, None, cache=c
|
||||
)
|
||||
new_cache.append(x)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
else:
|
||||
x = x
|
||||
if self.spk_output_layer is not None:
|
||||
x = self.spk_output_layer(x)
|
||||
dn, weights = self.cos_distance_att(x, profile, None)
|
||||
|
||||
x, tgt_mask, asr_memory, _ = self.decoder3(
|
||||
z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
|
||||
)
|
||||
new_cache.append(x)
|
||||
|
||||
for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
|
||||
x, tgt_mask, asr_memory, _ = decoder(
|
||||
x, tgt_mask, asr_memory, None, cache=c
|
||||
)
|
||||
new_cache.append(x)
|
||||
|
||||
if self.normalize_before:
|
||||
y = self.after_norm(x[:, -1])
|
||||
else:
|
||||
y = x[:, -1]
|
||||
if self.asr_output_layer is not None:
|
||||
y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
|
||||
|
||||
return y, weights, new_cache
|
||||
|
||||
def score(self, ys, state, asr_enc, spk_enc, profile):
|
||||
"""Score."""
|
||||
ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
|
||||
logp, weights, state = self.forward_one_step(
|
||||
ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
|
||||
)
|
||||
return logp.squeeze(0), weights.squeeze(), state
|
||||
|
||||
class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
spker_embedding_dim: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
asr_num_blocks: int = 6,
|
||||
spk_num_blocks: int = 3,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_asr_output_layer: bool = True,
|
||||
use_spk_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
spker_embedding_dim=spker_embedding_dim,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_asr_output_layer=use_asr_output_layer,
|
||||
use_spk_output_layer=use_spk_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
self.decoder2 = repeat(
|
||||
spk_num_blocks - 1,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
|
||||
attention_dim,
|
||||
spker_embedding_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
self.decoder4 = repeat(
|
||||
asr_num_blocks - 1,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
|
||||
self.size = size
|
||||
self.self_attn = self_attn
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
|
||||
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
# compute only the last frame query keeping dim: max_time_out -> 1
|
||||
assert cache.shape == (
|
||||
tgt.shape[0],
|
||||
tgt.shape[1] - 1,
|
||||
self.size,
|
||||
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
tgt_concat = torch.cat(
|
||||
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear1(tgt_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
z = x
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, skip), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(skip)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
|
||||
|
||||
class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
d_size,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
|
||||
self.size = size
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
self.spk_linear = nn.Linear(d_size, size, bias=False)
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
|
||||
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
x = tgt_q
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
residual = x
|
||||
|
||||
if dn!=None:
|
||||
x = x + self.spk_linear(dn)
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, memory, memory_mask
|
||||
@ -16,9 +16,8 @@ from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
LabelSmoothingLoss, NllLoss # noqa: H301
|
||||
)
|
||||
from funasr.losses.nll_loss import NllLoss
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
@ -28,7 +28,7 @@ from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecode
|
||||
from funasr.models.decoder.transformer_decoder import (
|
||||
DynamicConvolution2DTransformerDecoder, # noqa: H301
|
||||
)
|
||||
from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder
|
||||
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
|
||||
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
|
||||
from funasr.models.decoder.transformer_decoder import (
|
||||
LightweightConvolution2DTransformerDecoder, # noqa: H301
|
||||
|
||||
Loading…
Reference in New Issue
Block a user