mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add speaker-attributed ASR task for alimeeting
This commit is contained in:
parent
d76aea23d9
commit
49f13908de
@ -1,243 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import kaldiio
|
||||
import humanfriendly
|
||||
import numpy as np
|
||||
import resampy
|
||||
import soundfile
|
||||
from tqdm import tqdm
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.fileio.read_text import read_2column_text
|
||||
from funasr.fileio.sound_scp import SoundScpWriter
|
||||
|
||||
|
||||
def humanfriendly_or_none(value: str):
|
||||
if value in ("none", "None", "NONE"):
|
||||
return None
|
||||
return humanfriendly.parse_size(value)
|
||||
|
||||
|
||||
def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
|
||||
"""
|
||||
|
||||
>>> str2int_tuple('3,4,5')
|
||||
(3, 4, 5)
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
|
||||
return None
|
||||
return tuple(map(int, integers.strip().split(",")))
|
||||
|
||||
|
||||
def main():
|
||||
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=logfmt)
|
||||
logging.info(get_commandline_args())
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Create waves list from "wav.scp"',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("scp")
|
||||
parser.add_argument("outdir")
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
default="wav",
|
||||
help="Specify the prefix word of output file name " 'such as "wav.scp"',
|
||||
)
|
||||
parser.add_argument("--segments", default=None)
|
||||
parser.add_argument(
|
||||
"--fs",
|
||||
type=humanfriendly_or_none,
|
||||
default=None,
|
||||
help="If the sampling rate specified, " "Change the sampling rate.",
|
||||
)
|
||||
parser.add_argument("--audio-format", default="wav")
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--ref-channels", default=None, type=str2int_tuple)
|
||||
group.add_argument("--utt2ref-channels", default=None, type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
out_num_samples = Path(args.outdir) / f"utt2num_samples"
|
||||
|
||||
if args.ref_channels is not None:
|
||||
|
||||
def utt2ref_channels(x) -> Tuple[int, ...]:
|
||||
return args.ref_channels
|
||||
|
||||
elif args.utt2ref_channels is not None:
|
||||
utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
|
||||
|
||||
def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
|
||||
chs_str = d[x]
|
||||
return tuple(map(int, chs_str.split()))
|
||||
|
||||
else:
|
||||
utt2ref_channels = None
|
||||
|
||||
Path(args.outdir).mkdir(parents=True, exist_ok=True)
|
||||
out_wavscp = Path(args.outdir) / f"{args.name}.scp"
|
||||
if args.segments is not None:
|
||||
# Note: kaldiio supports only wav-pcm-int16le file.
|
||||
loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
|
||||
if args.audio_format.endswith("ark"):
|
||||
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
|
||||
fscp = out_wavscp.open("w")
|
||||
else:
|
||||
writer = SoundScpWriter(
|
||||
args.outdir,
|
||||
out_wavscp,
|
||||
format=args.audio_format,
|
||||
)
|
||||
|
||||
with out_num_samples.open("w") as fnum_samples:
|
||||
for uttid, (rate, wave) in tqdm(loader):
|
||||
# wave: (Time,) or (Time, Nmic)
|
||||
if wave.ndim == 2 and utt2ref_channels is not None:
|
||||
wave = wave[:, utt2ref_channels(uttid)]
|
||||
|
||||
if args.fs is not None and args.fs != rate:
|
||||
# FIXME(kamo): To use sox?
|
||||
wave = resampy.resample(
|
||||
wave.astype(np.float64), rate, args.fs, axis=0
|
||||
)
|
||||
wave = wave.astype(np.int16)
|
||||
rate = args.fs
|
||||
if args.audio_format.endswith("ark"):
|
||||
if "flac" in args.audio_format:
|
||||
suf = "flac"
|
||||
elif "wav" in args.audio_format:
|
||||
suf = "wav"
|
||||
else:
|
||||
raise RuntimeError("wav.ark or flac")
|
||||
|
||||
# NOTE(kamo): Using extended ark format style here.
|
||||
# This format is incompatible with Kaldi
|
||||
kaldiio.save_ark(
|
||||
fark,
|
||||
{uttid: (wave, rate)},
|
||||
scp=fscp,
|
||||
append=True,
|
||||
write_function=f"soundfile_{suf}",
|
||||
)
|
||||
|
||||
else:
|
||||
writer[uttid] = rate, wave
|
||||
fnum_samples.write(f"{uttid} {len(wave)}\n")
|
||||
else:
|
||||
if args.audio_format.endswith("ark"):
|
||||
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
|
||||
else:
|
||||
wavdir = Path(args.outdir) / f"data_{args.name}"
|
||||
wavdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with Path(args.scp).open("r") as fscp, out_wavscp.open(
|
||||
"w"
|
||||
) as fout, out_num_samples.open("w") as fnum_samples:
|
||||
for line in tqdm(fscp):
|
||||
uttid, wavpath = line.strip().split(None, 1)
|
||||
|
||||
if wavpath.endswith("|"):
|
||||
# Streaming input e.g. cat a.wav |
|
||||
with kaldiio.open_like_kaldi(wavpath, "rb") as f:
|
||||
with BytesIO(f.read()) as g:
|
||||
wave, rate = soundfile.read(g, dtype=np.int16)
|
||||
if wave.ndim == 2 and utt2ref_channels is not None:
|
||||
wave = wave[:, utt2ref_channels(uttid)]
|
||||
|
||||
if args.fs is not None and args.fs != rate:
|
||||
# FIXME(kamo): To use sox?
|
||||
wave = resampy.resample(
|
||||
wave.astype(np.float64), rate, args.fs, axis=0
|
||||
)
|
||||
wave = wave.astype(np.int16)
|
||||
rate = args.fs
|
||||
|
||||
if args.audio_format.endswith("ark"):
|
||||
if "flac" in args.audio_format:
|
||||
suf = "flac"
|
||||
elif "wav" in args.audio_format:
|
||||
suf = "wav"
|
||||
else:
|
||||
raise RuntimeError("wav.ark or flac")
|
||||
|
||||
# NOTE(kamo): Using extended ark format style here.
|
||||
# This format is incompatible with Kaldi
|
||||
kaldiio.save_ark(
|
||||
fark,
|
||||
{uttid: (wave, rate)},
|
||||
scp=fout,
|
||||
append=True,
|
||||
write_function=f"soundfile_{suf}",
|
||||
)
|
||||
else:
|
||||
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
|
||||
soundfile.write(owavpath, wave, rate)
|
||||
fout.write(f"{uttid} {owavpath}\n")
|
||||
else:
|
||||
wave, rate = soundfile.read(wavpath, dtype=np.int16)
|
||||
if wave.ndim == 2 and utt2ref_channels is not None:
|
||||
wave = wave[:, utt2ref_channels(uttid)]
|
||||
save_asis = False
|
||||
|
||||
elif args.audio_format.endswith("ark"):
|
||||
save_asis = False
|
||||
|
||||
elif Path(wavpath).suffix == "." + args.audio_format and (
|
||||
args.fs is None or args.fs == rate
|
||||
):
|
||||
save_asis = True
|
||||
|
||||
else:
|
||||
save_asis = False
|
||||
|
||||
if save_asis:
|
||||
# Neither --segments nor --fs are specified and
|
||||
# the line doesn't end with "|",
|
||||
# i.e. not using unix-pipe,
|
||||
# only in this case,
|
||||
# just using the original file as is.
|
||||
fout.write(f"{uttid} {wavpath}\n")
|
||||
else:
|
||||
if args.fs is not None and args.fs != rate:
|
||||
# FIXME(kamo): To use sox?
|
||||
wave = resampy.resample(
|
||||
wave.astype(np.float64), rate, args.fs, axis=0
|
||||
)
|
||||
wave = wave.astype(np.int16)
|
||||
rate = args.fs
|
||||
|
||||
if args.audio_format.endswith("ark"):
|
||||
if "flac" in args.audio_format:
|
||||
suf = "flac"
|
||||
elif "wav" in args.audio_format:
|
||||
suf = "wav"
|
||||
else:
|
||||
raise RuntimeError("wav.ark or flac")
|
||||
|
||||
# NOTE(kamo): Using extended ark format style here.
|
||||
# This format is not supported in Kaldi.
|
||||
kaldiio.save_ark(
|
||||
fark,
|
||||
{uttid: (wave, rate)},
|
||||
scp=fout,
|
||||
append=True,
|
||||
write_function=f"soundfile_{suf}",
|
||||
)
|
||||
else:
|
||||
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
|
||||
soundfile.write(owavpath, wave, rate)
|
||||
fout.write(f"{uttid} {owavpath}\n")
|
||||
fnum_samples.write(f"{uttid} {len(wave)}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,45 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
|
||||
|
||||
def get_commandline_args(no_executable=True):
|
||||
extra_chars = [
|
||||
" ",
|
||||
";",
|
||||
"&",
|
||||
"|",
|
||||
"<",
|
||||
">",
|
||||
"?",
|
||||
"*",
|
||||
"~",
|
||||
"`",
|
||||
'"',
|
||||
"'",
|
||||
"\\",
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
]
|
||||
|
||||
# Escape the extra characters for shell
|
||||
argv = [
|
||||
arg.replace("'", "'\\''")
|
||||
if all(char not in arg for char in extra_chars)
|
||||
else "'" + arg.replace("'", "'\\''") + "'"
|
||||
for arg in sys.argv
|
||||
]
|
||||
|
||||
if no_executable:
|
||||
return " ".join(argv[1:])
|
||||
else:
|
||||
return sys.executable + " " + " ".join(argv)
|
||||
|
||||
|
||||
def main():
|
||||
print(get_commandline_args())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,142 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
SECONDS=0
|
||||
log() {
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
help_message=$(cat << EOF
|
||||
Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
|
||||
e.g.
|
||||
$0 data/test/wav.scp data/test_format/
|
||||
|
||||
Format 'wav.scp': In short words,
|
||||
changing "kaldi-datadir" to "modified-kaldi-datadir"
|
||||
|
||||
The 'wav.scp' format in kaldi is very flexible,
|
||||
e.g. It can use unix-pipe as describing that wav file,
|
||||
but it sometime looks confusing and make scripts more complex.
|
||||
This tools creates actual wav files from 'wav.scp'
|
||||
and also segments wav files using 'segments'.
|
||||
|
||||
Options
|
||||
--fs <fs>
|
||||
--segments <segments>
|
||||
--nj <nj>
|
||||
--cmd <cmd>
|
||||
EOF
|
||||
)
|
||||
|
||||
out_filename=wav.scp
|
||||
cmd=utils/run.pl
|
||||
nj=30
|
||||
fs=none
|
||||
segments=
|
||||
|
||||
ref_channels=
|
||||
utt2ref_channels=
|
||||
|
||||
audio_format=wav
|
||||
write_utt2num_samples=true
|
||||
|
||||
log "$0 $*"
|
||||
. utils/parse_options.sh
|
||||
|
||||
if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
|
||||
log "${help_message}"
|
||||
log "Error: invalid command line arguments"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
. ./path.sh # Setup the environment
|
||||
|
||||
scp=$1
|
||||
if [ ! -f "${scp}" ]; then
|
||||
log "${help_message}"
|
||||
echo "$0: Error: No such file: ${scp}"
|
||||
exit 1
|
||||
fi
|
||||
dir=$2
|
||||
|
||||
|
||||
if [ $# -eq 2 ]; then
|
||||
logdir=${dir}/logs
|
||||
outdir=${dir}/data
|
||||
|
||||
elif [ $# -eq 3 ]; then
|
||||
logdir=$3
|
||||
outdir=${dir}/data
|
||||
|
||||
elif [ $# -eq 4 ]; then
|
||||
logdir=$3
|
||||
outdir=$4
|
||||
fi
|
||||
|
||||
|
||||
mkdir -p ${logdir}
|
||||
|
||||
rm -f "${dir}/${out_filename}"
|
||||
|
||||
|
||||
opts=
|
||||
if [ -n "${utt2ref_channels}" ]; then
|
||||
opts="--utt2ref-channels ${utt2ref_channels} "
|
||||
elif [ -n "${ref_channels}" ]; then
|
||||
opts="--ref-channels ${ref_channels} "
|
||||
fi
|
||||
|
||||
|
||||
if [ -n "${segments}" ]; then
|
||||
log "[info]: using ${segments}"
|
||||
nutt=$(<${segments} wc -l)
|
||||
nj=$((nj<nutt?nj:nutt))
|
||||
|
||||
split_segments=""
|
||||
for n in $(seq ${nj}); do
|
||||
split_segments="${split_segments} ${logdir}/segments.${n}"
|
||||
done
|
||||
|
||||
utils/split_scp.pl "${segments}" ${split_segments}
|
||||
|
||||
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
|
||||
pyscripts/audio/format_wav_scp.py \
|
||||
${opts} \
|
||||
--fs ${fs} \
|
||||
--audio-format "${audio_format}" \
|
||||
"--segment=${logdir}/segments.JOB" \
|
||||
"${scp}" "${outdir}/format.JOB"
|
||||
|
||||
else
|
||||
log "[info]: without segments"
|
||||
nutt=$(<${scp} wc -l)
|
||||
nj=$((nj<nutt?nj:nutt))
|
||||
|
||||
split_scps=""
|
||||
for n in $(seq ${nj}); do
|
||||
split_scps="${split_scps} ${logdir}/wav.${n}.scp"
|
||||
done
|
||||
|
||||
utils/split_scp.pl "${scp}" ${split_scps}
|
||||
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
|
||||
pyscripts/audio/format_wav_scp.py \
|
||||
${opts} \
|
||||
--fs "${fs}" \
|
||||
--audio-format "${audio_format}" \
|
||||
"${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
|
||||
fi
|
||||
|
||||
# Workaround for the NFS problem
|
||||
ls ${outdir}/format.* > /dev/null
|
||||
|
||||
# concatenate the .scp files together.
|
||||
for n in $(seq ${nj}); do
|
||||
cat "${outdir}/format.${n}/wav.scp" || exit 1;
|
||||
done > "${dir}/${out_filename}" || exit 1
|
||||
|
||||
if "${write_utt2num_samples}"; then
|
||||
for n in $(seq ${nj}); do
|
||||
cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
|
||||
done > "${dir}/utt2num_samples" || exit 1
|
||||
fi
|
||||
|
||||
log "Successfully finished. [elapsed=${SECONDS}s]"
|
||||
@ -1,116 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# 2020 @kamo-naoyuki
|
||||
# This file was copied from Kaldi and
|
||||
# I deleted parts related to wav duration
|
||||
# because we shouldn't use kaldi's command here
|
||||
# and we don't need the files actually.
|
||||
|
||||
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
|
||||
# 2014 Tom Ko
|
||||
# 2018 Emotech LTD (author: Pawel Swietojanski)
|
||||
# Apache 2.0
|
||||
|
||||
# This script operates on a directory, such as in data/train/,
|
||||
# that contains some subset of the following files:
|
||||
# wav.scp
|
||||
# spk2utt
|
||||
# utt2spk
|
||||
# text
|
||||
#
|
||||
# It generates the files which are used for perturbing the speed of the original data.
|
||||
|
||||
export LC_ALL=C
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# != 3 ]]; then
|
||||
echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
|
||||
echo "e.g.:"
|
||||
echo " $0 0.9 data/train_si284 data/train_si284p"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
factor=$1
|
||||
srcdir=$2
|
||||
destdir=$3
|
||||
label="sp"
|
||||
spk_prefix="${label}${factor}-"
|
||||
utt_prefix="${label}${factor}-"
|
||||
|
||||
#check is sox on the path
|
||||
|
||||
! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
|
||||
|
||||
if [[ ! -f ${srcdir}/utt2spk ]]; then
|
||||
echo "$0: no such file ${srcdir}/utt2spk"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [[ ${destdir} == "${srcdir}" ]]; then
|
||||
echo "$0: this script requires <srcdir> and <destdir> to be different."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "${destdir}"
|
||||
|
||||
<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
|
||||
<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
|
||||
<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
|
||||
if [[ ! -f ${srcdir}/utt2uniq ]]; then
|
||||
<"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
|
||||
else
|
||||
<"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
|
||||
fi
|
||||
|
||||
|
||||
<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
|
||||
utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
|
||||
|
||||
utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
|
||||
|
||||
if [[ -f ${srcdir}/segments ]]; then
|
||||
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
|
||||
utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
|
||||
awk -v factor="${factor}" \
|
||||
'{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
|
||||
>"${destdir}"/segments
|
||||
|
||||
utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
|
||||
awk -v factor="${factor}" \
|
||||
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
|
||||
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
|
||||
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
|
||||
> "${destdir}"/wav.scp
|
||||
if [[ -f ${srcdir}/reco2file_and_channel ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/reco_map \
|
||||
<"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
|
||||
fi
|
||||
|
||||
else # no segments->wav indexed by utterance.
|
||||
if [[ -f ${srcdir}/wav.scp ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
|
||||
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
|
||||
awk -v factor="${factor}" \
|
||||
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
|
||||
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
|
||||
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
|
||||
> "${destdir}"/wav.scp
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -f ${srcdir}/text ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
|
||||
fi
|
||||
if [[ -f ${srcdir}/spk2gender ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
|
||||
fi
|
||||
if [[ -f ${srcdir}/utt2lang ]]; then
|
||||
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
|
||||
fi
|
||||
|
||||
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
|
||||
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
|
||||
|
||||
utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
|
||||
@ -1,47 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class NllLoss(nn.Module):
|
||||
"""Nll loss.
|
||||
|
||||
:param int size: the number of class
|
||||
:param int padding_idx: ignored class id
|
||||
:param bool normalize_length: normalize loss by sequence length if True
|
||||
:param torch.nn.Module criterion: loss function
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
padding_idx,
|
||||
normalize_length=False,
|
||||
criterion=nn.NLLLoss(reduction='none'),
|
||||
):
|
||||
"""Construct an LabelSmoothingLoss object."""
|
||||
super(NllLoss, self).__init__()
|
||||
self.criterion = criterion
|
||||
self.padding_idx = padding_idx
|
||||
self.size = size
|
||||
self.true_dist = None
|
||||
self.normalize_length = normalize_length
|
||||
|
||||
def forward(self, x, target):
|
||||
"""Compute loss between x and target.
|
||||
|
||||
:param torch.Tensor x: prediction (batch, seqlen, class)
|
||||
:param torch.Tensor target:
|
||||
target signal masked with self.padding_id (batch, seqlen)
|
||||
:return: scalar float value
|
||||
:rtype torch.Tensor
|
||||
"""
|
||||
assert x.size(2) == self.size
|
||||
batch_size = x.size(0)
|
||||
x = x.view(-1, self.size)
|
||||
target = target.view(-1)
|
||||
with torch.no_grad():
|
||||
ignore = target == self.padding_idx # (B,)
|
||||
total = len(target) - ignore.sum().item()
|
||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
||||
kl = self.criterion(x , target)
|
||||
denom = total if self.normalize_length else batch_size
|
||||
return kl.masked_fill(ignore, 0).sum() / denom
|
||||
@ -1,169 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
|
||||
|
||||
class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
|
||||
self.size = size
|
||||
self.self_attn = self_attn
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
|
||||
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
# compute only the last frame query keeping dim: max_time_out -> 1
|
||||
assert cache.shape == (
|
||||
tgt.shape[0],
|
||||
tgt.shape[1] - 1,
|
||||
self.size,
|
||||
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
tgt_concat = torch.cat(
|
||||
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear1(tgt_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
z = x
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, skip), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(skip)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
|
||||
|
||||
class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
d_size,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
|
||||
self.size = size
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
self.spk_linear = nn.Linear(d_size, size, bias=False)
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
|
||||
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
x = tgt_q
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
residual = x
|
||||
|
||||
if dn!=None:
|
||||
x = x + self.spk_linear(dn)
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, memory, memory_mask
|
||||
|
||||
|
||||
|
||||
@ -1,291 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadedAttention
|
||||
from funasr.modules.attention import CosineDistanceAttention
|
||||
from funasr.models.decoder.transformer_decoder import DecoderLayer
|
||||
from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer
|
||||
from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer
|
||||
from funasr.modules.dynamic_conv import DynamicConvolution
|
||||
from funasr.modules.dynamic_conv2d import DynamicConvolution2D
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.lightconv import LightweightConvolution
|
||||
from funasr.modules.lightconv2d import LightweightConvolution2D
|
||||
from funasr.modules.mask import subsequent_mask
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
|
||||
class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
spker_embedding_dim: int = 256,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
input_layer: str = "embed",
|
||||
use_asr_output_layer: bool = True,
|
||||
use_spk_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
if input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(vocab_size, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
if use_asr_output_layer:
|
||||
self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.asr_output_layer = None
|
||||
|
||||
if use_spk_output_layer:
|
||||
self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
|
||||
else:
|
||||
self.spk_output_layer = None
|
||||
|
||||
self.cos_distance_att = CosineDistanceAttention()
|
||||
|
||||
self.decoder1 = None
|
||||
self.decoder2 = None
|
||||
self.decoder3 = None
|
||||
self.decoder4 = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
asr_hs_pad: torch.Tensor,
|
||||
spk_hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
profile_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask & m
|
||||
|
||||
asr_memory = asr_hs_pad
|
||||
spk_memory = spk_hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
|
||||
# Spk decoder
|
||||
x = self.embed(tgt)
|
||||
|
||||
x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
|
||||
x, tgt_mask, asr_memory, spk_memory, memory_mask
|
||||
)
|
||||
x, tgt_mask, spk_memory, memory_mask = self.decoder2(
|
||||
x, tgt_mask, spk_memory, memory_mask
|
||||
)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.spk_output_layer is not None:
|
||||
x = self.spk_output_layer(x)
|
||||
dn, weights = self.cos_distance_att(x, profile, profile_lens)
|
||||
# Asr decoder
|
||||
x, tgt_mask, asr_memory, memory_mask = self.decoder3(
|
||||
z, tgt_mask, asr_memory, memory_mask, dn
|
||||
)
|
||||
x, tgt_mask, asr_memory, memory_mask = self.decoder4(
|
||||
x, tgt_mask, asr_memory, memory_mask
|
||||
)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.asr_output_layer is not None:
|
||||
x = self.asr_output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, weights, olens
|
||||
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
asr_memory: torch.Tensor,
|
||||
spk_memory: torch.Tensor,
|
||||
profile: torch.Tensor,
|
||||
cache: List[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
x = self.embed(tgt)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
|
||||
new_cache = []
|
||||
x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
|
||||
x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
|
||||
)
|
||||
new_cache.append(x)
|
||||
for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
|
||||
x, tgt_mask, spk_memory, _ = decoder(
|
||||
x, tgt_mask, spk_memory, None, cache=c
|
||||
)
|
||||
new_cache.append(x)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
else:
|
||||
x = x
|
||||
if self.spk_output_layer is not None:
|
||||
x = self.spk_output_layer(x)
|
||||
dn, weights = self.cos_distance_att(x, profile, None)
|
||||
|
||||
x, tgt_mask, asr_memory, _ = self.decoder3(
|
||||
z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
|
||||
)
|
||||
new_cache.append(x)
|
||||
|
||||
for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
|
||||
x, tgt_mask, asr_memory, _ = decoder(
|
||||
x, tgt_mask, asr_memory, None, cache=c
|
||||
)
|
||||
new_cache.append(x)
|
||||
|
||||
if self.normalize_before:
|
||||
y = self.after_norm(x[:, -1])
|
||||
else:
|
||||
y = x[:, -1]
|
||||
if self.asr_output_layer is not None:
|
||||
y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
|
||||
|
||||
return y, weights, new_cache
|
||||
|
||||
def score(self, ys, state, asr_enc, spk_enc, profile):
|
||||
"""Score."""
|
||||
ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
|
||||
logp, weights, state = self.forward_one_step(
|
||||
ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
|
||||
)
|
||||
return logp.squeeze(0), weights.squeeze(), state
|
||||
|
||||
class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
spker_embedding_dim: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
asr_num_blocks: int = 6,
|
||||
spk_num_blocks: int = 3,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_asr_output_layer: bool = True,
|
||||
use_spk_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
spker_embedding_dim=spker_embedding_dim,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_asr_output_layer=use_asr_output_layer,
|
||||
use_spk_output_layer=use_spk_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
self.decoder2 = repeat(
|
||||
spk_num_blocks - 1,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
|
||||
attention_dim,
|
||||
spker_embedding_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
self.decoder4 = repeat(
|
||||
asr_num_blocks - 1,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user