Merge branch 'main' of github.com:alibaba-damo-academy/FunASR

add
This commit is contained in:
游雁 2023-04-18 14:44:59 +08:00
commit a4bd736b03
23 changed files with 4998 additions and 702 deletions

View File

@ -0,0 +1,18 @@
# Streaming RNN-T Result
## Training Config
- 8 gpu(Tesla V100)
- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train config: conf/train_conformer_rnnt_unified.yaml
- chunk config: chunk size 16, full left chunk
- LM config: LM was not used
- Model size: 90M
## Results (CER)
- Decode config: conf/train_conformer_rnnt_unified.yaml
| testset | CER(%) |
|:-----------:|:-------:|
| dev | 5.53 |
| test | 6.24 |

View File

@ -0,0 +1,8 @@
# The conformer transducer decoding configuration from @jeon30c
beam_size: 10
simu_streaming: false
streaming: true
chunk_size: 16
left_context: 16
right_context: 0

View File

@ -0,0 +1,5 @@
# The conformer transducer decoding configuration from @jeon30c
beam_size: 10
simu_streaming: true
streaming: false
chunk_size: 16

View File

@ -0,0 +1,80 @@
encoder: chunk_conformer
encoder_conf:
activation_type: swish
positional_dropout_rate: 0.5
time_reduction_factor: 2
unified_model_training: true
default_chunk_size: 16
jitter_range: 4
left_chunk_size: 0
embed_vgg_like: false
subsampling_factor: 4
linear_units: 2048
output_size: 512
attention_heads: 8
dropout_rate: 0.5
positional_dropout_rate: 0.5
attention_dropout_rate: 0.5
cnn_module_kernel: 15
num_blocks: 12
# decoder related
rnnt_decoder: rnnt
rnnt_decoder_conf:
embed_size: 512
hidden_size: 512
embed_dropout_rate: 0.5
dropout_rate: 0.5
joint_network_conf:
joint_space_size: 512
# Auxiliary CTC
model_conf:
auxiliary_ctc_weight: 0.0
# minibatch related
use_amp: true
batch_type: unsorted
batch_size: 16
num_workers: 16
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 200
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- cer_transducer_chunk
- min
keep_nbest_models: 10
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
normalize: None
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 40
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 50
num_time_mask: 5
log_interval: 50

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2017 Xingyu Na
# Apache 2.0
#. ./path.sh || exit 1;
if [ $# != 3 ]; then
echo "Usage: $0 <audio-path> <text-path> <output-path>"
echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
exit 1;
fi
aishell_audio_dir=$1
aishell_text=$2/aishell_transcript_v0.8.txt
output_dir=$3
train_dir=$output_dir/data/local/train
dev_dir=$output_dir/data/local/dev
test_dir=$output_dir/data/local/test
tmp_dir=$output_dir/data/local/tmp
mkdir -p $train_dir
mkdir -p $dev_dir
mkdir -p $test_dir
mkdir -p $tmp_dir
# data directory check
if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
echo "Error: $0 requires two directory arguments"
exit 1;
fi
# find wav audio file for train, dev and test resp.
find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
n=`cat $tmp_dir/wav.flist | wc -l`
[ $n -ne 141925 ] && \
echo Warning: expected 141925 data data files, found $n
grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
rm -r $tmp_dir
# Transcriptions preparation
for dir in $train_dir $dev_dir $test_dir; do
echo Preparing $dir transcriptions
sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
sort -u $dir/transcripts.txt > $dir/text
done
mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
for f in wav.scp text; do
cp $train_dir/$f $output_dir/data/train/$f || exit 1;
cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
cp $test_dir/$f $output_dir/data/test/$f || exit 1;
done
echo "$0: AISHELL data preparation succeeded"
exit 0;

5
egs/aishell/rnnt/path.sh Normal file
View File

@ -0,0 +1,5 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH

247
egs/aishell/rnnt/run.sh Executable file
View File

@ -0,0 +1,247 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="0,1,2,3"
gpu_num=4
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir= #feature output dictionary
exp_dir=
lang=zh
dumpdir=dump/fbank
feats_type=fbank
token_type=char
scp=feats.scp
type=kaldi_ark
stage=0
stop_stage=4
# feature configuration
feats_dim=80
sample_frequency=16000
nj=32
speed_perturb="0.9,1.0,1.1"
# data
data_aishell=
# exp tag
tag="exp1"
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train
valid_set=dev
test_sets="dev test"
asr_config=conf/train_conformer_rnnt_unified.yaml
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
inference_config=conf/decode_rnnt_conformer_streaming.yaml
inference_asr_model=valid.cer_transducer_chunk.ave_5best.pth
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
inference_nj=$[${ngpu}*${njob}]
_ngpu=1
else
inference_nj=$njob
_ngpu=0
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
> ${feats_dir}/data/${x}/text
utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: Feature Generation"
# compute fbank features
fbankdir=${feats_dir}/fbank
utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
utils/fix_data_feat.sh ${fbankdir}/train
utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
utils/fix_data_feat.sh ${fbankdir}/dev
utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
utils/fix_data_feat.sh ${fbankdir}/test
# compute global cmvn
utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
${fbankdir}/train ${exp_dir}/exp/make_fbank/train
# apply cmvn
utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
utils/fix_data_feat.sh ${feat_train_dir}
utils/fix_data_feat.sh ${feat_dev_dir}
utils/fix_data_feat.sh ${feat_test_dir}
#generate ark list
utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
vocab_size=$(cat ${token_list} | wc -l)
awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
# Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
asr_train_transducer.py \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
--input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
} &
done
wait
fi
# Testing Stage
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "stage 4: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
_dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
exit 0
fi
mkdir -p "${_logdir}"
_data="${feats_dir}/${dumpdir}/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
_opts=
if [ -n "${inference_config}" ]; then
_opts+="--config ${inference_config} "
fi
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
--output_dir "${_logdir}"/output.JOB \
--mode rnnt \
${_opts}
for f in token token_int score text; do
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/1best_recog/${f}"
done | sort -k1 >"${_dir}/${f}"
fi
done
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
python utils/proce_text.py ${_data}/text ${_data}/text.proc
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
fi

1
egs/aishell/rnnt/utils Symbolic link
View File

@ -0,0 +1 @@
../transformer/utils

View File

