add speaker-attributed ASR task for alimeeting

This commit is contained in:
smohan-speech 2023-05-07 02:27:58 +08:00
parent d76aea23d9
commit 49f13908de
7 changed files with 0 additions and 1053 deletions

View File

@ -1,243 +0,0 @@
#!/usr/bin/env python3
import argparse
import logging
from io import BytesIO
from pathlib import Path
from typing import Tuple, Optional
import kaldiio
import humanfriendly
import numpy as np
import resampy
import soundfile
from tqdm import tqdm
from typeguard import check_argument_types
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.read_text import read_2column_text
from funasr.fileio.sound_scp import SoundScpWriter
def humanfriendly_or_none(value: str):
if value in ("none", "None", "NONE"):
return None
return humanfriendly.parse_size(value)
def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
"""
>>> str2int_tuple('3,4,5')
(3, 4, 5)
"""
assert check_argument_types()
if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
return None
return tuple(map(int, integers.strip().split(",")))
def main():
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
logging.basicConfig(level=logging.INFO, format=logfmt)
logging.info(get_commandline_args())
parser = argparse.ArgumentParser(
description='Create waves list from "wav.scp"',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("scp")
parser.add_argument("outdir")
parser.add_argument(
"--name",
default="wav",
help="Specify the prefix word of output file name " 'such as "wav.scp"',
)
parser.add_argument("--segments", default=None)
parser.add_argument(
"--fs",
type=humanfriendly_or_none,
default=None,
help="If the sampling rate specified, " "Change the sampling rate.",
)
parser.add_argument("--audio-format", default="wav")
group = parser.add_mutually_exclusive_group()
group.add_argument("--ref-channels", default=None, type=str2int_tuple)
group.add_argument("--utt2ref-channels", default=None, type=str)
args = parser.parse_args()
out_num_samples = Path(args.outdir) / f"utt2num_samples"
if args.ref_channels is not None:
def utt2ref_channels(x) -> Tuple[int, ...]:
return args.ref_channels
elif args.utt2ref_channels is not None:
utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
chs_str = d[x]
return tuple(map(int, chs_str.split()))
else:
utt2ref_channels = None
Path(args.outdir).mkdir(parents=True, exist_ok=True)
out_wavscp = Path(args.outdir) / f"{args.name}.scp"
if args.segments is not None:
# Note: kaldiio supports only wav-pcm-int16le file.
loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
if args.audio_format.endswith("ark"):
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
fscp = out_wavscp.open("w")
else:
writer = SoundScpWriter(
args.outdir,
out_wavscp,
format=args.audio_format,
)
with out_num_samples.open("w") as fnum_samples:
for uttid, (rate, wave) in tqdm(loader):
# wave: (Time,) or (Time, Nmic)
if wave.ndim == 2 and utt2ref_channels is not None:
wave = wave[:, utt2ref_channels(uttid)]
if args.fs is not None and args.fs != rate:
# FIXME(kamo): To use sox?
wave = resampy.resample(
wave.astype(np.float64), rate, args.fs, axis=0
)
wave = wave.astype(np.int16)
rate = args.fs
if args.audio_format.endswith("ark"):
if "flac" in args.audio_format:
suf = "flac"
elif "wav" in args.audio_format:
suf = "wav"
else:
raise RuntimeError("wav.ark or flac")
# NOTE(kamo): Using extended ark format style here.
# This format is incompatible with Kaldi
kaldiio.save_ark(
fark,
{uttid: (wave, rate)},
scp=fscp,
append=True,
write_function=f"soundfile_{suf}",
)
else:
writer[uttid] = rate, wave
fnum_samples.write(f"{uttid} {len(wave)}\n")
else:
if args.audio_format.endswith("ark"):
fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
else:
wavdir = Path(args.outdir) / f"data_{args.name}"
wavdir.mkdir(parents=True, exist_ok=True)
with Path(args.scp).open("r") as fscp, out_wavscp.open(
"w"
) as fout, out_num_samples.open("w") as fnum_samples:
for line in tqdm(fscp):
uttid, wavpath = line.strip().split(None, 1)
if wavpath.endswith("|"):
# Streaming input e.g. cat a.wav |
with kaldiio.open_like_kaldi(wavpath, "rb") as f:
with BytesIO(f.read()) as g:
wave, rate = soundfile.read(g, dtype=np.int16)
if wave.ndim == 2 and utt2ref_channels is not None:
wave = wave[:, utt2ref_channels(uttid)]
if args.fs is not None and args.fs != rate:
# FIXME(kamo): To use sox?
wave = resampy.resample(
wave.astype(np.float64), rate, args.fs, axis=0
)
wave = wave.astype(np.int16)
rate = args.fs
if args.audio_format.endswith("ark"):
if "flac" in args.audio_format:
suf = "flac"
elif "wav" in args.audio_format:
suf = "wav"
else:
raise RuntimeError("wav.ark or flac")
# NOTE(kamo): Using extended ark format style here.
# This format is incompatible with Kaldi
kaldiio.save_ark(
fark,
{uttid: (wave, rate)},
scp=fout,
append=True,
write_function=f"soundfile_{suf}",
)
else:
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
soundfile.write(owavpath, wave, rate)
fout.write(f"{uttid} {owavpath}\n")
else:
wave, rate = soundfile.read(wavpath, dtype=np.int16)
if wave.ndim == 2 and utt2ref_channels is not None:
wave = wave[:, utt2ref_channels(uttid)]
save_asis = False
elif args.audio_format.endswith("ark"):
save_asis = False
elif Path(wavpath).suffix == "." + args.audio_format and (
args.fs is None or args.fs == rate
):
save_asis = True
else:
save_asis = False
if save_asis:
# Neither --segments nor --fs are specified and
# the line doesn't end with "|",
# i.e. not using unix-pipe,
# only in this case,
# just using the original file as is.
fout.write(f"{uttid} {wavpath}\n")
else:
if args.fs is not None and args.fs != rate:
# FIXME(kamo): To use sox?
wave = resampy.resample(
wave.astype(np.float64), rate, args.fs, axis=0
)
wave = wave.astype(np.int16)
rate = args.fs
if args.audio_format.endswith("ark"):
if "flac" in args.audio_format:
suf = "flac"
elif "wav" in args.audio_format:
suf = "wav"
else:
raise RuntimeError("wav.ark or flac")
# NOTE(kamo): Using extended ark format style here.
# This format is not supported in Kaldi.
kaldiio.save_ark(
fark,
{uttid: (wave, rate)},
scp=fout,
append=True,
write_function=f"soundfile_{suf}",
)
else:
owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
soundfile.write(owavpath, wave, rate)
fout.write(f"{uttid} {owavpath}\n")
fnum_samples.write(f"{uttid} {len(wave)}\n")
if __name__ == "__main__":
main()

View File

@ -1,45 +0,0 @@
#!/usr/bin/env python
import sys
def get_commandline_args(no_executable=True):
extra_chars = [
" ",
";",
"&",
"|",
"<",
">",
"?",
"*",
"~",
"`",
'"',
"'",
"\\",
"{",
"}",
"(",
")",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''")
if all(char not in arg for char in extra_chars)
else "'" + arg.replace("'", "'\\''") + "'"
for arg in sys.argv
]
if no_executable:
return " ".join(argv[1:])
else:
return sys.executable + " " + " ".join(argv)
def main():
print(get_commandline_args())
if __name__ == "__main__":
main()

View File

@ -1,142 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
SECONDS=0
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
help_message=$(cat << EOF
Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
e.g.
$0 data/test/wav.scp data/test_format/
Format 'wav.scp': In short words,
changing "kaldi-datadir" to "modified-kaldi-datadir"
The 'wav.scp' format in kaldi is very flexible,
e.g. It can use unix-pipe as describing that wav file,
but it sometime looks confusing and make scripts more complex.
This tools creates actual wav files from 'wav.scp'
and also segments wav files using 'segments'.
Options
--fs <fs>
--segments <segments>
--nj <nj>
--cmd <cmd>
EOF
)
out_filename=wav.scp
cmd=utils/run.pl
nj=30
fs=none
segments=
ref_channels=
utt2ref_channels=
audio_format=wav
write_utt2num_samples=true
log "$0 $*"
. utils/parse_options.sh
if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
log "${help_message}"
log "Error: invalid command line arguments"
exit 1
fi
. ./path.sh # Setup the environment
scp=$1
if [ ! -f "${scp}" ]; then
log "${help_message}"
echo "$0: Error: No such file: ${scp}"
exit 1
fi
dir=$2
if [ $# -eq 2 ]; then
logdir=${dir}/logs
outdir=${dir}/data
elif [ $# -eq 3 ]; then
logdir=$3
outdir=${dir}/data
elif [ $# -eq 4 ]; then
logdir=$3
outdir=$4
fi
mkdir -p ${logdir}
rm -f "${dir}/${out_filename}"
opts=
if [ -n "${utt2ref_channels}" ]; then
opts="--utt2ref-channels ${utt2ref_channels} "
elif [ -n "${ref_channels}" ]; then
opts="--ref-channels ${ref_channels} "
fi
if [ -n "${segments}" ]; then
log "[info]: using ${segments}"
nutt=$(<${segments} wc -l)
nj=$((nj<nutt?nj:nutt))
split_segments=""
for n in $(seq ${nj}); do
split_segments="${split_segments} ${logdir}/segments.${n}"
done
utils/split_scp.pl "${segments}" ${split_segments}
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
pyscripts/audio/format_wav_scp.py \
${opts} \
--fs ${fs} \
--audio-format "${audio_format}" \
"--segment=${logdir}/segments.JOB" \
"${scp}" "${outdir}/format.JOB"
else
log "[info]: without segments"
nutt=$(<${scp} wc -l)
nj=$((nj<nutt?nj:nutt))
split_scps=""
for n in $(seq ${nj}); do
split_scps="${split_scps} ${logdir}/wav.${n}.scp"
done
utils/split_scp.pl "${scp}" ${split_scps}
${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
pyscripts/audio/format_wav_scp.py \
${opts} \
--fs "${fs}" \
--audio-format "${audio_format}" \
"${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
fi
# Workaround for the NFS problem
ls ${outdir}/format.* > /dev/null
# concatenate the .scp files together.
for n in $(seq ${nj}); do
cat "${outdir}/format.${n}/wav.scp" || exit 1;
done > "${dir}/${out_filename}" || exit 1
if "${write_utt2num_samples}"; then
for n in $(seq ${nj}); do
cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
done > "${dir}/utt2num_samples" || exit 1
fi
log "Successfully finished. [elapsed=${SECONDS}s]"

View File

@ -1,116 +0,0 @@
#!/usr/bin/env bash
# 2020 @kamo-naoyuki
# This file was copied from Kaldi and
# I deleted parts related to wav duration
# because we shouldn't use kaldi's command here
# and we don't need the files actually.
# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
# 2014 Tom Ko
# 2018 Emotech LTD (author: Pawel Swietojanski)
# Apache 2.0
# This script operates on a directory, such as in data/train/,
# that contains some subset of the following files:
# wav.scp
# spk2utt
# utt2spk
# text
#
# It generates the files which are used for perturbing the speed of the original data.
export LC_ALL=C
set -euo pipefail
if [[ $# != 3 ]]; then
echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
echo "e.g.:"
echo " $0 0.9 data/train_si284 data/train_si284p"
exit 1
fi
factor=$1
srcdir=$2
destdir=$3
label="sp"
spk_prefix="${label}${factor}-"
utt_prefix="${label}${factor}-"
#check is sox on the path
! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
if [[ ! -f ${srcdir}/utt2spk ]]; then
echo "$0: no such file ${srcdir}/utt2spk"
exit 1;
fi
if [[ ${destdir} == "${srcdir}" ]]; then
echo "$0: this script requires <srcdir> and <destdir> to be different."
exit 1
fi
mkdir -p "${destdir}"
<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
if [[ ! -f ${srcdir}/utt2uniq ]]; then
<"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
else
<"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
fi
<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
if [[ -f ${srcdir}/segments ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
awk -v factor="${factor}" \
'{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
>"${destdir}"/segments
utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
> "${destdir}"/wav.scp
if [[ -f ${srcdir}/reco2file_and_channel ]]; then
utils/apply_map.pl -f 1 "${destdir}"/reco_map \
<"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
fi
else # no segments->wav indexed by utterance.
if [[ -f ${srcdir}/wav.scp ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
> "${destdir}"/wav.scp
fi
fi
if [[ -f ${srcdir}/text ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
fi
if [[ -f ${srcdir}/spk2gender ]]; then
utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
fi
if [[ -f ${srcdir}/utt2lang ]]; then
utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
fi
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
utils/validate_data_dir.sh --no-feats --no-text "${destdir}"

View File

@ -1,47 +0,0 @@
import torch
from torch import nn
class NllLoss(nn.Module):
"""Nll loss.
:param int size: the number of class
:param int padding_idx: ignored class id
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function
"""
def __init__(
self,
size,
padding_idx,
normalize_length=False,
criterion=nn.NLLLoss(reduction='none'),
):
"""Construct an LabelSmoothingLoss object."""
super(NllLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
kl = self.criterion(x , target)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore, 0).sum() / denom

View File

@ -1,169 +0,0 @@
import torch
from torch import nn
from funasr.modules.layer_norm import LayerNorm
class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
z = x
residual = x
if self.normalize_before:
x = self.norm1(x)
skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
if self.concat_after:
x_concat = torch.cat(
(x, skip), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(skip)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
def __init__(
self,
size,
d_size,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
self.size = size
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
self.spk_linear = nn.Linear(d_size, size, bias=False)
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
x = tgt_q
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if dn!=None:
x = x + self.spk_linear(dn)
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask

View File

@ -1,291 +0,0 @@
from typing import Any
from typing import List
from typing import Sequence
from typing import Tuple
import torch
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.attention import CosineDistanceAttention
from funasr.models.decoder.transformer_decoder import DecoderLayer
from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeAsrDecoderFirstLayer
from funasr.models.decoder.decoder_layer_sa_asr import SpeakerAttributeSpkDecoderFirstLayer
from funasr.modules.dynamic_conv import DynamicConvolution
from funasr.modules.dynamic_conv2d import DynamicConvolution2D
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.lightconv import LightweightConvolution
from funasr.modules.lightconv2d import LightweightConvolution2D
from funasr.modules.mask import subsequent_mask
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from funasr.models.decoder.abs_decoder import AbsDecoder
class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
spker_embedding_dim: int = 256,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
input_layer: str = "embed",
use_asr_output_layer: bool = True,
use_spk_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_asr_output_layer:
self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.asr_output_layer = None
if use_spk_output_layer:
self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
else:
self.spk_output_layer = None
self.cos_distance_att = CosineDistanceAttention()
self.decoder1 = None
self.decoder2 = None
self.decoder3 = None
self.decoder4 = None
def forward(
self,
asr_hs_pad: torch.Tensor,
spk_hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
profile: torch.Tensor,
profile_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
asr_memory = asr_hs_pad
spk_memory = spk_hs_pad
memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
# Spk decoder
x = self.embed(tgt)
x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
x, tgt_mask, asr_memory, spk_memory, memory_mask
)
x, tgt_mask, spk_memory, memory_mask = self.decoder2(
x, tgt_mask, spk_memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.spk_output_layer is not None:
x = self.spk_output_layer(x)
dn, weights = self.cos_distance_att(x, profile, profile_lens)
# Asr decoder
x, tgt_mask, asr_memory, memory_mask = self.decoder3(
z, tgt_mask, asr_memory, memory_mask, dn
)
x, tgt_mask, asr_memory, memory_mask = self.decoder4(
x, tgt_mask, asr_memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.asr_output_layer is not None:
x = self.asr_output_layer(x)
olens = tgt_mask.sum(1)
return x, weights, olens
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
asr_memory: torch.Tensor,
spk_memory: torch.Tensor,
profile: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x = self.embed(tgt)
if cache is None:
cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
new_cache = []
x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
)
new_cache.append(x)
for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
x, tgt_mask, spk_memory, _ = decoder(
x, tgt_mask, spk_memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
x = self.after_norm(x)
else:
x = x
if self.spk_output_layer is not None:
x = self.spk_output_layer(x)
dn, weights = self.cos_distance_att(x, profile, None)
x, tgt_mask, asr_memory, _ = self.decoder3(
z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
)
new_cache.append(x)
for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
x, tgt_mask, asr_memory, _ = decoder(
x, tgt_mask, asr_memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.asr_output_layer is not None:
y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
return y, weights, new_cache
def score(self, ys, state, asr_enc, spk_enc, profile):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
logp, weights, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
)
return logp.squeeze(0), weights.squeeze(), state
class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
spker_embedding_dim: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
asr_num_blocks: int = 6,
spk_num_blocks: int = 3,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_asr_output_layer: bool = True,
use_spk_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
spker_embedding_dim=spker_embedding_dim,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_asr_output_layer=use_asr_output_layer,
use_spk_output_layer=use_spk_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
)
self.decoder2 = repeat(
spk_num_blocks - 1,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
attention_dim,
spker_embedding_dim,
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
)
self.decoder4 = repeat(
asr_num_blocks - 1,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)