add speaker-attributed ASR task for alimeeting

This commit is contained in:
smohan-speech 2023-05-07 02:21:58 +08:00
parent 3b7e4b0d34
commit d76aea23d9
35 changed files with 1090 additions and 516 deletions

View File

@ -434,14 +434,14 @@ if ! "${skip_data_prep}"; then
log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"
for factor in ${speed_perturb_factors}; do
if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then
scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
local/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
_dirs+="data/${train_set}_sp${factor} "
else
# If speed factor is 1, same as the original
_dirs+="data/${train_set} "
fi
done
utils/combine_data.sh "data/${train_set}_sp" ${_dirs}
local/combine_data.sh "data/${train_set}_sp" ${_dirs}
else
log "Skip stage 2: Speed perturbation"
fi
@ -473,7 +473,7 @@ if ! "${skip_data_prep}"; then
_suf=""
fi
fi
utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
@ -488,7 +488,7 @@ if ! "${skip_data_prep}"; then
_opts+="--segments data/${dset}/segments "
fi
# shellcheck disable=SC2086
scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
--audio-format "${audio_format}" --fs "${fs}" ${_opts} \
"data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
@ -515,7 +515,7 @@ if ! "${skip_data_prep}"; then
for dset in $rm_dset; do
# Copy data dir
utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
local/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type"
# Remove short utterances
@ -564,7 +564,7 @@ if ! "${skip_data_prep}"; then
awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text"
# fix_data_dir.sh leaves only utts which exist in all files
utils/fix_data_dir.sh "${data_feats}/${dset}"
local/fix_data_dir.sh "${data_feats}/${dset}"
# generate uttid
cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
@ -1283,6 +1283,7 @@ if ! "${skip_eval}"; then
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--mc True \
--nbest 1 \
--ngpu "${_ngpu}" \
--njob ${njob_infer} \
@ -1312,10 +1313,10 @@ if ! "${skip_eval}"; then
_data="${data_feats}/${dset}"
_dir="${asr_exp}/${inference_tag}/${dset}"
python local/proce_text.py ${_data}/text ${_data}/text.proc
python local/proce_text.py ${_dir}/text ${_dir}/text.proc
python utils/proce_text.py ${_data}/text ${_data}/text.proc
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
@ -1390,6 +1391,7 @@ if ! "${skip_eval}"; then
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--mc True \
--nbest 1 \
--ngpu "${_ngpu}" \
--njob ${njob_infer} \
@ -1421,10 +1423,10 @@ if ! "${skip_eval}"; then
_data="${data_feats}/${dset}"
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
python local/proce_text.py ${_data}/text ${_data}/text.proc
python local/proce_text.py ${_dir}/text ${_dir}/text.proc
python utils/proce_text.py ${_data}/text ${_data}/text.proc
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
@ -1506,6 +1508,7 @@ if ! "${skip_eval}"; then
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--mc True \
--nbest 1 \
--ngpu "${_ngpu}" \
--njob ${njob_infer} \
@ -1536,10 +1539,10 @@ if ! "${skip_eval}"; then
_data="${data_feats}/${dset}"
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
python local/proce_text.py ${_data}/text ${_data}/text.proc
python local/proce_text.py ${_dir}/text ${_dir}/text.proc
python utils/proce_text.py ${_data}/text ${_data}/text.proc
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt

View File

@ -436,7 +436,7 @@ if ! "${skip_data_prep}"; then
_suf=""
utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
_opts=
@ -548,6 +548,7 @@ if ! "${skip_eval}"; then
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--mc True \
--nbest 1 \
--ngpu "${_ngpu}" \
--njob ${njob_infer} \

View File

@ -4,7 +4,6 @@ frontend_conf:
n_fft: 400
win_length: 400
hop_length: 160
use_channel: 0
# encoder related
encoder: conformer

View File

@ -4,7 +4,6 @@ frontend_conf:
n_fft: 400
win_length: 400
hop_length: 160
use_channel: 0
# encoder related
asr_encoder: conformer

View File

@ -78,7 +78,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
#sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
sed -e 's///g' $near_dir/tmp1> $near_dir/tmp2
@ -109,7 +109,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
@ -121,8 +121,8 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
log "stage 3: finali data process"
utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
@ -146,10 +146,10 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
./utils/fix_data_dir.sh $far_single_speaker_dir
utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
./local/fix_data_dir.sh $far_single_speaker_dir
local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do