@ -134,6 +134,11 @@ def get_parser():
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group.add_argument(
"--beam_search_config",
default={},
help="The keyword arguments for transducer beam search.",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
@ -171,6 +176,41 @@ def get_parser():
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group.add_argument("--simu_streaming", type=str2bool, default=False)
group.add_argument("--chunk_size", type=int, default=16)
group.add_argument("--left_context", type=int, default=16)
group.add_argument("--right_context", type=int, default=0)
group.add_argument(
"--display_partial_hypotheses",
type=bool,
default=False,
help="Whether to display partial hypotheses during chunk-by-chunk inference.",
)
group = parser.add_argument_group("Dynamic quantization related")
group.add_argument(
"--quantize_asr_model",
type=bool,
default=False,
help="Apply dynamic quantization to ASR model.",
)
group.add_argument(
"--quantize_modules",
nargs="*",
default=None,
help="""Module names to apply dynamic quantization on.
The module names are provided as a list, where each name is separated
by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
Each specified name should be an attribute of 'torch.nn', e.g.:
torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
)
group.add_argument(
"--quantize_dtype",
type=str,
default="qint8",
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
)
group = parser.add_argument_group("Text converter related")
group.add_argument(
@ -268,6 +308,9 @@ def inference_launch_funasr(**kwargs):
elif mode == "mfcca":
from funasr.bin.asr_inference_mfcca import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "rnnt":
from funasr.bin.asr_inference_rnnt import inference
return inference(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python3
import os
from funasr.tasks.asr import ASRTransducerTask
# for ASR Training
def parse_args():
parser = ASRTransducerTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for ASR Training
ASRTransducerTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@ -0,0 +1,258 @@
"""RNN decoder definition for Transducer models."""
from typing import List, Optional, Tuple
import torch
from typeguard import check_argument_types
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.specaug.specaug import SpecAug
class RNNTDecoder(torch.nn.Module):
"""RNN decoder module.
Args:
vocab_size: Vocabulary size.
embed_size: Embedding size.
hidden_size: Hidden size..
rnn_type: Decoder layers type.
num_layers: Number of decoder layers.
dropout_rate: Dropout rate for decoder layers.
embed_dropout_rate: Dropout rate for embedding layer.
embed_pad: Embedding padding symbol ID.
"""
def __init__(
self,
vocab_size: int,
embed_size: int = 256,
hidden_size: int = 256,
rnn_type: str = "lstm",
num_layers: int = 1,
dropout_rate: float = 0.0,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
) -> None:
"""Construct a RNNDecoder object."""
super().__init__()
assert check_argument_types()
if rnn_type not in ("lstm", "gru"):
raise ValueError(f"Not supported: rnn_type={rnn_type}")
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
self.rnn = torch.nn.ModuleList(
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
)
for _ in range(1, num_layers):
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
self.dropout_rnn = torch.nn.ModuleList(
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
)
self.dlayers = num_layers
self.dtype = rnn_type
self.output_size = hidden_size
self.vocab_size = vocab_size
self.device = next(self.parameters()).device
self.score_cache = {}
def forward(
self,
labels: torch.Tensor,
label_lens: torch.Tensor,
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
) -> torch.Tensor:
"""Encode source label sequences.
Args:
labels: Label ID sequences. (B, L)
states: Decoder hidden states.
((N, B, D_dec), (N, B, D_dec) or None) or None
Returns:
dec_out: Decoder output sequences. (B, U, D_dec)
"""
if states is None:
states = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out
def rnn_forward(
self,
x: torch.Tensor,
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Encode source label sequences.
Args:
x: RNN input sequences. (B, D_emb)
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
Returns:
x: RNN output sequences. (B, D_dec)
(h_next, c_next): Decoder hidden states.
(N, B, D_dec), (N, B, D_dec) or None)
"""
h_prev, c_prev = state
h_next, c_next = self.init_state(x.size(0))
for layer in range(self.dlayers):
if self.dtype == "lstm":
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
layer
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
else:
x, h_next[layer : layer + 1] = self.rnn[layer](
x, hx=h_prev[layer : layer + 1]
)
x = self.dropout_rnn[layer](x)
return x, (h_next, c_next)
def score(
self,
label: torch.Tensor,
label_sequence: List[int],
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypothesis.
Args:
label: Previous label. (1, 1)
label_sequence: Current label sequence.
dec_state: Previous decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
dec_state: Decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
"""
str_labels = "_".join(map(str, label_sequence))
if str_labels in self.score_cache:
dec_out, dec_state = self.score_cache[str_labels]
else:
dec_embed = self.embed(label)
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
self.score_cache[str_labels] = (dec_out, dec_state)
return dec_out[0], dec_state
def batch_score(
self,
hyps: List[Hypothesis],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
dec_embed = self.embed(labels)
states = self.create_batch_states([h.dec_state for h in hyps])
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out.squeeze(1), states
def set_device(self, device: torch.device) -> None:
"""Set GPU device to use.
Args:
device: Device ID.
"""
self.device = device
def init_state(
self, batch_size: int
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
h_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
if self.dtype == "lstm":
c_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
return (h_n, c_n)
return (h_n, None)
def select_state(
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
idx: State ID to extract.
Returns:
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
"""
return (
states[0][:, idx : idx + 1, :],
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
)
def create_batch_states(
self,
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Create decoder hidden states.
Args:
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
Returns:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
return (
torch.cat([s[0] for s in new_states], dim=1),
torch.cat([s[1] for s in new_states], dim=1)
if self.dtype == "lstm"
else None,
)

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,7 @@ from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from typing import Dict
import torch
from torch import nn
@ -18,6 +19,7 @@ from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttentionChunk,
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.modules.embedding import (
@ -25,16 +27,23 @@ from funasr.modules.embedding import (
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.nets_utils import (
TooShortUttError,
check_short_utt,
make_chunk_mask,
make_source_mask,
)
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.repeat import repeat, MultiBlocks
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
@ -42,6 +51,8 @@ from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.modules.subsampling import Conv2dSubsamplingPad
from funasr.modules.subsampling import StreamingConvInput
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
@ -276,6 +287,188 @@ class EncoderLayer(nn.Module):
return x, mask
class ChunkEncoderLayer(torch.nn.Module):
"""Chunk Conformer module definition.
Args:
block_size: Input/output size.
self_att: Self-attention module instance.
feed_forward: Feed-forward module instance.
feed_forward_macaron: Feed-forward module instance for macaron network.
conv_mod: Convolution module instance.
norm_class: Normalization module class.
norm_args: Normalization module arguments.
dropout_rate: Dropout rate.
"""
def __init__(
self,
block_size: int,
self_att: torch.nn.Module,
feed_forward: torch.nn.Module,
feed_forward_macaron: torch.nn.Module,
conv_mod: torch.nn.Module,
norm_class: torch.nn.Module = torch.nn.LayerNorm,
norm_args: Dict = {},
dropout_rate: float = 0.0,
) -> None:
"""Construct a Conformer object."""
super().__init__()
self.self_att = self_att
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.feed_forward_scale = 0.5
self.conv_mod = conv_mod
self.norm_feed_forward = norm_class(block_size, **norm_args)
self.norm_self_att = norm_class(block_size, **norm_args)
self.norm_macaron = norm_class(block_size, **norm_args)
self.norm_conv = norm_class(block_size, **norm_args)
self.norm_final = norm_class(block_size, **norm_args)
self.dropout = torch.nn.Dropout(dropout_rate)
self.block_size = block_size
self.cache = None
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
"""Initialize/Reset self-attention and convolution modules cache for streaming.
Args:
left_context: Number of left frames during chunk-by-chunk inference.
device: Device to use for cache tensor.
"""
self.cache = [
torch.zeros(
(1, left_context, self.block_size),
device=device,
),
torch.zeros(
(
1,
self.block_size,
self.conv_mod.kernel_size - 1,
),
device=device,
),
]
def forward(
self,
x: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: Conformer input sequences. (B, T, D_block)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
mask: Source mask. (B, T)
chunk_mask: Chunk mask. (T_2, T_2)
Returns:
x: Conformer output sequences. (B, T, D_block)
mask: Source mask. (B, T)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
"""
residual = x
x = self.norm_macaron(x)
x = residual + self.feed_forward_scale * self.dropout(
self.feed_forward_macaron(x)
)
residual = x
x = self.norm_self_att(x)
x_q = x
x = residual + self.dropout(
self.self_att(
x_q,
x,
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
)
residual = x
x = self.norm_conv(x)
x, _ = self.conv_mod(x)
x = residual + self.dropout(x)
residual = x
x = self.norm_feed_forward(x)
x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
x = self.norm_final(x)
return x, mask, pos_enc
def chunk_forward(
self,
x: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_size: int = 16,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode chunk of input sequence.
Args:
x: Conformer input sequences. (B, T, D_block)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
mask: Source mask. (B, T_2)
left_context: Number of frames in left context.
right_context: Number of frames in right context.
Returns:
x: Conformer output sequences. (B, T, D_block)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
"""
residual = x
x = self.norm_macaron(x)
x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
residual = x
x = self.norm_self_att(x)
if left_context > 0:
key = torch.cat([self.cache[0], x], dim=1)
else:
key = x
val = key
if right_context > 0:
att_cache = key[:, -(left_context + right_context) : -right_context, :]
else:
att_cache = key[:, -left_context:, :]
x = residual + self.self_att(
x,
key,
val,
pos_enc,
mask,
left_context=left_context,
)
residual = x
x = self.norm_conv(x)
x, conv_cache = self.conv_mod(
x, cache=self.cache[1], right_context=right_context
)
x = residual + x
residual = x
x = self.norm_feed_forward(x)
x = residual + self.feed_forward_scale * self.feed_forward(x)
x = self.norm_final(x)
self.cache = [att_cache, conv_cache]
return x, pos_enc
class ConformerEncoder(AbsEncoder):
"""Conformer encoder module.
@ -604,3 +797,442 @@ class ConformerEncoder(AbsEncoder):
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
class CausalConvolution(torch.nn.Module):
"""ConformerConvolution module definition.
Args:
channels: The number of channels.
kernel_size: Size of the convolving kernel.
activation: Type of activation function.
norm_args: Normalization module arguments.
causal: Whether to use causal convolution (set to True if streaming).
"""
def __init__(
self,
channels: int,
kernel_size: int,
activation: torch.nn.Module = torch.nn.ReLU(),
norm_args: Dict = {},
causal: bool = False,
) -> None:
"""Construct an ConformerConvolution object."""
super().__init__()
assert (kernel_size - 1) % 2 == 0
self.kernel_size = kernel_size
self.pointwise_conv1 = torch.nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
)
if causal:
self.lorder = kernel_size - 1
padding = 0
else:
self.lorder = 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = torch.nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
)
self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
self.pointwise_conv2 = torch.nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
)
self.activation = activation
def forward(
self,
x: torch.Tensor,
cache: Optional[torch.Tensor] = None,
right_context: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x: ConformerConvolution input sequences. (B, T, D_hidden)
cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
right_context: Number of frames in right context.
Returns:
x: ConformerConvolution output sequences. (B, T, D_hidden)
cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
"""
x = self.pointwise_conv1(x.transpose(1, 2))
x = torch.nn.functional.glu(x, dim=1)
if self.lorder > 0:
if cache is None:
x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
x = torch.cat([cache, x], dim=2)
if right_context > 0:
cache = x[:, :, -(self.lorder + right_context) : -right_context]
else:
cache = x[:, :, -self.lorder :]
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x).transpose(1, 2)
return x, cache
class ConformerChunkEncoder(AbsEncoder):
"""Encoder module definition.
Args:
input_size: Input size.
body_conf: Encoder body configuration.
input_conf: Encoder input configuration.
main_conf: Encoder main configuration.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
embed_vgg_like: bool = False,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = "legacy",
pos_enc_layer_type: str = "rel_pos",
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
zero_triu: bool = False,
norm_type: str = "layer_norm",
cnn_module_kernel: int = 31,
conv_mod_norm_eps: float = 0.00001,
conv_mod_norm_momentum: float = 0.1,
simplified_att_score: bool = False,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
left_chunk_size: int = 0,
time_reduction_factor: int = 1,
unified_model_training: bool = False,
default_chunk_size: int = 16,
jitter_range: int = 4,
subsampling_factor: int = 1,
) -> None:
"""Construct an Encoder object."""
super().__init__()
assert check_argument_types()
self.embed = StreamingConvInput(
input_size,
output_size,
subsampling_factor,
vgg_like=embed_vgg_like,
output_size=output_size,
)
self.pos_enc = StreamingRelPositionalEncoding(
output_size,
positional_dropout_rate,
)
activation = get_activation(
activation_type
)
pos_wise_args = (
output_size,
linear_units,
positional_dropout_rate,
activation,
)
conv_mod_norm_args = {
"eps": conv_mod_norm_eps,
"momentum": conv_mod_norm_momentum,
}
conv_mod_args = (
output_size,
cnn_module_kernel,
activation,
conv_mod_norm_args,
dynamic_chunk_training or unified_model_training,
)
mult_att_args = (
attention_heads,
output_size,
attention_dropout_rate,
simplified_att_score,
)
fn_modules = []
for _ in range(num_blocks):
module = lambda: ChunkEncoderLayer(
output_size,
RelPositionMultiHeadedAttentionChunk(*mult_att_args),
PositionwiseFeedForward(*pos_wise_args),
PositionwiseFeedForward(*pos_wise_args),
CausalConvolution(*conv_mod_args),
dropout_rate=dropout_rate,
)
fn_modules.append(module)
self.encoders = MultiBlocks(
[fn() for fn in fn_modules],
output_size,
)
self._output_size = output_size
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.left_chunk_size = left_chunk_size
self.unified_model_training = unified_model_training
self.default_chunk_size = default_chunk_size
self.jitter_range = jitter_range
self.time_reduction_factor = time_reduction_factor
def output_size(self) -> int:
return self._output_size
def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
Where size is the number of features frames after applying subsampling.
Args:
size: Number of frames after subsampling.
hop_length: Frontend's hop length
Returns:
: Number of raw samples
"""
return self.embed.get_size_before_subsampling(size) * hop_length
def get_encoder_input_size(self, size: int) -> int:
"""Return the corresponding number of sample for a given chunk size, in frames.
Where size is the number of features frames after applying subsampling.
Args:
size: Number of frames after subsampling.
Returns:
: Number of raw samples
"""
return self.embed.get_size_before_subsampling(size)
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
"""Initialize/Reset encoder streaming cache.
Args:
left_context: Number of frames in left context.
device: Device ID.
"""
return self.encoders.reset_streaming_cache(left_context, device)
def forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: Encoder input features. (B, T_in, F)
x_len: Encoder input features lengths. (B,)
Returns:
x: Encoder outputs. (B, T_out, D_enc)
x_len: Encoder outputs lenghts. (B,)
"""
short_status, limit_size = check_short_utt(
self.embed.subsampling_factor, x.size(1)
)
if short_status:
raise TooShortUttError(
f"has {x.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
x.size(1),
limit_size,
)
mask = make_source_mask(x_len)
if self.unified_model_training:
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
left_chunk_size=self.left_chunk_size,
device=x.device,
)
x_utt = self.encoders(
x,
pos_enc,
mask,
chunk_mask=None,
)
x_chunk = self.encoders(
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x_utt = x_utt[:,::self.time_reduction_factor,:]
x_chunk = x_chunk[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x_utt, x_chunk, olens
elif self.dynamic_chunk_training:
max_len = x.size(1)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = (chunk_size % self.short_chunk_size) + 1
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
left_chunk_size=self.left_chunk_size,
device=x.device,
)
else:
x, mask = self.embed(x, mask, None)
pos_enc = self.pos_enc(x)
chunk_mask = None
x = self.encoders(
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
return x, olens
def simu_chunk_forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
chunk_size: int = 16,
left_context: int = 32,
right_context: int = 0,
) -> torch.Tensor:
short_status, limit_size = check_short_utt(
self.embed.subsampling_factor, x.size(1)
)
if short_status:
raise TooShortUttError(
f"has {x.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
x.size(1),
limit_size,
)
mask = make_source_mask(x_len)
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
x.size(1),
chunk_size,
left_chunk_size=self.left_chunk_size,
device=x.device,
)
x = self.encoders(
x,
pos_enc,
mask,
chunk_mask=chunk_mask,
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
return x
def chunk_forward(
self,
x: torch.Tensor,
x_len: torch.Tensor,
processed_frames: torch.tensor,
chunk_size: int = 16,
left_context: int = 32,
right_context: int = 0,
) -> torch.Tensor:
"""Encode input sequences as chunks.
Args:
x: Encoder input features. (1, T_in, F)
x_len: Encoder input features lengths. (1,)
processed_frames: Number of frames already seen.
left_context: Number of frames in left context.
right_context: Number of frames in right context.
Returns:
x: Encoder outputs. (B, T_out, D_enc)
"""
mask = make_source_mask(x_len)
x, mask = self.embed(x, mask, None)
if left_context > 0:
processed_mask = (
torch.arange(left_context, device=x.device)
.view(1, left_context)
.flip(1)
)
processed_mask = processed_mask >= processed_frames
mask = torch.cat([processed_mask, mask], dim=1)
pos_enc = self.pos_enc(x, left_context=left_context)
x = self.encoders.chunk_forward(
x,
pos_enc,
mask,
chunk_size=chunk_size,
left_context=left_context,
right_context=right_context,
)
if right_context > 0:
x = x[:, 0:-right_context, :]
if self.time_reduction_factor > 1:
x = x[:,::self.time_reduction_factor,:]
return x

View File

@ -0,0 +1,61 @@
"""Transducer joint network implementation."""
import torch
from funasr.modules.nets_utils import get_activation
class JointNetwork(torch.nn.Module):
"""Transducer joint network module.
Args:
output_size: Output size.
encoder_size: Encoder output size.
decoder_size: Decoder output size..
joint_space_size: Joint space size.
joint_act_type: Type of activation for joint network.
**activation_parameters: Parameters for the activation function.
"""
def __init__(
self,
output_size: int,
encoder_size: int,
decoder_size: int,
joint_space_size: int = 256,
joint_activation_type: str = "tanh",
) -> None:
"""Construct a JointNetwork object."""
super().__init__()
self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size)
self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False)
self.lin_out = torch.nn.Linear(joint_space_size, output_size)
self.joint_activation = get_activation(
joint_activation_type
)
def forward(
self,
enc_out: torch.Tensor,
dec_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""Joint computation of encoder and decoder hidden state sequences.
Args:
enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
Returns:
joint_out: Joint output state sequences. (B, T, U, D_out)
"""
if project_input:
joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out))
else:
joint_out = self.joint_activation(enc_out + dec_out)
return self.lin_out(joint_out)

View File

@ -11,7 +11,7 @@ import math
import numpy
import torch
from torch import nn
from typing import Optional, Tuple
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@ -741,3 +741,221 @@ class MultiHeadSelfAttention(nn.Module):
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs
class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
"""RelPositionMultiHeadedAttention definition.
Args:
num_heads: Number of attention heads.
embed_size: Embedding size.
dropout_rate: Dropout rate.
"""
def __init__(
self,
num_heads: int,
embed_size: int,
dropout_rate: float = 0.0,
simplified_attention_score: bool = False,
) -> None:
"""Construct an MultiHeadedAttention object."""
super().__init__()
self.d_k = embed_size // num_heads
self.num_heads = num_heads
assert self.d_k * num_heads == embed_size, (
"embed_size (%d) must be divisible by num_heads (%d)",
(embed_size, num_heads),
)
self.linear_q = torch.nn.Linear(embed_size, embed_size)
self.linear_k = torch.nn.Linear(embed_size, embed_size)
self.linear_v = torch.nn.Linear(embed_size, embed_size)
self.linear_out = torch.nn.Linear(embed_size, embed_size)
if simplified_attention_score:
self.linear_pos = torch.nn.Linear(embed_size, num_heads)
self.compute_att_score = self.compute_simplified_attention_score
else:
self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
self.compute_att_score = self.compute_attention_score
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.attn = None
def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
"""Compute relative positional encoding.
Args:
x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
left_context: Number of frames in left context.
Returns:
x: Output sequence. (B, H, T_1, T_2)
"""
batch_size, n_heads, time1, n = x.shape
time2 = time1 + left_context
batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
return x.as_strided(
(batch_size, n_heads, time1, time2),
(batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
storage_offset=(n_stride * (time1 - 1)),
)
def compute_simplified_attention_score(
self,
query: torch.Tensor,
key: torch.Tensor,
pos_enc: torch.Tensor,
left_context: int = 0,
) -> torch.Tensor:
"""Simplified attention score computation.
Reference: https://github.com/k2-fsa/icefall/pull/458
Args:
query: Transformed query tensor. (B, H, T_1, d_k)
key: Transformed key tensor. (B, H, T_2, d_k)
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
left_context: Number of frames in left context.
Returns:
: Attention score. (B, H, T_1, T_2)
"""
pos_enc = self.linear_pos(pos_enc)
matrix_ac = torch.matmul(query, key.transpose(2, 3))
matrix_bd = self.rel_shift(
pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
left_context=left_context,
)
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
def compute_attention_score(
self,
query: torch.Tensor,
key: torch.Tensor,
pos_enc: torch.Tensor,
left_context: int = 0,
) -> torch.Tensor:
"""Attention score computation.
Args:
query: Transformed query tensor. (B, H, T_1, d_k)
key: Transformed key tensor. (B, H, T_2, d_k)
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
left_context: Number of frames in left context.
Returns:
: Attention score. (B, H, T_1, T_2)
"""
p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
query = query.transpose(1, 2)
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query: Query tensor. (B, T_1, size)
key: Key tensor. (B, T_2, size)
v: Value tensor. (B, T_2, size)
Returns:
q: Transformed query tensor. (B, H, T_1, d_k)
k: Transformed key tensor. (B, H, T_2, d_k)
v: Transformed value tensor. (B, H, T_2, d_k)
"""
n_batch = query.size(0)
q = (
self.linear_q(query)
.view(n_batch, -1, self.num_heads, self.d_k)
.transpose(1, 2)
)
k = (
self.linear_k(key)
.view(n_batch, -1, self.num_heads, self.d_k)
.transpose(1, 2)
)
v = (
self.linear_v(value)
.view(n_batch, -1, self.num_heads, self.d_k)
.transpose(1, 2)
)
return q, k, v
def forward_attention(
self,
value: torch.Tensor,
scores: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value: Transformed value. (B, H, T_2, d_k)
scores: Attention score. (B, H, T_1, T_2)
mask: Source mask. (B, T_2)
chunk_mask: Chunk mask. (T_1, T_1)
Returns:
attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
"""
batch_size = scores.size(0)
mask = mask.unsqueeze(1).unsqueeze(2)
if chunk_mask is not None:
mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
scores = scores.masked_fill(mask, float("-inf"))
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
attn_output = self.dropout(self.attn)
attn_output = torch.matmul(attn_output, value)
attn_output = self.linear_out(
attn_output.transpose(1, 2)
.contiguous()
.view(batch_size, -1, self.num_heads * self.d_k)
)
return attn_output
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
left_context: int = 0,
) -> torch.Tensor:
"""Compute scaled dot product attention with rel. positional encoding.
Args:
query: Query tensor. (B, T_1, size)
key: Key tensor. (B, T_2, size)
value: Value tensor. (B, T_2, size)
pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
mask: Source mask. (B, T_2)
chunk_mask: Chunk mask. (T_1, T_1)
left_context: Number of frames in left context.
Returns:
: Output tensor. (B, T_1, H * d_k)
"""
q, k, v = self.forward_qkv(query, key, value)
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)

View File

@ -0,0 +1,704 @@
"""Search algorithms for Transducer models."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from funasr.models.joint_net.joint_network import JointNetwork
@dataclass
class Hypothesis:
"""Default hypothesis definition for Transducer search algorithms.
Args:
score: Total log-probability.
yseq: Label sequence as integer ID sequence.
dec_state: RNNDecoder or StatelessDecoder state.
((N, 1, D_dec), (N, 1, D_dec) or None) or None
lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None
"""
score: float
yseq: List[int]
dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None
lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None
@dataclass
class ExtendedHypothesis(Hypothesis):
"""Extended hypothesis definition for NSC beam search and mAES.
Args:
: Hypothesis dataclass arguments.
dec_out: Decoder output sequence. (B, D_dec)
lm_score: Log-probabilities of the LM for given label. (vocab_size)
"""
dec_out: torch.Tensor = None
lm_score: torch.Tensor = None
class BeamSearchTransducer:
"""Beam search implementation for Transducer.
Args:
decoder: Decoder module.
joint_network: Joint network module.
beam_size: Size of the beam.
lm: LM class.
lm_weight: LM weight for soft fusion.
search_type: Search algorithm to use during inference.
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
u_max: Maximum expected target sequence length. (ALSD)
nstep: Number of maximum expansion steps at each time step. (mAES)
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
expansion_beta:
Number of additional candidates for expanded hypotheses selection. (mAES)
score_norm: Normalize final scores by length.
nbest: Number of final hypothesis.
streaming: Whether to perform chunk-by-chunk beam search.
"""
def __init__(
self,
decoder,
joint_network: JointNetwork,
beam_size: int,
lm: Optional[torch.nn.Module] = None,
lm_weight: float = 0.1,
search_type: str = "default",
max_sym_exp: int = 3,
u_max: int = 50,
nstep: int = 2,
expansion_gamma: float = 2.3,
expansion_beta: int = 2,
score_norm: bool = False,
nbest: int = 1,
streaming: bool = False,
) -> None:
"""Construct a BeamSearchTransducer object."""
super().__init__()
self.decoder = decoder
self.joint_network = joint_network
self.vocab_size = decoder.vocab_size
assert beam_size <= self.vocab_size, (
"beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
% (
beam_size,
self.vocab_size,
)
)
self.beam_size = beam_size
if search_type == "default":
self.search_algorithm = self.default_beam_search
elif search_type == "tsd":
assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
max_sym_exp
)
self.max_sym_exp = max_sym_exp
self.search_algorithm = self.time_sync_decoding
elif search_type == "alsd":
assert not streaming, "ALSD is not available in streaming mode."
assert u_max >= 0, "u_max should be a positive integer, a portion of max_T."
self.u_max = u_max
self.search_algorithm = self.align_length_sync_decoding
elif search_type == "maes":
assert self.vocab_size >= beam_size + expansion_beta, (
"beam_size (%d) + expansion_beta (%d) "
" should be smaller than or equal to vocab size (%d)."
% (beam_size, expansion_beta, self.vocab_size)
)
self.max_candidates = beam_size + expansion_beta
self.nstep = nstep
self.expansion_gamma = expansion_gamma
self.search_algorithm = self.modified_adaptive_expansion_search
else:
raise NotImplementedError(
"Specified search type (%s) is not supported." % search_type
)
self.use_lm = lm is not None
if self.use_lm:
assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported."
self.sos = self.vocab_size - 1
self.lm = lm
self.lm_weight = lm_weight
self.score_norm = score_norm
self.nbest = nbest
self.reset_inference_cache()
def __call__(
self,
enc_out: torch.Tensor,
is_final: bool = True,
) -> List[Hypothesis]:
"""Perform beam search.
Args:
enc_out: Encoder output sequence. (T, D_enc)
is_final: Whether enc_out is the final chunk of data.
Returns:
nbest_hyps: N-best decoding results
"""
self.decoder.set_device(enc_out.device)
hyps = self.search_algorithm(enc_out)
if is_final:
self.reset_inference_cache()
return self.sort_nbest(hyps)
self.search_cache = hyps
return hyps
def reset_inference_cache(self) -> None:
"""Reset cache for decoder scoring and streaming."""
self.decoder.score_cache = {}
self.search_cache = None
def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
"""Sort in-place hypotheses by score or score given sequence length.
Args:
hyps: Hypothesis.
Return:
hyps: Sorted hypothesis.
"""
if self.score_norm:
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
else:
hyps.sort(key=lambda x: x.score, reverse=True)
return hyps[: self.nbest]
def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
"""Recombine hypotheses with same label ID sequence.
Args:
hyps: Hypotheses.
Returns:
final: Recombined hypotheses.
"""
final = {}
for hyp in hyps:
str_yseq = "_".join(map(str, hyp.yseq))
if str_yseq in final:
final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score)
else:
final[str_yseq] = hyp
return [*final.values()]
def select_k_expansions(
self,
hyps: List[ExtendedHypothesis],
topk_idx: torch.Tensor,
topk_logp: torch.Tensor,
) -> List[ExtendedHypothesis]:
"""Return K hypotheses candidates for expansion from a list of hypothesis.
K candidates are selected according to the extended hypotheses probabilities
and a prune-by-value method. Where K is equal to beam_size + beta.
Args:
hyps: Hypotheses.
topk_idx: Indices of candidates hypothesis.
topk_logp: Log-probabilities of candidates hypothesis.
Returns:
k_expansions: Best K expansion hypotheses candidates.
"""
k_expansions = []
for i, hyp in enumerate(hyps):
hyp_i = [
(int(k), hyp.score + float(v))
for k, v in zip(topk_idx[i], topk_logp[i])
]
k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
k_expansions.append(
sorted(
filter(
lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
),
key=lambda x: x[1],
reverse=True,
)
)
return k_expansions
def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor:
"""Make batch of inputs with left padding for LM scoring.
Args:
hyps_seq: Hypothesis sequences.
Returns:
: Padded batch of sequences.
"""
max_len = max([len(h) for h in hyps_seq])
return torch.LongTensor(
[[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq],
device=self.decoder.device,
)
def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Beam search implementation without prefix search.
Modified from https://arxiv.org/pdf/1211.3711.pdf
Args:
enc_out: Encoder output sequence. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam_k = min(self.beam_size, (self.vocab_size - 1))
max_t = len(enc_out)
if self.search_cache is not None:
kept_hyps = self.search_cache
else:
kept_hyps = [
Hypothesis(
score=0.0,
yseq=[0],
dec_state=self.decoder.init_state(1),
)
]
for t in range(max_t):
hyps = kept_hyps
kept_hyps = []
while True:
max_hyp = max(hyps, key=lambda x: x.score)
hyps.remove(max_hyp)
label = torch.full(
(1, 1),
max_hyp.yseq[-1],
dtype=torch.long,
device=self.decoder.device,
)
dec_out, state = self.decoder.score(
label,
max_hyp.yseq,
max_hyp.dec_state,
)
logp = torch.log_softmax(
self.joint_network(enc_out[t : t + 1, :], dec_out),
dim=-1,
).squeeze(0)
top_k = logp[1:].topk(beam_k, dim=-1)
kept_hyps.append(
Hypothesis(
score=(max_hyp.score + float(logp[0:1])),
yseq=max_hyp.yseq,
dec_state=max_hyp.dec_state,
lm_state=max_hyp.lm_state,
)
)
if self.use_lm:
lm_scores, lm_state = self.lm.score(
torch.LongTensor(
[self.sos] + max_hyp.yseq[1:], device=self.decoder.device
),
max_hyp.lm_state,
None,
)
else:
lm_state = max_hyp.lm_state
for logp, k in zip(*top_k):
score = max_hyp.score + float(logp)
if self.use_lm:
score += self.lm_weight * lm_scores[k + 1]
hyps.append(
Hypothesis(
score=score,
yseq=max_hyp.yseq + [int(k + 1)],
dec_state=state,
lm_state=lm_state,
)
)
hyps_max = float(max(hyps, key=lambda x: x.score).score)
kept_most_prob = sorted(
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
key=lambda x: x.score,
)
if len(kept_most_prob) >= self.beam_size:
kept_hyps = kept_most_prob
break
return kept_hyps
def align_length_sync_decoding(
self,
enc_out: torch.Tensor,
) -> List[Hypothesis]:
"""Alignment-length synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
h: Encoder output sequences. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
t_max = int(enc_out.size(0))
u_max = min(self.u_max, (t_max - 1))
B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))]
final = []
if self.use_lm:
B[0].lm_state = self.lm.zero_state()
for i in range(t_max + u_max):
A = []
B_ = []
B_enc_out = []
for hyp in B:
u = len(hyp.yseq) - 1
t = i - u
if t > (t_max - 1):
continue
B_.append(hyp)
B_enc_out.append((t, enc_out[t]))
if B_:
beam_enc_out = torch.stack([b[1] for b in B_enc_out])
beam_dec_out, beam_state = self.decoder.batch_score(B_)
beam_logp = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out),
dim=-1,
)
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
if self.use_lm:
beam_lm_scores, beam_lm_states = self.lm.batch_score(
self.create_lm_batch_inputs([b.yseq for b in B_]),
[b.lm_state for b in B_],
None,
)
for i, hyp in enumerate(B_):
new_hyp = Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
A.append(new_hyp)
if B_enc_out[i][0] == (t_max - 1):
final.append(new_hyp)
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq[:] + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = beam_lm_states[i]
A.append(new_hyp)
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
B = self.recombine_hyps(B)
if final:
return final
return B
def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Time synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
enc_out: Encoder output sequence. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
if self.search_cache is not None:
B = self.search_cache
else:
B = [
Hypothesis(
yseq=[0],
score=0.0,
dec_state=self.decoder.init_state(1),
)
]
if self.use_lm:
B[0].lm_state = self.lm.zero_state()
for enc_out_t in enc_out:
A = []
C = B
enc_out_t = enc_out_t.unsqueeze(0)
for v in range(self.max_sym_exp):
D = []
beam_dec_out, beam_state = self.decoder.batch_score(C)
beam_logp = torch.log_softmax(
self.joint_network(enc_out_t, beam_dec_out),
dim=-1,
)
beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
seq_A = [h.yseq for h in A]
for i, hyp in enumerate(C):
if hyp.yseq not in seq_A:
A.append(
Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
)
else:
dict_pos = seq_A.index(hyp.yseq)
A[dict_pos].score = np.logaddexp(
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
)
if v < (self.max_sym_exp - 1):
if self.use_lm:
beam_lm_scores, beam_lm_states = self.lm.batch_score(
self.create_lm_batch_inputs([c.yseq for c in C]),
[c.lm_state for c in C],
None,
)
for i, hyp in enumerate(C):
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = beam_lm_states[i]
D.append(new_hyp)
C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size]
B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
return B
def modified_adaptive_expansion_search(
self,
enc_out: torch.Tensor,
) -> List[ExtendedHypothesis]:
"""Modified version of Adaptive Expansion Search (mAES).
Based on AES (https://ieeexplore.ieee.org/document/9250505) and
NSC (https://arxiv.org/abs/2201.05420).
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
nbest_hyps: N-best hypothesis.
"""
if self.search_cache is not None:
kept_hyps = self.search_cache
else:
init_tokens = [
ExtendedHypothesis(
yseq=[0],
score=0.0,
dec_state=self.decoder.init_state(1),
)
]
beam_dec_out, beam_state = self.decoder.batch_score(
init_tokens,
)
if self.use_lm:
beam_lm_scores, beam_lm_states = self.lm.batch_score(
self.create_lm_batch_inputs([h.yseq for h in init_tokens]),
[h.lm_state for h in init_tokens],
None,
)
lm_state = beam_lm_states[0]
lm_score = beam_lm_scores[0]
else:
lm_state = None
lm_score = None
kept_hyps = [
ExtendedHypothesis(
yseq=[0],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
dec_out=beam_dec_out[0],
lm_state=lm_state,
lm_score=lm_score,
)
]
for enc_out_t in enc_out:
hyps = kept_hyps
kept_hyps = []
beam_enc_out = enc_out_t.unsqueeze(0)
list_b = []
for n in range(self.nstep):
beam_dec_out = torch.stack([h.dec_out for h in hyps])
beam_logp, beam_idx = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out),
dim=-1,
).topk(self.max_candidates, dim=-1)
k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp)
list_exp = []
for i, hyp in enumerate(hyps):
for k, new_score in k_expansions[i]:
new_hyp = ExtendedHypothesis(
yseq=hyp.yseq[:],
score=new_score,
dec_out=hyp.dec_out,
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
lm_score=hyp.lm_score,
)
if k == 0:
list_b.append(new_hyp)
else:
new_hyp.yseq.append(int(k))
if self.use_lm:
new_hyp.score += self.lm_weight * float(hyp.lm_score[k])
list_exp.append(new_hyp)
if not list_exp:
kept_hyps = sorted(
self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True
)[: self.beam_size]
break
else:
beam_dec_out, beam_state = self.decoder.batch_score(
list_exp,
)
if self.use_lm:
beam_lm_scores, beam_lm_states = self.lm.batch_score(
self.create_lm_batch_inputs([h.yseq for h in list_exp]),
[h.lm_state for h in list_exp],
None,
)
if n < (self.nstep - 1):
for i, hyp in enumerate(list_exp):
hyp.dec_out = beam_dec_out[i]
hyp.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
hyp.lm_state = beam_lm_states[i]
hyp.lm_score = beam_lm_scores[i]
hyps = list_exp[:]
else:
beam_logp = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out),
dim=-1,
)
for i, hyp in enumerate(list_exp):
hyp.score += float(beam_logp[i, 0])
hyp.dec_out = beam_dec_out[i]
hyp.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
hyp.lm_state = beam_lm_states[i]
hyp.lm_score = beam_lm_scores[i]
kept_hyps = sorted(
self.recombine_hyps(list_b + list_exp),
key=lambda x: x.score,
reverse=True,
)[: self.beam_size]
return kept_hyps