View File

@ -77,7 +77,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
@ -89,7 +89,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
log "stage 2: finali data process"
utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
@ -113,10 +113,10 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
./utils/fix_data_dir.sh $far_single_speaker_dir
utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
./local/fix_data_dir.sh $far_single_speaker_dir
local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do

View File

@ -98,7 +98,7 @@ if $has_segments; then
for in_dir in $*; do
if [ ! -f $in_dir/segments ]; then
echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
utils/data/get_segments_for_data.sh $in_dir
local/data/get_segments_for_data.sh $in_dir
else
cat $in_dir/segments
fi
@ -133,14 +133,14 @@ for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn
fi
done
utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
local/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
if [[ $dir_with_frame_shift ]]; then
cp $dir_with_frame_shift/frame_shift $dest
fi
if ! $skip_fix ; then
utils/fix_data_dir.sh $dest || exit 1;
local/fix_data_dir.sh $dest || exit 1;
fi
exit 0

View File

@ -71,25 +71,25 @@ else
cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
fi
cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \
utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
cat $srcdir/utt2spk | local/apply_map.pl -f 1 $destdir/utt_map | \
local/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
local/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
if [ -f $srcdir/feats.scp ]; then
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
fi
if [ -f $srcdir/vad.scp ]; then
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
fi
if [ -f $srcdir/segments ]; then
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
cp $srcdir/wav.scp $destdir
else # no segments->wav indexed by utt.
if [ -f $srcdir/wav.scp ]; then
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
fi
fi
@ -98,26 +98,26 @@ if [ -f $srcdir/reco2file_and_channel ]; then
fi
if [ -f $srcdir/text ]; then
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
fi
if [ -f $srcdir/utt2dur ]; then
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
fi
if [ -f $srcdir/utt2num_frames ]; then
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
fi
if [ -f $srcdir/reco2dur ]; then
if [ -f $srcdir/segments ]; then
cp $srcdir/reco2dur $destdir/reco2dur
else
utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
fi
fi
if [ -f $srcdir/spk2gender ]; then
utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
fi
if [ -f $srcdir/cmvn.scp ]; then
utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
fi
for f in frame_shift stm glm ctm; do
if [ -f $srcdir/$f ]; then
@ -142,4 +142,4 @@ done
[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
utils/validate_data_dir.sh $validate_opts $destdir
local/validate_data_dir.sh $validate_opts $destdir

View File

@ -20,7 +20,7 @@ fi
data=$1
if [ ! -s $data/utt2dur ]; then
utils/data/get_utt2dur.sh $data 1>&2 || exit 1;
local/data/get_utt2dur.sh $data 1>&2 || exit 1;
fi
# <utt-id> <utt-id> 0 <utt-dur>

View File

@ -94,7 +94,7 @@ elif [ -f $data/wav.scp ]; then
nj=$num_utts
fi
utils/data/split_data.sh --per-utt $data $nj
local/data/split_data.sh --per-utt $data $nj
sdata=$data/split${nj}utt
$cmd JOB=1:$nj $data/log/get_durations.JOB.log \

View File

@ -60,11 +60,11 @@ nf=`cat $data/feats.scp 2>/dev/null | wc -l`
nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
echo "** use utils/fix_data_dir.sh $data to fix this."
echo "** use local/fix_data_dir.sh $data to fix this."
fi
if [ -f $data/text ] && [ $nu -ne $nt ]; then
echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
echo "** use utils/fix_data_dir.sh to fix this."
echo "** use local/fix_data_dir.sh to fix this."
fi
@ -112,7 +112,7 @@ utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
for n in `seq $numsplit`; do
dsn=$data/split${numsplit}${utt}/$n
utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
local/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
done
maybe_wav_scp=

View File

@ -112,7 +112,7 @@ function filter_recordings {
function filter_speakers {
# throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
for s in cmvn.scp spk2gender; do
@ -123,7 +123,7 @@ function filter_speakers {
done
filter_file $tmpdir/speakers $data/spk2utt
utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
for s in cmvn.scp spk2gender $spk_extra_files; do
f=$data/$s
@ -210,6 +210,6 @@ filter_utts
filter_speakers
filter_recordings
utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
echo "fix_data_dir.sh: old files are kept in $data/.backup"

View File

@ -0,0 +1,243 @@
#!/usr/bin/env python3
import argparse
import logging
from io import BytesIO
from pathlib import Path
from typing import Tuple, Optional
import kaldiio
import humanfriendly
import numpy as np
import resampy
import soundfile
from tqdm import tqdm
from typeguard import check_argument_types
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.read_text import read_2column_text
from funasr.fileio.sound_scp import SoundScpWriter
def humanfriendly_or_none(value: str):
if value in ("none", "None", "NONE"):
return None
return humanfriendly.parse_size(value)
def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
"""
>>> str2int_tuple('3,4,5')
(3, 4, 5)
"""
assert check_argument_types()
if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
return None
return tuple(map(int, integers.strip().split(",")))
def main():
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
logging.info(get_commandline_args())
parser = argparse.ArgumentParser(
description='Create waves list from "wav.scp"',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("scp")
parser.add_argument("outdir")
parser.add_argument(
"--name",
default="wav",
help="Specify the prefix word of output file name " 'such as "wav.scp"',
)
parser.add_argument("--segments", default=None)
parser.add_argument(
"--fs",
type=humanfriendly_or_none,
default=None,
help="If the sampling rate specified, " "Change the sampling rate.",
)
parser.add_argument("--audio-format", default="wav")
group = parser.add_mutually_exclusive_group()
group.add_argument("--ref-channels", default=None, type=str2int_tuple)
group.add_argument("--utt2ref-channels", default=None, type=str)
args = parser.parse_args()
out_num_samples = Path(args.outdir) / f"utt2num_samples"
if args.ref_channels is not None:
def utt2ref_channels(x) -> Tuple[int, ...]:
return args.ref_channels
elif args.utt2ref_channels is not None:
utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
chs_str = d[x]
return tuple(map(int, chs_str.split()))
else:
utt2ref_channels = None
Path(args.outdir).mkdir(parents=True, exist_ok=True)
out_wavscp = Path(args.outdir) / f"{args.name}.scp"
if args.segments is not None:
# Note: kaldiio supports only wav-pcm-int16le file.
loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
if args.audio_format.endswith("ark"):
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
fscp = out_wavscp.open("w")
else:
writer = SoundScpWriter(
args.outdir,
out_wavscp,
format=args.audio_format,
)
with out_num_samples.open("w") as fnum_samples:
for uttid, (rate, wave) in tqdm(loader):
# wave: (Time,) or (Time, Nmic)
if wave.ndim == 2 and utt2ref_channels is not None:
wave = wave[:, utt2ref_channels(uttid)]
if args.fs is not None and args.fs != rate:
# FIXME(kamo): To use sox?
wave = resampy.resample(
wave.astype(np.float64), rate, args.fs, axis=0
)
wave = wave.astype(np.int16)
rate = args.fs
if args.audio_format.endswith("ark"):
if "flac" in args.audio_format:
suf = "flac"
elif "wav" in args.audio_format:
suf = "wav"
else:
raise RuntimeError("wav.ark or flac")
# NOTE(kamo): Using extended ark format style here.
# This format is incompatible with Kaldi
kaldiio.save_ark(
fark,
{uttid: (wave, rate)},
scp=fscp,
append=True,
write_function=f"soundfile_{suf}",
)
else:
writer[uttid] = rate, wave
fnum_samples.write(f"{uttid} {len(wave)}\n")
else:
if args.audio_format.endswith("ark"):
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
else:
wavdir = Path(args.outdir) / f"data_{args.name}"
wavdir.mkdir(parents=True, exist_ok=True)
with Path(args.scp).open("r") as fscp, out_wavscp.open(
"w"
) as fout, out_num_samples.open("w") as fnum_samples:
for line in tqdm(fscp):
uttid, wavpath = line.strip().split(None, 1)
if wavpath.endswith("|"):
# Streaming input e.g. cat a.wav |
with kaldiio.open_like_kaldi(wavpath, "rb") as f:
with BytesIO(f.read()) as g:
wave, rate = soundfile.read(g, dtype=np.int16)
if wave.ndim == 2 and utt2ref_channels is not None:
wave = wave[:, utt2ref_channels(uttid)]
if args.fs is not None and args.fs != rate:
# FIXME(kamo): To use sox?
wave = resampy.resample(
wave.astype(np.float64), rate, args.fs, axis=0
)
wave = wave.astype(np.int16)
rate = args.fs
if args.audio_format.endswith("ark"):
if "flac" in args.audio_format:
suf = "flac"
elif "wav" in args.audio_format:
suf = "wav"
else:
raise RuntimeError("wav.ark or flac")
# NOTE(kamo): Using extended ark format style here.
# This format is incompatible with Kaldi
kaldiio.save_ark(
fark,
{uttid: (wave, rate)},
scp=fout,
append=True,
write_function=f"soundfile_{suf}",
)
else:
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
soundfile.write(owavpath, wave, rate)
fout.write(f"{uttid} {owavpath}\n")
else:
wave, rate = soundfile.read(wavpath, dtype=np.int16)
if wave.ndim == 2 and utt2ref_channels is not None:
wave = wave[:, utt2ref_channels(uttid)]
save_asis = False
elif args.audio_format.endswith("ark"):
save_asis = False
elif Path(wavpath).suffix == "." + args.audio_format and (
args.fs is None or args.fs == rate
):
save_asis = True
else:
save_asis = False
if save_asis:
# Neither --segments nor --fs are specified and
# the line doesn't end with "|",
# i.e. not using unix-pipe,
# only in this case,
# just using the original file as is.
fout.write(f"{uttid} {wavpath}\n")
else:
if args.fs is not None and args.fs != rate:
# FIXME(kamo): To use sox?
wave = resampy.resample(
wave.astype(np.float64), rate, args.fs, axis=0
)
wave = wave.astype(np.int16)
rate = args.fs
if args.audio_format.endswith("ark"):
if "flac" in args.audio_format:
suf = "flac"
elif "wav" in args.audio_format:
suf = "wav"
else:
raise RuntimeError("wav.ark or flac")
# NOTE(kamo): Using extended ark format style here.
# This format is not supported in Kaldi.
kaldiio.save_ark(
fark,
{uttid: (wave, rate)},
scp=fout,
append=True,
write_function=f"soundfile_{suf}",
)
else:
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
soundfile.write(owavpath, wave, rate)
fout.write(f"{uttid} {owavpath}\n")
fnum_samples.write(f"{uttid} {len(wave)}\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,142 @@
#!/usr/bin/env bash
set -euo pipefail
SECONDS=0
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
help_message=$(cat << EOF
Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
e.g.
$0 data/test/wav.scp data/test_format/
Format 'wav.scp': In short words,
changing "kaldi-datadir" to "modified-kaldi-datadir"
The 'wav.scp' format in kaldi is very flexible,
e.g. It can use unix-pipe as describing that wav file,
but it sometime looks confusing and make scripts more complex.
This tools creates actual wav files from 'wav.scp'
and also segments wav files using 'segments'.
Options
--fs <fs>
--segments <segments>
--nj <nj>
--cmd <cmd>
EOF
)
out_filename=wav.scp
cmd=utils/run.pl
nj=30
fs=none
segments=
ref_channels=
utt2ref_channels=
audio_format=wav
write_utt2num_samples=true
log "$0 $*"
. utils/parse_options.sh
if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
log "${help_message}"
log "Error: invalid command line arguments"
exit 1
fi
. ./path.sh # Setup the environment
scp=$1
if [ ! -f "${scp}" ]; then
log "${help_message}"
echo "$0: Error: No such file: ${scp}"
exit 1
fi
dir=$2
if [ $# -eq 2 ]; then
logdir=${dir}/logs
outdir=${dir}/data
elif [ $# -eq 3 ]; then
logdir=$3
outdir=${dir}/data
elif [ $# -eq 4 ]; then
logdir=$3
outdir=$4
fi
mkdir -p ${logdir}
rm -f "${dir}/${out_filename}"
opts=
if [ -n "${utt2ref_channels}" ]; then
opts="--utt2ref-channels ${utt2ref_channels} "
elif [ -n "${ref_channels}" ]; then
opts="--ref-channels ${ref_channels} "
fi
if [ -n "${segments}" ]; then
log "[info]: using ${segments}"
nutt=$(<${segments} wc -l)
nj=$((nj<nutt?nj:nutt))
split_segments=""
for n in $(seq ${nj}); do
split_segments="${split_segments} ${logdir}/segments.${n}"
done
utils/split_scp.pl "${segments}" ${split_segments}
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
local/format_wav_scp.py \
${opts} \
--fs ${fs} \
--audio-format "${audio_format}" \
"--segment=${logdir}/segments.JOB" \
"${scp}" "${outdir}/format.JOB"
else
log "[info]: without segments"
nutt=$(<${scp} wc -l)
nj=$((nj<nutt?nj:nutt))
split_scps=""
for n in $(seq ${nj}); do
split_scps="${split_scps} ${logdir}/wav.${n}.scp"
done
utils/split_scp.pl "${scp}" ${split_scps}
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
local/format_wav_scp.py \
${opts} \
--fs "${fs}" \
--audio-format "${audio_format}" \
"${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
fi
# Workaround for the NFS problem
ls ${outdir}/format.* > /dev/null
# concatenate the .scp files together.
for n in $(seq ${nj}); do
cat "${outdir}/format.${n}/wav.scp" || exit 1;
done > "${dir}/${out_filename}" || exit 1
if "${write_utt2num_samples}"; then
for n in $(seq ${nj}); do
cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
done > "${dir}/utt2num_samples" || exit 1
fi
log "Successfully finished. [elapsed=${SECONDS}s]"

View File

@ -0,0 +1,116 @@
#!/usr/bin/env bash
# 2020 @kamo-naoyuki
# This file was copied from Kaldi and
# I deleted parts related to wav duration
# because we shouldn't use kaldi's command here
# and we don't need the files actually.
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
# 2014 Tom Ko
# 2018 Emotech LTD (author: Pawel Swietojanski)
# Apache 2.0
# This script operates on a directory, such as in data/train/,
# that contains some subset of the following files:
# wav.scp
# spk2utt
# utt2spk
# text
#
# It generates the files which are used for perturbing the speed of the original data.
export LC_ALL=C
set -euo pipefail
if [[ $# != 3 ]]; then
echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
echo "e.g.:"
echo " $0 0.9 data/train_si284 data/train_si284p"
exit 1
fi
factor=$1
srcdir=$2
destdir=$3
label="sp"
spk_prefix="${label}${factor}-"
utt_prefix="${label}${factor}-"
#check is sox on the path
! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
if [[ ! -f ${srcdir}/utt2spk ]]; then
echo "$0: no such file ${srcdir}/utt2spk"
exit 1;
fi
if [[ ${destdir} == "${srcdir}" ]]; then
echo "$0: this script requires <srcdir> and <destdir> to be different."
exit 1
fi
mkdir -p "${destdir}"
<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
if [[ ! -f ${srcdir}/utt2uniq ]]; then
<"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
else
<"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
fi
<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
if [[ -f ${srcdir}/segments ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
awk -v factor="${factor}" \
'{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
>"${destdir}"/segments
utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
> "${destdir}"/wav.scp
if [[ -f ${srcdir}/reco2file_and_channel ]]; then
utils/apply_map.pl -f 1 "${destdir}"/reco_map \
<"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
fi
else # no segments->wav indexed by utterance.
if [[ -f ${srcdir}/wav.scp ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
> "${destdir}"/wav.scp
fi
fi
if [[ -f ${srcdir}/text ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
fi
if [[ -f ${srcdir}/spk2gender ]]; then
utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
fi
if [[ -f ${srcdir}/utt2lang ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
fi
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
utils/validate_data_dir.sh --no-feats --no-text "${destdir}"

View File

@ -113,7 +113,7 @@ fi
check_sorted_and_uniq $data/spk2utt
! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
<(utils/spk2utt_to_utt2spk.pl $data/spk2utt) && \
<(local/spk2utt_to_utt2spk.pl $data/spk2utt) && \
echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
@ -135,7 +135,7 @@ if ! $no_text; then
echo "$0: text contains $n_non_print lines with non-printable characters" &&\
exit 1;
fi
utils/validate_text.pl $data/text || exit 1;
local/validate_text.pl $data/text || exit 1;
check_sorted_and_uniq $data/text
text_len=`cat $data/text | wc -l`
illegal_sym_list="<s> </s> #0"

View File

@ -2,5 +2,4 @@ export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH
export PATH=$PWD/utils/:$PATH
export PATH=$FUNASR_DIR/funasr/bin:$PATH

1
egs/alimeeting/sa-asr/utils Symbolic link
View File

@ -0,0 +1 @@
../../aishell/transformer/utils

View File

@ -1,87 +0,0 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
# Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
$shifted=0;
if ($ARGV[0] eq "--exclude") {
$exclude = 1;
shift @ARGV;
$shifted=1;
}
if ($ARGV[0] eq "-f") {
$field = $ARGV[1];
shift @ARGV; shift @ARGV;
$shifted=1
}
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
"only the lines that were *not* in id_list.\n" .
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
"-f option, add 1 to the argument.\n" .
"See also: utils/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
@A = split;
@A>=1 || die "Invalid id-list file line $_";
$seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
while(<>) {
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
# $1 is what we filter on.
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
print $_;
}
}
} else {
while(<>) {
@A = split;
@A > 0 || die "Invalid scp file line $_";
@A >= $field || die "Invalid scp file line $_";
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
print $_;
}
}
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)

View File

@ -1,97 +0,0 @@
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.

View File

@ -1,246 +0,0 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# See ../../COPYING for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can. If you use the utt2spk
# option it will make sure these chunks coincide with speaker boundaries. In
# this case, if there are more chunks than speakers (and in some other
# circumstances), some of the resulting chunks will be empty and it will print
# an error message and exit with nonzero status.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
# You can also call the scripts like:
# split_scp.pl -j 3 0 scp scp.0
# [note: with this option, it assumes zero-based indexing of the split parts,
# i.e. the second number must be 0 <= n < num-jobs.]
use warnings;
$num_jobs = 0;
$job_id = 0;
$utt2spk_file = "";
$one_based = 0;
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
if ($ARGV[0] eq "-j") {
shift @ARGV;
$num_jobs = shift @ARGV;
$job_id = shift @ARGV;
}
if ($ARGV[0] =~ /--utt2spk=(.+)/) {
$utt2spk_file=$1;
shift;
}
if ($ARGV[0] eq '--one-based') {
$one_based = 1;
shift @ARGV;
}
}
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
$job_id - $one_based >= $num_jobs)) {
die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
($one_based ? " --one-based" : "") . "'\n"
}
$one_based
and $job_id--;
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
die
"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
}
$error = 0;
$inscp = shift @ARGV;
if ($num_jobs == 0) { # without -j option
@OUTPUTS = @ARGV;
} else {
for ($j = 0; $j < $num_jobs; $j++) {
if ($j == $job_id) {
if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
else { push @OUTPUTS, "-"; }
} else {
push @OUTPUTS, "/dev/null";
}
}
}
if ($utt2spk_file ne "") { # We have the --utt2spk option...
open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
while(<$u_fh>) {
@A = split;
@A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
($u,$s) = @A;
$utt2spk{$u} = $s;
}
close $u_fh;
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
@spkrs = ();
while(<$i_fh>) {
@A = split;
if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
$u = $A[0];
$s = $utt2spk{$u};
defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
if(!defined $spk_count{$s}) {
push @spkrs, $s;
$spk_count{$s} = 0;
$spk_data{$s} = []; # ref to new empty array.
}
$spk_count{$s}++;
push @{$spk_data{$s}}, $_;
}
# Now split as equally as possible ..
# First allocate spks to files by allocating an approximately
# equal number of speakers.
$numspks = @spkrs; # number of speakers.
$numscps = @OUTPUTS; # number of output files.
if ($numspks < $numscps) {
die "$0: Refusing to split data because number of speakers $numspks " .
"is less than the number of output .scp files $numscps\n";
}
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scparray[$scpidx] = []; # [] is array reference.
}
for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
$scpidx = int(($spkidx*$numscps) / $numspks);
$spk = $spkrs[$spkidx];
push @{$scparray[$scpidx]}, $spk;
$scpcount[$scpidx] += $spk_count{$spk};
}
# Now will try to reassign beginning + ending speakers
# to different scp's and see if it gets more balanced.
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
# We can show that if considering changing just 2 scp's, we minimize
# this by minimizing the squared difference in sizes. This is
# equivalent to minimizing the absolute difference in sizes. This
# shows this method is bound to converge.
$changed = 1;
while($changed) {
$changed = 0;
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
# First try to reassign ending spk of this scp.
if($scpidx < $numscps-1) {
$sz = @{$scparray[$scpidx]};
if($sz > 0) {
$spk = $scparray[$scpidx]->[$sz-1];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx];
$nutt2 = $scpcount[$scpidx+1];
if( abs( ($nutt2+$count) - ($nutt1-$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx+1] += $count;
$scpcount[$scpidx] -= $count;
pop @{$scparray[$scpidx]};
unshift @{$scparray[$scpidx+1]}, $spk;
$changed = 1;
}
}
}
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
$spk = $scparray[$scpidx]->[0];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx-1];
$nutt2 = $scpcount[$scpidx];
if( abs( ($nutt2-$count) - ($nutt1+$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx-1] += $count;
$scpcount[$scpidx] -= $count;
shift @{$scparray[$scpidx]};
push @{$scparray[$scpidx-1]}, $spk;
$changed = 1;
}
}
}
}
# Now print out the files...
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($f_fh, '>', $scpfile)
: open($f_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
$count = 0;
if(@{$scparray[$scpidx]} == 0) {
print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
"$scpfile (too many splits and too few speakers?)\n";
$error = 1;
} else {
foreach $spk ( @{$scparray[$scpidx]} ) {
print $f_fh @{$spk_data{$spk}};
$count += $spk_count{$spk};
}
$count == $scpcount[$scpidx] || die "Count mismatch [code error]";
}
close($f_fh);
}
} else {
# This block is the "normal" case where there is no --utt2spk
# option and we just break into equal size chunks.
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
$numscps = @OUTPUTS; # size of array.
@F = ();
while(<$i_fh>) {
push @F, $_;
}
$numlines = @F;
if($numlines == 0) {
print STDERR "$0: error: empty input scp file $inscp\n";
$error = 1;
}
$linesperscp = int( $numlines / $numscps); # the "whole part"..
$linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
$remainder = $numlines - ($linesperscp * $numscps);
($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
# [just doing int() rounds down].
$n = 0;
for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($o_fh, '>', $scpfile)
: open($o_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
print $o_fh $F[$n++];
}
close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
}
$n == $numlines || die "$n != $numlines [code error]";
}
exit ($error);

View File

@ -40,6 +40,8 @@ from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@ -90,6 +92,12 @@ class Speech2Text:
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
if asr_train_args.frontend=='wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
else:
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
@ -197,12 +205,21 @@ class Speech2Text:
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
batch = {"speech": speech, "speech_lengths": speech_lengths}
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
@ -275,6 +292,7 @@ def inference(
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
mc: bool = False,
**kwargs,
):
inference_pipeline = inference_modelscope(
@ -305,6 +323,7 @@ def inference(
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
mc=mc,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@ -337,6 +356,7 @@ def inference_modelscope(
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
mc: bool = False,
param_dict: dict = None,
**kwargs,
):
@ -406,7 +426,7 @@ def inference_modelscope(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=True,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
@ -415,7 +435,7 @@ def inference_modelscope(
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop

View File

@ -71,7 +71,13 @@ def get_parser():
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group.add_argument(
"--mc",
type=bool,
default=False,
help="MultiChannel input",
)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",

View File

@ -2,14 +2,6 @@
import os
import logging
logging.basicConfig(
level='INFO',
format=f"[{os.uname()[1].split('.')[0]}]"
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
from funasr.tasks.asr import ASRTask

View File

@ -35,6 +35,8 @@ from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@ -85,6 +87,12 @@ class Speech2Text:
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
if asr_train_args.frontend=='wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
else:
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
@ -201,7 +209,16 @@ class Speech2Text:
if isinstance(profile, np.ndarray):
profile = torch.tensor(profile)
batch = {"speech": speech, "speech_lengths": speech_lengths}
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
@ -308,6 +325,7 @@ def inference(
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
mc: bool = False,
**kwargs,
):
inference_pipeline = inference_modelscope(
@ -338,6 +356,7 @@ def inference(
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
mc=mc,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@ -370,6 +389,7 @@ def inference_modelscope(
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
mc: bool = False,
param_dict: dict = None,
**kwargs,
):
@ -437,7 +457,7 @@ def inference_modelscope(
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=True,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,

View File

@ -2,14 +2,6 @@
import os
import logging
logging.basicConfig(
level='INFO',
format=f"[{os.uname()[1].split('.')[0]}]"
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
from funasr.tasks.sa_asr import ASRTask

View File

@ -79,3 +79,49 @@ class SequenceBinaryCrossEntropy(nn.Module):
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask, 0).sum() / denom
class NllLoss(nn.Module):
"""Nll loss.
:param int size: the number of class
:param int padding_idx: ignored class id
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function
"""
def __init__(
self,
size,
padding_idx,
normalize_length=False,
criterion=nn.NLLLoss(reduction='none'),
):
"""Construct an LabelSmoothingLoss object."""
super(NllLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
kl = self.criterion(x , target)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore, 0).sum() / denom

View File

@ -13,6 +13,7 @@ from typeguard import check_argument_types
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.attention import CosineDistanceAttention
from funasr.modules.dynamic_conv import DynamicConvolution
from funasr.modules.dynamic_conv2d import DynamicConvolution2D
from funasr.modules.embedding import PositionalEncoding
@ -763,4 +764,429 @@ class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
normalize_before,
concat_after,
),
)
)
class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
spker_embedding_dim: int = 256,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
input_layer: str = "embed",
use_asr_output_layer: bool = True,
use_spk_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_asr_output_layer:
self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.asr_output_layer = None
if use_spk_output_layer:
self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
else:
self.spk_output_layer = None
self.cos_distance_att = CosineDistanceAttention()
self.decoder1 = None
self.decoder2 = None
self.decoder3 = None
self.decoder4 = None
def forward(
self,
asr_hs_pad: torch.Tensor,
spk_hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
profile: torch.Tensor,
profile_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
asr_memory = asr_hs_pad
spk_memory = spk_hs_pad
memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
# Spk decoder
x = self.embed(tgt)
x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
x, tgt_mask, asr_memory, spk_memory, memory_mask
)
x, tgt_mask, spk_memory, memory_mask = self.decoder2(
x, tgt_mask, spk_memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.spk_output_layer is not None:
x = self.spk_output_layer(x)
dn, weights = self.cos_distance_att(x, profile, profile_lens)
# Asr decoder
x, tgt_mask, asr_memory, memory_mask = self.decoder3(
z, tgt_mask, asr_memory, memory_mask, dn
)
x, tgt_mask, asr_memory, memory_mask = self.decoder4(
x, tgt_mask, asr_memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.asr_output_layer is not None:
x = self.asr_output_layer(x)
olens = tgt_mask.sum(1)
return x, weights, olens
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
asr_memory: torch.Tensor,
spk_memory: torch.Tensor,
profile: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x = self.embed(tgt)
if cache is None:
cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
new_cache = []
x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
)
new_cache.append(x)
for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
x, tgt_mask, spk_memory, _ = decoder(
x, tgt_mask, spk_memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
x = self.after_norm(x)
else:
x = x
if self.spk_output_layer is not None:
x = self.spk_output_layer(x)
dn, weights = self.cos_distance_att(x, profile, None)
x, tgt_mask, asr_memory, _ = self.decoder3(
z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
)
new_cache.append(x)
for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
x, tgt_mask, asr_memory, _ = decoder(
x, tgt_mask, asr_memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.asr_output_layer is not None:
y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
return y, weights, new_cache
def score(self, ys, state, asr_enc, spk_enc, profile):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
logp, weights, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
)
return logp.squeeze(0), weights.squeeze(), state
class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
spker_embedding_dim: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
asr_num_blocks: int = 6,
spk_num_blocks: int = 3,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_asr_output_layer: bool = True,
use_spk_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
spker_embedding_dim=spker_embedding_dim,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_asr_output_layer=use_asr_output_layer,
use_spk_output_layer=use_spk_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
)
self.decoder2 = repeat(
spk_num_blocks - 1,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
attention_dim,
spker_embedding_dim,
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
)
self.decoder4 = repeat(
asr_num_blocks - 1,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
z = x
residual = x
if self.normalize_before:
x = self.norm1(x)
skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
if self.concat_after:
x_concat = torch.cat(
(x, skip), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(skip)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
def __init__(
self,
size,
d_size,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
self.size = size
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
self.spk_linear = nn.Linear(d_size, size, bias=False)
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
x = tgt_q
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if dn!=None:
x = x + self.spk_linear(dn)
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask

View File

@ -16,9 +16,8 @@ from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
LabelSmoothingLoss, NllLoss # noqa: H301
)
from funasr.losses.nll_loss import NllLoss
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder

View File

@ -28,7 +28,7 @@ from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecode
from funasr.models.decoder.transformer_decoder import (
DynamicConvolution2DTransformerDecoder, # noqa: H301
)
from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
from funasr.models.decoder.transformer_decoder import (
LightweightConvolution2DTransformerDecoder, # noqa: H301