View File

@ -6,6 +6,8 @@
"""Common functions for ASR."""
from typing import List, Optional, Tuple
import json
import logging
import sys
@ -13,7 +15,10 @@ import sys
from itertools import groupby
import numpy as np
import six
import torch
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.models.joint_net.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
@ -247,3 +252,148 @@ class ErrorCalculator(object):
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)
class ErrorCalculatorTransducer:
"""Calculate CER and WER for transducer models.
Args:
decoder: Decoder module.
joint_network: Joint Network module.
token_list: List of token units.
sym_space: Space symbol.
sym_blank: Blank symbol.
report_cer: Whether to compute CER.
report_wer: Whether to compute WER.
"""
def __init__(
self,
decoder,
joint_network: JointNetwork,
token_list: List[int],
sym_space: str,
sym_blank: str,
report_cer: bool = False,
report_wer: bool = False,
) -> None:
"""Construct an ErrorCalculatorTransducer object."""
super().__init__()
self.beam_search = BeamSearchTransducer(
decoder=decoder,
joint_network=joint_network,
beam_size=1,
search_type="default",
score_norm=False,
)
self.decoder = decoder
self.token_list = token_list
self.space = sym_space
self.blank = sym_blank
self.report_cer = report_cer
self.report_wer = report_wer
def __call__(
self, encoder_out: torch.Tensor, target: torch.Tensor
) -> Tuple[Optional[float], Optional[float]]:
"""Calculate sentence-level WER or/and CER score for Transducer model.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
target: Target label ID sequences. (B, L)
Returns:
: Sentence-level CER score.
: Sentence-level WER score.
"""
cer, wer = None, None
batchsize = int(encoder_out.size(0))
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
char_pred, char_target = self.convert_to_char(pred, target)
if self.report_cer:
cer = self.calculate_cer(char_pred, char_target)
if self.report_wer:
wer = self.calculate_wer(char_pred, char_target)
return cer, wer
def convert_to_char(
self, pred: torch.Tensor, target: torch.Tensor
) -> Tuple[List, List]:
"""Convert label ID sequences to character sequences.
Args:
pred: Prediction label ID sequences. (B, U)
target: Target label ID sequences. (B, L)
Returns:
char_pred: Prediction character sequences. (B, ?)
char_target: Target character sequences. (B, ?)
"""
char_pred, char_target = [], []
for i, pred_i in enumerate(pred):
char_pred_i = [self.token_list[int(h)] for h in pred_i]
char_target_i = [self.token_list[int(r)] for r in target[i]]
char_pred_i = "".join(char_pred_i).replace(self.space, " ")
char_pred_i = char_pred_i.replace(self.blank, "")
char_target_i = "".join(char_target_i).replace(self.space, " ")
char_target_i = char_target_i.replace(self.blank, "")
char_pred.append(char_pred_i)
char_target.append(char_target_i)
return char_pred, char_target
def calculate_cer(
self, char_pred: torch.Tensor, char_target: torch.Tensor
) -> float:
"""Calculate sentence-level CER score.
Args:
char_pred: Prediction character sequences. (B, ?)
char_target: Target character sequences. (B, ?)
Returns:
: Average sentence-level CER score.
"""
import editdistance
distances, lens = [], []
for i, char_pred_i in enumerate(char_pred):
pred = char_pred_i.replace(" ", "")
target = char_target[i].replace(" ", "")
distances.append(editdistance.eval(pred, target))
lens.append(len(target))
return float(sum(distances)) / sum(lens)
def calculate_wer(
self, char_pred: torch.Tensor, char_target: torch.Tensor
) -> float:
"""Calculate sentence-level WER score.
Args:
char_pred: Prediction character sequences. (B, ?)
char_target: Target character sequences. (B, ?)
Returns:
: Average sentence-level WER score
"""
import editdistance
distances, lens = [], []
for i, char_pred_i in enumerate(char_pred):
pred = char_pred_i.replace("", " ").split()
target = char_target[i].replace("", " ").split()
distances.append(editdistance.eval(pred, target))
lens.append(len(target))
return float(sum(distances)) / sum(lens)

View File

@ -440,4 +440,79 @@ class StreamSinusoidalPositionEncoder(torch.nn.Module):
outputs = F.pad(outputs, (pad_left, pad_right))
outputs = outputs.transpose(1, 2)
return outputs
class StreamingRelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding.
Args:
size: Module size.
max_len: Maximum input length.
dropout_rate: Dropout rate.
"""
def __init__(
self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
) -> None:
"""Construct a RelativePositionalEncoding object."""
super().__init__()
self.size = size
self.pe = None
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
"""Reset positional encoding.
Args:
x: Input sequences. (B, T, ?)
left_context: Number of frames in left context.
"""
time1 = x.size(1) + left_context
if self.pe is not None:
if self.pe.size(1) >= time1 * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(device=x.device, dtype=x.dtype)
return
pe_positive = torch.zeros(time1, self.size)
pe_negative = torch.zeros(time1, self.size)
position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.size, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.size)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
pe_negative = pe_negative[1:].unsqueeze(0)
self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
dtype=x.dtype, device=x.device
)
def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
"""Compute positional encoding.
Args:
x: Input sequences. (B, T, ?)
left_context: Number of frames in left context.
Returns:
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
"""
self.extend_pe(x, left_context=left_context)
time1 = x.size(1) + left_context
pos_enc = self.pe[
:, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
]
pos_enc = self.dropout(pos_enc)
return pos_enc

View File

@ -3,7 +3,7 @@
"""Network related utility tools."""
import logging
from typing import Dict
from typing import Dict, List, Tuple
import numpy as np
import torch
@ -506,3 +506,196 @@ def get_activation(act):
}
return activation_funcs[act]()
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
Args:
message: Error message to display.
actual_size: The size that cannot pass the subsampling.
limit: The size limit for subsampling.
"""
def __init__(self, message: str, actual_size: int, limit: int) -> None:
"""Construct a TooShortUttError module."""
super().__init__(message)
self.actual_size = actual_size
self.limit = limit
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
"""Check if the input is too short for subsampling.
Args:
sub_factor: Subsampling factor for Conv2DSubsampling.
size: Input size.
Returns:
: Whether an error should be sent.
: Size limit for specified subsampling factor.
"""
if sub_factor == 2 and size < 3:
return True, 7
elif sub_factor == 4 and size < 7:
return True, 7
elif sub_factor == 6 and size < 11:
return True, 11
return False, -1
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
"""Get conv2D second layer parameters for given subsampling factor.
Args:
sub_factor: Subsampling factor (1/X).
input_size: Input size.
Returns:
: Kernel size for second convolution.
: Stride for second convolution.
: Conv2DSubsampling output size.
"""
if sub_factor == 2:
return 3, 1, (((input_size - 1) // 2 - 2))
elif sub_factor == 4:
return 3, 2, (((input_size - 1) // 2 - 1) // 2)
elif sub_factor == 6:
return 5, 3, (((input_size - 1) // 2 - 2) // 3)
else:
raise ValueError(
"subsampling_factor parameter should be set to either 2, 4 or 6."
)
def make_chunk_mask(
size: int,
chunk_size: int,
left_chunk_size: int = 0,
device: torch.device = None,
) -> torch.Tensor:
"""Create chunk mask for the subsequent steps (size, size).
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
Args:
size: Size of the source mask.
chunk_size: Number of frames in chunk.
left_chunk_size: Size of the left context in chunks (0 means full context).
device: Device for the mask tensor.
Returns:
mask: Chunk mask. (size, size)
"""
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if left_chunk_size <= 0:
start = 0
else:
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
end = min((i // chunk_size + 1) * chunk_size, size)
mask[i, start:end] = True
return ~mask
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create source mask for given lengths.
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
Args:
lengths: Sequence lengths. (B,)
Returns:
: Mask for the sequence lengths. (B, max_len)
"""
max_len = lengths.max()
batch_size = lengths.size(0)
expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
return expanded_lengths >= lengths.unsqueeze(1)
def get_transducer_task_io(
labels: torch.Tensor,
encoder_out_lens: torch.Tensor,
ignore_id: int = -1,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get Transducer loss I/O.
Args:
labels: Label ID sequences. (B, L)
encoder_out_lens: Encoder output lengths. (B,)
ignore_id: Padding symbol ID.
blank_id: Blank symbol ID.
Returns:
decoder_in: Decoder inputs. (B, U)
target: Target label ID sequences. (B, U)
t_len: Time lengths. (B,)
u_len: Label lengths. (B,)
"""
def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
"""Create padded batch of labels from a list of labels sequences.
Args:
labels: Labels sequences. [B x (?)]
padding_value: Padding value.
Returns:
labels: Batch of padded labels sequences. (B,)
"""
batch_size = len(labels)
padded = (
labels[0]
.new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
.fill_(padding_value)
)
for i in range(batch_size):
padded[i, : labels[i].size(0)] = labels[i]
return padded
device = labels.device
labels_unpad = [y[y != ignore_id] for y in labels]
blank = labels[0].new([blank_id])
decoder_in = pad_list(
[torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
).to(device)
target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
encoder_out_lens = list(map(int, encoder_out_lens))
t_len = torch.IntTensor(encoder_out_lens).to(device)
u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
return decoder_in, target, t_len, u_len
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
"""Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
if t.size(dim) == pad_len:
return t
else:
pad_size = list(t.shape)
pad_size[dim] = pad_len - t.size(dim)
return torch.cat(
[t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
)

View File

@ -6,6 +6,8 @@
"""Repeat the same layer definition."""
from typing import Dict, List, Optional
import torch
@ -31,3 +33,92 @@ def repeat(N, fn):
"""
return MultiSequential(*[fn(n) for n in range(N)])
class MultiBlocks(torch.nn.Module):
"""MultiBlocks definition.
Args:
block_list: Individual blocks of the encoder architecture.
output_size: Architecture output size.
norm_class: Normalization module class.
norm_args: Normalization module arguments.
"""
def __init__(
self,
block_list: List[torch.nn.Module],
output_size: int,
norm_class: torch.nn.Module = torch.nn.LayerNorm,
) -> None:
"""Construct a MultiBlocks object."""
super().__init__()
self.blocks = torch.nn.ModuleList(block_list)
self.norm_blocks = norm_class(output_size)
self.num_blocks = len(block_list)
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
"""Initialize/Reset encoder streaming cache.
Args:
left_context: Number of left frames during chunk-by-chunk inference.
device: Device to use for cache tensor.
"""
for idx in range(self.num_blocks):
self.blocks[idx].reset_streaming_cache(left_context, device)
def forward(
self,
x: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward each block of the encoder architecture.
Args:
x: MultiBlocks input sequences. (B, T, D_block_1)
pos_enc: Positional embedding sequences.
mask: Source mask. (B, T)
chunk_mask: Chunk mask. (T_2, T_2)
Returns:
x: Output sequences. (B, T, D_block_N)
"""
for block_index, block in enumerate(self.blocks):
x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
x = self.norm_blocks(x)
return x
def chunk_forward(
self,
x: torch.Tensor,
pos_enc: torch.Tensor,
mask: torch.Tensor,
chunk_size: int = 0,
left_context: int = 0,
right_context: int = 0,
) -> torch.Tensor:
"""Forward each block of the encoder architecture.
Args:
x: MultiBlocks input sequences. (B, T, D_block_1)
pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
mask: Source mask. (B, T_2)
left_context: Number of frames in left context.
right_context: Number of frames in right context.
Returns:
x: MultiBlocks output sequences. (B, T, D_block_N)
"""
for block_idx, block in enumerate(self.blocks):
x, pos_enc = block.chunk_forward(
x,
pos_enc,
mask,
chunk_size=chunk_size,
left_context=left_context,
right_context=right_context,
)
x = self.norm_blocks(x)
return x

View File

@ -11,6 +11,10 @@ import torch.nn.functional as F
from funasr.modules.embedding import PositionalEncoding
import logging
from funasr.modules.streaming_utils.utils import sequence_mask
from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
from typing import Optional, Tuple, Union
import math
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
@ -407,3 +411,201 @@ class Conv1dSubsampling(torch.nn.Module):
var_dict_tf[name_tf].shape))
return var_dict_torch_update
class StreamingConvInput(torch.nn.Module):
"""Streaming ConvInput module definition.
Args:
input_size: Input size.
conv_size: Convolution size.
subsampling_factor: Subsampling factor.
vgg_like: Whether to use a VGG-like network.
output_size: Block output dimension.
"""
def __init__(
self,
input_size: int,
conv_size: Union[int, Tuple],
subsampling_factor: int = 4,
vgg_like: bool = True,
output_size: Optional[int] = None,
) -> None:
"""Construct a ConvInput object."""
super().__init__()
if vgg_like:
if subsampling_factor == 1:
conv_size1, conv_size2 = conv_size
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
)
output_proj = conv_size2 * ((input_size // 2) // 2)
self.subsampling_factor = 1
self.stride_1 = 1
self.create_new_mask = self.create_new_vgg_mask
else:
conv_size1, conv_size2 = conv_size
kernel_1 = int(subsampling_factor / 2)
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d((kernel_1, 2)),
torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d((2, 2)),
)
output_proj = conv_size2 * ((input_size // 2) // 2)
self.subsampling_factor = subsampling_factor
self.create_new_mask = self.create_new_vgg_mask
self.stride_1 = kernel_1
else:
if subsampling_factor == 1:
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
torch.nn.ReLU(),
)
output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
self.subsampling_factor = subsampling_factor
self.kernel_2 = 3
self.stride_2 = 1
self.create_new_mask = self.create_new_conv2d_mask
else:
kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
subsampling_factor,
input_size,
)
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
torch.nn.ReLU(),
)
output_proj = conv_size * conv_2_output_size
self.subsampling_factor = subsampling_factor
self.kernel_2 = kernel_2
self.stride_2 = stride_2
self.create_new_mask = self.create_new_conv2d_mask
self.vgg_like = vgg_like
self.min_frame_length = 7
if output_size is not None:
self.output = torch.nn.Linear(output_proj, output_size)
self.output_size = output_size
else:
self.output = None
self.output_size = output_proj
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode input sequences.
Args:
x: ConvInput input sequences. (B, T, D_feats)
mask: Mask of input sequences. (B, 1, T)
Returns:
x: ConvInput output sequences. (B, sub(T), D_out)
mask: Mask of output sequences. (B, 1, sub(T))
"""
if mask is not None:
mask = self.create_new_mask(mask)
olens = max(mask.eq(0).sum(1))
b, t, f = x.size()
x = x.unsqueeze(1) # (b. 1. t. f)
if chunk_size is not None:
max_input_length = int(
chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
)
x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
x = list(x)
x = torch.stack(x, dim=0)
N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
x = self.conv(x)
_, c, _, f = x.size()
if chunk_size is not None:
x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
else:
x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
if self.output is not None:
x = self.output(x)
return x, mask[:,:olens][:,:x.size(1)]
def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
"""Create a new mask for VGG output sequences.
Args:
mask: Mask of input sequences. (B, T)
Returns:
mask: Mask of output sequences. (B, sub(T))
"""
if self.subsampling_factor > 1:
vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
mask = mask[:, :vgg2_t_len][:, ::2]
else:
mask = mask
return mask
def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
"""Create new conformer mask for Conv2d output sequences.
Args:
mask: Mask of input sequences. (B, T)
Returns:
mask: Mask of output sequences. (B, sub(T))
"""
if self.subsampling_factor > 1:
return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
else:
return mask
def get_size_before_subsampling(self, size: int) -> int:
"""Return the original size before subsampling for a given size.
Args:
size: Number of frames after subsampling.
Returns:
: Number of frames before subsampling.
"""
return size * self.subsampling_factor

View File

@ -38,13 +38,16 @@ from funasr.models.decoder.transformer_decoder import (
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
@ -151,6 +154,7 @@ encoder_choices = ClassChoices(
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
chunk_conformer=ConformerChunkEncoder,
),
type_check=AbsEncoder,
default="rnn",
@ -208,6 +212,16 @@ decoder_choices2 = ClassChoices(
type_check=AbsDecoder,
default="rnn",
)
rnnt_decoder_choices = ClassChoices(
"rnnt_decoder",
classes=dict(
rnnt=RNNTDecoder,
),
type_check=RNNTDecoder,
default="rnnt",
)
predictor_choices = ClassChoices(
name="predictor",
classes=dict(
@ -1332,3 +1346,378 @@ class ASRTaskAligner(ASRTaskParaformer):
) -> Tuple[str, ...]:
retval = ("speech", "text")
return retval
class ASRTransducerTask(AbsTask):
"""ASR Transducer Task definition."""
num_optimizers: int = 1
class_choices_list = [
frontend_choices,
specaug_choices,
normalize_choices,
encoder_choices,
rnnt_decoder_choices,
]
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
"""Add Transducer task arguments.
Args:
cls: ASRTransducerTask object.
parser: Transducer arguments parser.
"""
group = parser.add_argument_group(description="Task related.")
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="Integer-string mapper for tokens.",
)
group.add_argument(
"--split_with_space",
type=str2bool,
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of dimensions for input features.",
)
group.add_argument(
"--init",
type=str_or_none,
default=None,
help="Type of model initialization to use.",
)
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(TransducerModel),
help="The keyword arguments for the model class.",
)
# group.add_argument(
# "--encoder_conf",
# action=NestedDictAction,
# default={},
# help="The keyword arguments for the encoder class.",
# )
group.add_argument(
"--joint_network_conf",
action=NestedDictAction,
default={},
help="The keyword arguments for the joint network class.",
)
group = parser.add_argument_group(description="Preprocess related.")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Whether to apply preprocessing to input data.",
)
group.add_argument(
"--token_type",
type=str,
default="bpe",
choices=["bpe", "char", "word", "phn"],
help="The type of tokens to use during tokenization.",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The path of the sentencepiece model.",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="The 'non_linguistic_symbols' file path.",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Text cleaner to use.",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="g2p method to use if --token_type=phn.",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Normalization value for maximum amplitude scaling.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The RIR SCP file path.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="The probability of the applied RIR convolution.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The path of noise SCP file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability of the applied noise addition.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of the noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --decoder and --decoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
"""Build collate function.
Args:
cls: ASRTransducerTask object.
args: Task arguments.
train: Training mode.
Return:
: Callable collate function.
"""
assert check_argument_types()
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
"""Build pre-processing function.
Args:
cls: ASRTransducerTask object.
args: Task arguments.
train: Training mode.
Return:
: Callable pre-processing function.
"""
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
"""Required data depending on task mode.
Args:
cls: ASRTransducerTask object.
train: Training mode.
inference: Inference mode.
Return:
retval: Required task data.
"""
if not inference:
retval = ("speech", "text")
else:
retval = ("speech",)
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
"""Optional data depending on task mode.
Args:
cls: ASRTransducerTask object.
train: Training mode.
inference: Inference mode.
Return:
retval: Optional task data.
"""
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> TransducerModel:
"""Required data depending on task mode.
Args:
cls: ASRTransducerTask object.
args: Task arguments.
Return:
model: ASR Transducer model.
"""
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size }")
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 4. Encoder
if getattr(args, "encoder", None) is not None:
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size, **args.encoder_conf)
else:
encoder = Encoder(input_size, **args.encoder_conf)
encoder_output_size = encoder.output_size()
# 5. Decoder
rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
decoder = rnnt_decoder_class(
vocab_size,
**args.rnnt_decoder_conf,
)
decoder_output_size = decoder.output_size
if getattr(args, "decoder", None) is not None:
att_decoder_class = decoder_choices.get_class(args.att_decoder)
att_decoder = att_decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**args.decoder_conf,
)
else:
att_decoder = None
# 6. Joint Network
joint_network = JointNetwork(
vocab_size,
encoder_output_size,
decoder_output_size,
**args.joint_network_conf,
)
# 7. Build model
if encoder.unified_model_training:
model = UnifiedTransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
encoder=encoder,
decoder=decoder,
att_decoder=att_decoder,
joint_network=joint_network,
**args.model_conf,
)
else:
model = TransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
encoder=encoder,
decoder=decoder,
att_decoder=att_decoder,
joint_network=joint_network,
**args.model_conf,
)
# 8. Initialize model
if args.init is not None:
raise NotImplementedError(
"Currently not supported.",
"Initialization part will be reworked in a short future.",
)
#assert check_return_type(model)
return model