diff --git a/egs/aishell/rnnt/README.md b/egs/aishell/rnnt/README.md new file mode 100644 index 000000000..45f1f3f98 --- /dev/null +++ b/egs/aishell/rnnt/README.md @@ -0,0 +1,18 @@ + +# Streaming RNN-T Result + +## Training Config +- 8 gpu(Tesla V100) +- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment +- Train config: conf/train_conformer_rnnt_unified.yaml +- chunk config: chunk size 16, full left chunk +- LM config: LM was not used +- Model size: 90M + +## Results (CER) +- Decode config: conf/train_conformer_rnnt_unified.yaml + +| testset | CER(%) | +|:-----------:|:-------:| +| dev | 5.53 | +| test | 6.24 | diff --git a/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml new file mode 100644 index 000000000..26e43c64d --- /dev/null +++ b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml @@ -0,0 +1,8 @@ +# The conformer transducer decoding configuration from @jeon30c +beam_size: 10 +simu_streaming: false +streaming: true +chunk_size: 16 +left_context: 16 +right_context: 0 + diff --git a/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml new file mode 100644 index 000000000..dc3eff2a5 --- /dev/null +++ b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml @@ -0,0 +1,5 @@ +# The conformer transducer decoding configuration from @jeon30c +beam_size: 10 +simu_streaming: true +streaming: false +chunk_size: 16 diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml new file mode 100644 index 000000000..8a1c40cac --- /dev/null +++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml @@ -0,0 +1,80 @@ +encoder: chunk_conformer +encoder_conf: + activation_type: swish + positional_dropout_rate: 0.5 + time_reduction_factor: 2 + unified_model_training: true + default_chunk_size: 16 + jitter_range: 4 + left_chunk_size: 0 + embed_vgg_like: false + subsampling_factor: 4 + linear_units: 2048 + output_size: 512 + attention_heads: 8 + dropout_rate: 0.5 + positional_dropout_rate: 0.5 + attention_dropout_rate: 0.5 + cnn_module_kernel: 15 + num_blocks: 12 + +# decoder related +rnnt_decoder: rnnt +rnnt_decoder_conf: + embed_size: 512 + hidden_size: 512 + embed_dropout_rate: 0.5 + dropout_rate: 0.5 + +joint_network_conf: + joint_space_size: 512 + +# Auxiliary CTC +model_conf: + auxiliary_ctc_weight: 0.0 + +# minibatch related +use_amp: true +batch_type: unsorted +batch_size: 16 +num_workers: 16 + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 200 +val_scheduler_criterion: + - valid + - loss +best_model_criterion: +- - valid + - cer_transducer_chunk + - min +keep_nbest_models: 10 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + +normalize: None + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 40 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 50 + num_time_mask: 5 + +log_interval: 50 diff --git a/egs/aishell/rnnt/local/aishell_data_prep.sh b/egs/aishell/rnnt/local/aishell_data_prep.sh new file mode 100755 index 000000000..83f489b3c --- /dev/null +++ b/egs/aishell/rnnt/local/aishell_data_prep.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Copyright 2017 Xingyu Na +# Apache 2.0 + +#. ./path.sh || exit 1; + +if [ $# != 3 ]; then + echo "Usage: $0 " + echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data" + exit 1; +fi + +aishell_audio_dir=$1 +aishell_text=$2/aishell_transcript_v0.8.txt +output_dir=$3 + +train_dir=$output_dir/data/local/train +dev_dir=$output_dir/data/local/dev +test_dir=$output_dir/data/local/test +tmp_dir=$output_dir/data/local/tmp + +mkdir -p $train_dir +mkdir -p $dev_dir +mkdir -p $test_dir +mkdir -p $tmp_dir + +# data directory check +if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then + echo "Error: $0 requires two directory arguments" + exit 1; +fi + +# find wav audio file for train, dev and test resp. +find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist +n=`cat $tmp_dir/wav.flist | wc -l` +[ $n -ne 141925 ] && \ + echo Warning: expected 141925 data data files, found $n + +grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; +grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; +grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; + +rm -r $tmp_dir + +# Transcriptions preparation +for dir in $train_dir $dev_dir $test_dir; do + echo Preparing $dir transcriptions + sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list + paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all + utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt + awk '{print $1}' $dir/transcripts.txt > $dir/utt.list + utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp + sort -u $dir/transcripts.txt > $dir/text +done + +mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test + +for f in wav.scp text; do + cp $train_dir/$f $output_dir/data/train/$f || exit 1; + cp $dev_dir/$f $output_dir/data/dev/$f || exit 1; + cp $test_dir/$f $output_dir/data/test/$f || exit 1; +done + +echo "$0: AISHELL data preparation succeeded" +exit 0; diff --git a/egs/aishell/rnnt/path.sh b/egs/aishell/rnnt/path.sh new file mode 100644 index 000000000..7972642d0 --- /dev/null +++ b/egs/aishell/rnnt/path.sh @@ -0,0 +1,5 @@ +export FUNASR_DIR=$PWD/../../.. + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PATH=$FUNASR_DIR/funasr/bin:$PATH diff --git a/egs/aishell/rnnt/run.sh b/egs/aishell/rnnt/run.sh new file mode 100755 index 000000000..bcd4a8b9f --- /dev/null +++ b/egs/aishell/rnnt/run.sh @@ -0,0 +1,247 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; + +# machines configuration +CUDA_VISIBLE_DEVICES="0,1,2,3" +gpu_num=4 +count=1 +gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding +# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob +njob=5 +train_cmd=utils/run.pl +infer_cmd=utils/run.pl + +# general configuration +feats_dir= #feature output dictionary +exp_dir= +lang=zh +dumpdir=dump/fbank +feats_type=fbank +token_type=char +scp=feats.scp +type=kaldi_ark +stage=0 +stop_stage=4 + +# feature configuration +feats_dim=80 +sample_frequency=16000 +nj=32 +speed_perturb="0.9,1.0,1.1" + +# data +data_aishell= + +# exp tag +tag="exp1" + +. utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set=train +valid_set=dev +test_sets="dev test" + +asr_config=conf/train_conformer_rnnt_unified.yaml +model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}" + +inference_config=conf/decode_rnnt_conformer_streaming.yaml +inference_asr_model=valid.cer_transducer_chunk.ave_5best.pth + +# you can set gpu num for decoding here +gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default +ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') + +if ${gpu_inference}; then + inference_nj=$[${ngpu}*${njob}] + _ngpu=1 +else + inference_nj=$njob + _ngpu=0 +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: Data preparation" + # Data preparation + local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir} + for x in train dev test; do + cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org + paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \ + > ${feats_dir}/data/${x}/text + utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org + mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text + done +fi + +feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir} +feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir} +feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir} +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Feature Generation" + # compute fbank features + fbankdir=${feats_dir}/fbank + utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \ + ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train + utils/fix_data_feat.sh ${fbankdir}/train + utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \ + ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev + utils/fix_data_feat.sh ${fbankdir}/dev + utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \ + ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test + utils/fix_data_feat.sh ${fbankdir}/test + + # compute global cmvn + utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \ + ${fbankdir}/train ${exp_dir}/exp/make_fbank/train + + # apply cmvn + utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \ + ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir} + utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \ + ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir} + utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \ + ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir} + + cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir} + cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir} + cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir} + + utils/fix_data_feat.sh ${feat_train_dir} + utils/fix_data_feat.sh ${feat_dev_dir} + utils/fix_data_feat.sh ${feat_test_dir} + + #generate ark list + utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir} + utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir} +fi + +token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt +echo "dictionary: ${token_list}" +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: Dictionary Preparation" + mkdir -p ${feats_dir}/data/${lang}_token_list/char/ + + echo "make a dictionary" + echo "" > ${token_list} + echo "" >> ${token_list} + echo "" >> ${token_list} + utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} + num_token=$(cat ${token_list} | wc -l) + echo "" >> ${token_list} + vocab_size=$(cat ${token_list} | wc -l) + awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char + awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char + mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train + mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev + cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train + cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev +fi + +# Training Stage +world_size=$gpu_num # run on one machine +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: Training" + mkdir -p ${exp_dir}/exp/${model_dir} + mkdir -p ${exp_dir}/exp/${model_dir}/log + INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + asr_train_transducer.py \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --token_type char \ + --token_list $token_list \ + --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \ + --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \ + --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \ + --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \ + --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \ + --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \ + --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \ + --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \ + --resume true \ + --output_dir ${exp_dir}/exp/${model_dir} \ + --config $asr_config \ + --input_size $feats_dim \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --multiprocessing_distributed true \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1 + } & + done + wait +fi + +# Testing Stage +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "stage 4: Inference" + for dset in ${test_sets}; do + asr_exp=${exp_dir}/exp/${model_dir} + inference_tag="$(basename "${inference_config}" .yaml)" + _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" + _logdir="${_dir}/logdir" + if [ -d ${_dir} ]; then + echo "${_dir} is already exists. if you want to decode again, please delete this dir first." + exit 0 + fi + mkdir -p "${_logdir}" + _data="${feats_dir}/${dumpdir}/${dset}" + key_file=${_data}/${scp} + num_scp_file="$(<${key_file} wc -l)" + _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") + split_scps= + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob} \ + --gpuid_list ${gpuid_list} \ + --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ + --key_file "${_logdir}"/keys.JOB.scp \ + --asr_train_config "${asr_exp}"/config.yaml \ + --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode rnnt \ + ${_opts} + + for f in token token_int score text; do + if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | sort -k1 >"${_dir}/${f}" + fi + done + python utils/proce_text.py ${_dir}/text ${_dir}/text.proc + python utils/proce_text.py ${_data}/text ${_data}/text.proc + python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + done +fi diff --git a/egs/aishell/rnnt/utils b/egs/aishell/rnnt/utils new file mode 120000 index 000000000..4072eacc1 --- /dev/null +++ b/egs/aishell/rnnt/utils @@ -0,0 +1 @@ +../transformer/utils \ No newline at end of file diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 7add9604b..2b6716ed8 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -134,6 +134,11 @@ def get_parser(): help="Pretrained model tag. If specify this option, *_train_config and " "*_file will be overwritten", ) + group.add_argument( + "--beam_search_config", + default={}, + help="The keyword arguments for transducer beam search.", + ) group = parser.add_argument_group("Beam-search related") group.add_argument( @@ -171,6 +176,41 @@ def get_parser(): group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") group.add_argument("--streaming", type=str2bool, default=False) + group.add_argument("--simu_streaming", type=str2bool, default=False) + group.add_argument("--chunk_size", type=int, default=16) + group.add_argument("--left_context", type=int, default=16) + group.add_argument("--right_context", type=int, default=0) + group.add_argument( + "--display_partial_hypotheses", + type=bool, + default=False, + help="Whether to display partial hypotheses during chunk-by-chunk inference.", + ) + + group = parser.add_argument_group("Dynamic quantization related") + group.add_argument( + "--quantize_asr_model", + type=bool, + default=False, + help="Apply dynamic quantization to ASR model.", + ) + group.add_argument( + "--quantize_modules", + nargs="*", + default=None, + help="""Module names to apply dynamic quantization on. + The module names are provided as a list, where each name is separated + by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]). + Each specified name should be an attribute of 'torch.nn', e.g.: + torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", + ) + group.add_argument( + "--quantize_dtype", + type=str, + default="qint8", + choices=["float16", "qint8"], + help="Dtype for dynamic quantization.", + ) group = parser.add_argument_group("Text converter related") group.add_argument( @@ -268,6 +308,9 @@ def inference_launch_funasr(**kwargs): elif mode == "mfcca": from funasr.bin.asr_inference_mfcca import inference_modelscope return inference_modelscope(**kwargs) + elif mode == "rnnt": + from funasr.bin.asr_inference_rnnt import inference + return inference(**kwargs) else: logging.info("Unknown decoding mode: {}".format(mode)) return None diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py index 2189a718d..bff87022e 100644 --- a/funasr/bin/asr_inference_rnnt.py +++ b/funasr/bin/asr_inference_rnnt.py @@ -1,396 +1,149 @@ #!/usr/bin/env python3 + +""" Inference class definition for Transducer models.""" + +from __future__ import annotations + import argparse import logging +import math import sys -import time -import copy -import os -import codecs -import tempfile -import requests from pathlib import Path -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union -from typing import Dict -from typing import Any -from typing import List +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch -from typeguard import check_argument_types +from packaging.version import parse as V +from typeguard import check_argument_types, check_return_type +from funasr.modules.beam_search.beam_search_transducer import ( + BeamSearchTransducer, + Hypothesis, +) +from funasr.modules.nets_utils import TooShortUttError from funasr.fileio.datadir_writer import DatadirWriter -from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch -from funasr.modules.beam_search.beam_search import Hypothesis -from funasr.modules.scorers.ctc import CTCPrefixScorer -from funasr.modules.scorers.length_bonus import LengthBonus -from funasr.modules.subsampling import TooShortUttError -from funasr.tasks.asr import ASRTaskParaformer as ASRTask +from funasr.tasks.asr import ASRTransducerTask from funasr.tasks.lm import LMTask from funasr.text.build_tokenizer import build_tokenizer from funasr.text.token_id_converter import TokenIDConverter from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse +from funasr.utils.types import str2bool, str2triple_str, str_or_none from funasr.utils.cli_utils import get_commandline_args -from funasr.utils.types import str2bool -from funasr.utils.types import str2triple_str -from funasr.utils.types import str_or_none -from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend -from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer -from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export - class Speech2Text: - """Speech2Text class - - Examples: - >>> import soundfile - >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") - >>> audio, rate = soundfile.read("speech.wav") - >>> speech2text(audio) - [(text, token, token_int, hypothesis object), ...] - + """Speech2Text class for Transducer models. + Args: + asr_train_config: ASR model training config path. + asr_model_file: ASR model path. + beam_search_config: Beam search config path. + lm_train_config: Language Model training config path. + lm_file: Language Model config path. + token_type: Type of token units. + bpemodel: BPE model path. + device: Device to use for inference. + beam_size: Size of beam during search. + dtype: Data type. + lm_weight: Language model weight. + quantize_asr_model: Whether to apply dynamic quantization to ASR model. + quantize_modules: List of module names to apply dynamic quantization on. + quantize_dtype: Dynamic quantization data type. + nbest: Number of final hypothesis. + streaming: Whether to perform chunk-by-chunk inference. + chunk_size: Number of frames in chunk AFTER subsampling. + left_context: Number of frames in left context AFTER subsampling. + right_context: Number of frames in right context AFTER subsampling. + display_partial_hypotheses: Whether to display partial hypotheses. """ def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - frontend_conf: dict = None, - hotword_list_or_file: str = None, - **kwargs, - ): - assert check_argument_types() + self, + asr_train_config: Union[Path, str] = None, + asr_model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + beam_search_config: Dict[str, Any] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + beam_size: int = 5, + dtype: str = "float32", + lm_weight: float = 1.0, + quantize_asr_model: bool = False, + quantize_modules: List[str] = None, + quantize_dtype: str = "qint8", + nbest: int = 1, + streaming: bool = False, + simu_streaming: bool = False, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + display_partial_hypotheses: bool = False, + ) -> None: + """Construct a Speech2Text object.""" + super().__init__() - # 1. Build ASR model - scorers = {} - asr_model, asr_train_args = ASRTask.build_model_from_file( + assert check_argument_types() + asr_model, asr_train_args = ASRTransducerTask.build_model_from_file( asr_train_config, asr_model_file, cmvn_file, device ) + frontend = None if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) - logging.info("asr_model: {}".format(asr_model)) - logging.info("asr_train_args: {}".format(asr_train_args)) - asr_model.to(dtype=getattr(torch, dtype)).eval() + if quantize_asr_model: + if quantize_modules is not None: + if not all([q in ["LSTM", "Linear"] for q in quantize_modules]): + raise ValueError( + "Only 'Linear' and 'LSTM' modules are currently supported" + " by PyTorch and in --quantize_modules" + ) - if asr_model.ctc != None: - ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) - scorers.update( - ctc=ctc - ) - token_list = asr_model.token_list - scorers.update( - length_bonus=LengthBonus(len(token_list)), - ) + q_config = set([getattr(torch.nn, q) for q in quantize_modules]) + else: + q_config = {torch.nn.Linear} + + if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")): + raise ValueError( + "float16 dtype for dynamic quantization is not supported with torch" + " version < 1.5.0. Switching to qint8 dtype instead." + ) + q_dtype = getattr(torch, quantize_dtype) + + asr_model = torch.quantization.quantize_dynamic( + asr_model, q_config, dtype=q_dtype + ).eval() + else: + asr_model.to(dtype=getattr(torch, dtype)).eval() - # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device ) - scorers["lm"] = lm.lm - - # 3. Build ngram model - # ngram is not supported now - ngram = None - scorers["ngram"] = ngram + lm_scorer = lm.lm + else: + lm_scorer = None # 4. Build BeamSearch object - # transducer is not supported now - beam_search_transducer = None + if beam_search_config is None: + beam_search_config = {} - weights = dict( - decoder=1.0 - ctc_weight, - ctc=ctc_weight, - lm=lm_weight, - ngram=ngram_weight, - length_bonus=penalty, + beam_search = BeamSearchTransducer( + asr_model.decoder, + asr_model.joint_network, + beam_size, + lm=lm_scorer, + lm_weight=lm_weight, + nbest=nbest, + **beam_search_config, ) - beam_search = BeamSearch( - beam_size=beam_size, - weights=weights, - scorers=scorers, - sos=asr_model.sos, - eos=asr_model.eos, - vocab_size=len(token_list), - token_list=token_list, - pre_beam_score_key=None if ctc_weight == 1.0 else "full", - ) - - beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() - for scorer in scorers.values(): - if isinstance(scorer, torch.nn.Module): - scorer.to(device=device, dtype=getattr(torch, dtype)).eval() - - logging.info(f"Decoding device={device}, dtype={dtype}") - - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text - if token_type is None: - token_type = asr_train_args.token_type - if bpemodel is None: - bpemodel = asr_train_args.bpemodel - - if token_type is None: - tokenizer = None - elif token_type == "bpe": - if bpemodel is not None: - tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) - else: - tokenizer = None - else: - tokenizer = build_tokenizer(token_type=token_type) - converter = TokenIDConverter(token_list=token_list) - logging.info(f"Text tokenizer: {tokenizer}") - - self.asr_model = asr_model - self.asr_train_args = asr_train_args - self.converter = converter - self.tokenizer = tokenizer - - # 6. [Optional] Build hotword list from str, local file or url - self.hotword_list = None - self.hotword_list = self.generate_hotwords_list(hotword_list_or_file) - - is_use_lm = lm_weight != 0.0 and lm_file is not None - if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: - beam_search = None - self.beam_search = beam_search - logging.info(f"Beam_search: {self.beam_search}") - self.beam_search_transducer = beam_search_transducer - self.maxlenratio = maxlenratio - self.minlenratio = minlenratio - self.device = device - self.dtype = dtype - self.nbest = nbest - self.frontend = frontend - self.encoder_downsampling_factor = 1 - if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d": - self.encoder_downsampling_factor = 4 - - @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ): - """Inference - - Args: - speech: Input speech data - Returns: - text, token, token_int, hyp - - """ - assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} - - # a. To device - batch = to_device(batch, device=self.device) - - # b. Forward Encoder - enc, enc_len = self.asr_model.encode(**batch) - if isinstance(enc, tuple): - enc = enc[0] - # assert len(enc) == 1, len(enc) - enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor - - predictor_outs = self.asr_model.calc_predictor(enc, enc_len) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ - predictor_outs[2], predictor_outs[3] - pre_token_length = pre_token_length.round().long() - if torch.max(pre_token_length) < 1: - return [] - if not isinstance(self.asr_model, ContextualParaformer): - if self.hotword_list: - logging.warning("Hotword is given but asr model is not a ContextualParaformer.") - decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - else: - decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - - results = [] - b, n, d = decoder_out.size() - for i in range(b): - x = enc[i, :enc_len[i], :] - am_scores = decoder_out[i, :pre_token_length[i], :] - if self.beam_search is not None: - nbest_hyps = self.beam_search( - x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio - ) - - nbest_hyps = nbest_hyps[: self.nbest] - else: - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] - - for hyp in nbest_hyps: - assert isinstance(hyp, (Hypothesis)), type(hyp) - - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) - - # Change integer-ids to tokens - token = self.converter.ids2tokens(token_int) - - if self.tokenizer is not None: - text = self.tokenizer.tokens2text(token) - else: - text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) - - # assert check_return_type(results) - return results - - def generate_hotwords_list(self, hotword_list_or_file): - # for None - if hotword_list_or_file is None: - hotword_list = None - # for local txt inputs - elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): - logging.info("Attempting to parse hotwords from local txt...") - hotword_list = [] - hotword_str_list = [] - with codecs.open(hotword_list_or_file, 'r') as fin: - for line in fin.readlines(): - hw = line.strip() - hotword_str_list.append(hw) - hotword_list.append(self.converter.tokens2ids([i for i in hw])) - hotword_list.append([self.asr_model.sos]) - hotword_str_list.append('') - logging.info("Initialized hotword list from file: {}, hotword list: {}." - .format(hotword_list_or_file, hotword_str_list)) - # for url, download and generate txt - elif hotword_list_or_file.startswith('http'): - logging.info("Attempting to parse hotwords from url...") - work_dir = tempfile.TemporaryDirectory().name - if not os.path.exists(work_dir): - os.makedirs(work_dir) - text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) - local_file = requests.get(hotword_list_or_file) - open(text_file_path, "wb").write(local_file.content) - hotword_list_or_file = text_file_path - hotword_list = [] - hotword_str_list = [] - with codecs.open(hotword_list_or_file, 'r') as fin: - for line in fin.readlines(): - hw = line.strip() - hotword_str_list.append(hw) - hotword_list.append(self.converter.tokens2ids([i for i in hw])) - hotword_list.append([self.asr_model.sos]) - hotword_str_list.append('') - logging.info("Initialized hotword list from file: {}, hotword list: {}." - .format(hotword_list_or_file, hotword_str_list)) - # for text str input - elif not hotword_list_or_file.endswith('.txt'): - logging.info("Attempting to parse hotwords as str...") - hotword_list = [] - hotword_str_list = [] - for hw in hotword_list_or_file.strip().split(): - hotword_str_list.append(hw) - hotword_list.append(self.converter.tokens2ids([i for i in hw])) - hotword_list.append([self.asr_model.sos]) - hotword_str_list.append('') - logging.info("Hotword list: {}.".format(hotword_str_list)) - else: - hotword_list = None - return hotword_list - -class Speech2TextExport: - """Speech2TextExport class - - """ - - def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - frontend_conf: dict = None, - hotword_list_or_file: str = None, - **kwargs, - ): - - # 1. Build ASR model - asr_model, asr_train_args = ASRTask.build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, device - ) - frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) - - logging.info("asr_model: {}".format(asr_model)) - logging.info("asr_train_args: {}".format(asr_train_args)) - asr_model.to(dtype=getattr(torch, dtype)).eval() token_list = asr_model.token_list - - - logging.info(f"Decoding device={device}, dtype={dtype}") - - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: @@ -407,197 +160,277 @@ class Speech2TextExport: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") - - # self.asr_model = asr_model + + self.asr_model = asr_model self.asr_train_args = asr_train_args - self.converter = converter - self.tokenizer = tokenizer - self.device = device self.dtype = dtype self.nbest = nbest - self.frontend = frontend - model = Paraformer_export(asr_model, onnx=False) - self.asr_model = model + self.converter = converter + self.tokenizer = tokenizer + + self.beam_search = beam_search + self.streaming = streaming + self.simu_streaming = simu_streaming + self.chunk_size = max(chunk_size, 0) + self.left_context = max(left_context, 0) + self.right_context = max(right_context, 0) + + if not streaming or chunk_size == 0: + self.streaming = False + self.asr_model.encoder.dynamic_chunk_training = False + if not simu_streaming or chunk_size == 0: + self.simu_streaming = False + self.asr_model.encoder.dynamic_chunk_training = False + + self.frontend = frontend + self.window_size = self.chunk_size + self.right_context + + self._ctx = self.asr_model.encoder.get_encoder_input_size( + self.window_size + ) + + #self.last_chunk_length = ( + # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 + #) * self.hop_length + + self.last_chunk_length = ( + self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 + ) + self.reset_inference_cache() + + def reset_inference_cache(self) -> None: + """Reset Speech2Text parameters.""" + self.frontend_cache = None + + self.asr_model.encoder.reset_streaming_cache( + self.left_context, device=self.device + ) + self.beam_search.reset_inference_cache() + + self.num_processed_frames = torch.tensor([[0]], device=self.device) + @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ): - """Inference - + def streaming_decode( + self, + speech: Union[torch.Tensor, np.ndarray], + is_final: bool = True, + ) -> List[Hypothesis]: + """Speech2Text streaming call. Args: - speech: Input speech data + speech: Chunk of speech data. (S) + is_final: Whether speech corresponds to the final chunk of data. Returns: - text, token, token_int, hyp + nbest_hypothesis: N-best hypothesis. + """ + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + if is_final: + if self.streaming and speech.size(0) < self.last_chunk_length: + pad = torch.zeros( + self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype + ) + speech = torch.cat([speech, pad], dim=0) #feats, feats_length = self.apply_frontend(speech, is_final=is_final) + feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) + + if self.asr_model.normalize is not None: + feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) + + feats = to_device(feats, device=self.device) + feats_lengths = to_device(feats_lengths, device=self.device) + enc_out = self.asr_model.encoder.chunk_forward( + feats, + feats_lengths, + self.num_processed_frames, + chunk_size=self.chunk_size, + left_context=self.left_context, + right_context=self.right_context, + ) + nbest_hyps = self.beam_search(enc_out[0], is_final=is_final) + + self.num_processed_frames += self.chunk_size + + if is_final: + self.reset_inference_cache() + + return nbest_hyps + + @torch.no_grad() + def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]: + """Speech2Text call. + Args: + speech: Speech data. (S) + Returns: + nbest_hypothesis: N-best hypothesis. """ assert check_argument_types() - # Input as audio signal if isinstance(speech, np.ndarray): speech = torch.tensor(speech) - - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - - enc_len_batch_total = feats_len.sum() - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} - - # a. To device - batch = to_device(batch, device=self.device) - - decoder_outs = self.asr_model(**batch) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) + + if self.asr_model.normalize is not None: + feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) + + feats = to_device(feats, device=self.device) + feats_lengths = to_device(feats_lengths, device=self.device) + enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, self.right_context) + nbest_hyps = self.beam_search(enc_out[0]) + + return nbest_hyps + + @torch.no_grad() + def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]: + """Speech2Text call. + Args: + speech: Speech data. (S) + Returns: + nbest_hypothesis: N-best hypothesis. + """ + assert check_argument_types() + + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) + + feats = to_device(feats, device=self.device) + feats_lengths = to_device(feats_lengths, device=self.device) + + enc_out, _ = self.asr_model.encoder(feats, feats_lengths) + + nbest_hyps = self.beam_search(enc_out[0]) + + return nbest_hyps + + def hypotheses_to_results(self, nbest_hyps: List[Hypothesis]) -> List[Any]: + """Build partial or final results from the hypotheses. + Args: + nbest_hyps: N-best hypothesis. + Returns: + results: Results containing different representation for the hypothesis. + """ results = [] - b, n, d = decoder_out.size() - for i in range(b): - am_scores = decoder_out[i, :ys_pad_lens[i], :] - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - yseq.tolist(), device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + for hyp in nbest_hyps: + token_int = list(filter(lambda x: x != 0, hyp.yseq)) - for hyp in nbest_hyps: - assert isinstance(hyp, (Hypothesis)), type(hyp) + token = self.converter.ids2tokens(token_int) - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + results.append((text, token, token_int, hyp)) - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) - - # Change integer-ids to tokens - token = self.converter.ids2tokens(token_int) - - if self.tokenizer is not None: - text = self.tokenizer.tokens2text(token) - else: - text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) + assert check_return_type(results) return results + @staticmethod + def from_pretrained( + model_tag: Optional[str] = None, + **kwargs: Optional[Any], + ) -> Speech2Text: + """Build Speech2Text instance from the pretrained model. + Args: + model_tag: Model tag of the pretrained models. + Return: + : Speech2Text instance. + """ + if model_tag is not None: + try: + from espnet_model_zoo.downloader import ModelDownloader + + except ImportError: + logging.error( + "`espnet_model_zoo` is not installed. " + "Please install via `pip install -U espnet_model_zoo`." + ) + raise + d = ModelDownloader() + kwargs.update(**d.download_and_unpack(model_tag)) + + return Speech2Text(**kwargs) + def inference( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - streaming: bool = False, - output_dir: Optional[str] = None, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - - **kwargs, -): - inference_pipeline = inference_modelscope( - maxlenratio=maxlenratio, - minlenratio=minlenratio, - batch_size=batch_size, - beam_size=beam_size, - ngpu=ngpu, - ctc_weight=ctc_weight, - lm_weight=lm_weight, - penalty=penalty, - log_level=log_level, - asr_train_config=asr_train_config, - asr_model_file=asr_model_file, - cmvn_file=cmvn_file, - raw_inputs=raw_inputs, - lm_train_config=lm_train_config, - lm_file=lm_file, - token_type=token_type, - key_file=key_file, - word_lm_train_config=word_lm_train_config, - bpemodel=bpemodel, - allow_variable_data_keys=allow_variable_data_keys, - streaming=streaming, - output_dir=output_dir, - dtype=dtype, - seed=seed, - ngram_weight=ngram_weight, - nbest=nbest, - num_workers=num_workers, - - **kwargs, - ) - return inference_pipeline(data_path_and_name_and_type, raw_inputs) - - -def inference_modelscope( - maxlenratio: float, - minlenratio: float, - batch_size: int, - beam_size: int, - ngpu: int, - ctc_weight: float, - lm_weight: float, - penalty: float, - log_level: Union[int, str], - # data_path_and_name_and_type, - asr_train_config: Optional[str], - asr_model_file: Optional[str], - cmvn_file: Optional[str] = None, - lm_train_config: Optional[str] = None, - lm_file: Optional[str] = None, - token_type: Optional[str] = None, - key_file: Optional[str] = None, - word_lm_train_config: Optional[str] = None, - bpemodel: Optional[str] = None, - allow_variable_data_keys: bool = False, - dtype: str = "float32", - seed: int = 0, - ngram_weight: float = 0.9, - nbest: int = 1, - num_workers: int = 1, - output_dir: Optional[str] = None, - param_dict: dict = None, - **kwargs, -): + output_dir: str, + batch_size: int, + dtype: str, + beam_size: int, + ngpu: int, + seed: int, + lm_weight: float, + nbest: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + asr_train_config: Optional[str], + asr_model_file: Optional[str], + cmvn_file: Optional[str], + beam_search_config: Optional[dict], + lm_train_config: Optional[str], + lm_file: Optional[str], + model_tag: Optional[str], + token_type: Optional[str], + bpemodel: Optional[str], + key_file: Optional[str], + allow_variable_data_keys: bool, + quantize_asr_model: Optional[bool], + quantize_modules: Optional[List[str]], + quantize_dtype: Optional[str], + streaming: Optional[bool], + simu_streaming: Optional[bool], + chunk_size: Optional[int], + left_context: Optional[int], + right_context: Optional[int], + display_partial_hypotheses: bool, + **kwargs, +) -> None: + """Transducer model inference. + Args: + output_dir: Output directory path. + batch_size: Batch decoding size. + dtype: Data type. + beam_size: Beam size. + ngpu: Number of GPUs. + seed: Random number generator seed. + lm_weight: Weight of language model. + nbest: Number of final hypothesis. + num_workers: Number of workers. + log_level: Level of verbose for logs. + data_path_and_name_and_type: + asr_train_config: ASR model training config path. + asr_model_file: ASR model path. + beam_search_config: Beam search config path. + lm_train_config: Language Model training config path. + lm_file: Language Model path. + model_tag: Model tag. + token_type: Type of token units. + bpemodel: BPE model path. + key_file: File key. + allow_variable_data_keys: Whether to allow variable data keys. + quantize_asr_model: Whether to apply dynamic quantization to ASR model. + quantize_modules: List of module names to apply dynamic quantization on. + quantize_dtype: Dynamic quantization data type. + streaming: Whether to perform chunk-by-chunk inference. + chunk_size: Number of frames in chunk AFTER subsampling. + left_context: Number of frames in left context AFTER subsampling. + right_context: Number of frames in right context AFTER subsampling. + display_partial_hypotheses: Whether to display partial hypotheses. + """ assert check_argument_types() - if word_lm_train_config is not None: - raise NotImplementedError("Word LM is not implemented") + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") @@ -605,20 +438,11 @@ def inference_modelscope( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) - - export_mode = False - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - export_mode = param_dict.get("export_mode", False) - else: - hotword_list_or_file = None - if ngpu >= 1 and torch.cuda.is_available(): + if ngpu >= 1: device = "cuda" else: device = "cpu" - batch_size = 1 - # 1. Set random-seed set_all_random_seed(seed) @@ -627,143 +451,105 @@ def inference_modelscope( asr_train_config=asr_train_config, asr_model_file=asr_model_file, cmvn_file=cmvn_file, + beam_search_config=beam_search_config, lm_train_config=lm_train_config, lm_file=lm_file, token_type=token_type, bpemodel=bpemodel, device=device, - maxlenratio=maxlenratio, - minlenratio=minlenratio, dtype=dtype, beam_size=beam_size, - ctc_weight=ctc_weight, lm_weight=lm_weight, - ngram_weight=ngram_weight, - penalty=penalty, nbest=nbest, - hotword_list_or_file=hotword_list_or_file, + quantize_asr_model=quantize_asr_model, + quantize_modules=quantize_modules, + quantize_dtype=quantize_dtype, + streaming=streaming, + simu_streaming=simu_streaming, + chunk_size=chunk_size, + left_context=left_context, + right_context=right_context, + ) + speech2text = Speech2Text.from_pretrained( + model_tag=model_tag, + **speech2text_kwargs, ) - if export_mode: - speech2text = Speech2TextExport(**speech2text_kwargs) - else: - speech2text = Speech2Text(**speech2text_kwargs) - def _forward( - data_path_and_name_and_type, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - output_dir_v2: Optional[str] = None, - fs: dict = None, - param_dict: dict = None, - **kwargs, - ): - - hotword_list_or_file = None - if param_dict is not None: - hotword_list_or_file = param_dict.get('hotword') - if 'hotword' in kwargs: - hotword_list_or_file = kwargs['hotword'] - if hotword_list_or_file is not None or 'hotword' in kwargs: - speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) - cache = None - if 'cache' in param_dict: - cache = param_dict['cache'] - # 3. Build data-iterator - if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, torch.Tensor): - raw_inputs = raw_inputs.numpy() - data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] - loader = ASRTask.build_streaming_iterator( - data_path_and_name_and_type, - dtype=dtype, - fs=fs, - batch_size=batch_size, - key_file=key_file, - num_workers=num_workers, - preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), - collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), - allow_variable_data_keys=allow_variable_data_keys, - inference=True, - ) - - forward_time_total = 0.0 - length_total = 0.0 - finish_count = 0 - file_count = 1 - # 7 .Start for-loop - # FIXME(kamo): The output format should be discussed about - asr_result_list = [] - output_path = output_dir_v2 if output_dir_v2 is not None else output_dir - if output_path is not None: - writer = DatadirWriter(output_path) - else: - writer = None + # 3. Build data-iterator + loader = ASRTransducerTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTransducerTask.build_preprocess_fn( + speech2text.asr_train_args, False + ), + collate_fn=ASRTransducerTask.build_collate_fn( + speech2text.asr_train_args, False + ), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + # 4 .Start for-loop + with DatadirWriter(output_dir) as writer: for keys, batch in loader: assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" - # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} + batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + assert len(batch.keys()) == 1 - logging.info("decoding, utt_id: {}".format(keys)) - # N-best list of (text, token, token_int, hyp_object) + try: + if speech2text.streaming: + speech = batch["speech"] - time_beg = time.time() - results = speech2text(cache=cache, **batch) - if len(results) < 1: - hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) - results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest - time_end = time.time() - forward_time = time_end - time_beg - lfr_factor = results[0][-1] - length = results[0][-2] - forward_time_total += forward_time - length_total += length - rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor)) - logging.info(rtf_cur) + _steps = len(speech) // speech2text._ctx + _end = 0 + for i in range(_steps): + _end = (i + 1) * speech2text._ctx - for batch_id in range(_bs): - result = [results[batch_id][:-2]] + speech2text.streaming_decode( + speech[i * speech2text._ctx : _end], is_final=False + ) - key = keys[batch_id] - for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result): - # Create a directory: outdir/{n}best_recog - if writer is not None: - ibest_writer = writer[f"{n}best_recog"] + final_hyps = speech2text.streaming_decode( + speech[_end : len(speech)], is_final=True + ) + elif speech2text.simu_streaming: + final_hyps = speech2text.simu_streaming_decode(**batch) + else: + final_hyps = speech2text(**batch) - # Write the result to each file - ibest_writer["token"][key] = " ".join(token) - # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) - ibest_writer["score"][key] = str(hyp.score) - ibest_writer["rtf"][key] = rtf_cur + results = speech2text.hypotheses_to_results(final_hyps) + except TooShortUttError as e: + logging.warning(f"Utterance {keys} {e}") + hyp = Hypothesis(score=0.0, yseq=[], dec_state=None) + results = [[" ", [""], [2], hyp]] * nbest - if text is not None: - text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) - item = {'key': key, 'value': text_postprocessed} - asr_result_list.append(item) - finish_count += 1 - # asr_utils.print_progress(finish_count / file_count) - if writer is not None: - ibest_writer["text"][key] = " ".join(word_lists) + key = keys[0] + for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): + ibest_writer = writer[f"{n}best_recog"] - logging.info("decoding, utt: {}, predictions: {}".format(key, text)) - rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)) - logging.info(rtf_avg) - if writer is not None: - ibest_writer["rtf"]["rtf_avf"] = rtf_avg - return asr_result_list + ibest_writer["token"][key] = " ".join(token) + ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["score"][key] = str(hyp.score) - return _forward + if text is not None: + ibest_writer["text"][key] = text def get_parser(): + """Get Transducer model inference parser.""" + parser = config_argparse.ArgumentParser( - description="ASR Decoding", + description="ASR Transducer Decoding", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - # Note(kamo): Use '_' instead of '-' as separator. - # '-' is confusing if written in yaml. parser.add_argument( "--log_level", type=lambda x: x.upper(), @@ -792,17 +578,12 @@ def get_parser(): default=1, help="The number of workers used for DataLoader", ) - parser.add_argument( - "--hotword", - type=str_or_none, - default=None, - help="hotword file path or hotwords seperated by space" - ) + group = parser.add_argument_group("Input data related") group.add_argument( "--data_path_and_name_and_type", type=str2triple_str, - required=False, + required=True, action="append", ) group.add_argument("--key_file", type=str_or_none) @@ -834,26 +615,11 @@ def get_parser(): type=str, help="LM parameter file", ) - group.add_argument( - "--word_lm_train_config", - type=str, - help="Word LM training configuration", - ) - group.add_argument( - "--word_lm_file", - type=str, - help="Word LM parameter file", - ) - group.add_argument( - "--ngram_file", - type=str, - help="N-gram parameter file", - ) group.add_argument( "--model_tag", type=str, help="Pretrained model tag. If specify this option, *_train_config and " - "*_file will be overwritten", + "*_file will be overwritten", ) group = parser.add_argument_group("Beam-search related") @@ -864,42 +630,13 @@ def get_parser(): help="The batch size for inference", ) group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") - group.add_argument("--beam_size", type=int, default=20, help="Beam size") - group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") - group.add_argument( - "--maxlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain max output length. " - "If maxlenratio=0.0 (default), it uses a end-detect " - "function " - "to automatically find maximum hypothesis lengths." - "If maxlenratio<0.0, its absolute value is interpreted" - "as a constant max output length", - ) - group.add_argument( - "--minlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain min output length", - ) - group.add_argument( - "--ctc_weight", - type=float, - default=0.5, - help="CTC weight in joint decoding", - ) + group.add_argument("--beam_size", type=int, default=5, help="Beam size") group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") - group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") - group.add_argument("--streaming", type=str2bool, default=False) - group.add_argument( - "--frontend_conf", - default=None, - help="", + "--beam_search_config", + default={}, + help="The keyword arguments for transducer beam search.", ) - group.add_argument("--raw_inputs", type=list, default=None) - # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) group = parser.add_argument_group("Text converter related") group.add_argument( @@ -908,14 +645,77 @@ def get_parser(): default=None, choices=["char", "bpe", None], help="The token type for ASR model. " - "If not given, refers from the training args", + "If not given, refers from the training args", ) group.add_argument( "--bpemodel", type=str_or_none, default=None, help="The model path of sentencepiece. " - "If not given, refers from the training args", + "If not given, refers from the training args", + ) + + group = parser.add_argument_group("Dynamic quantization related") + parser.add_argument( + "--quantize_asr_model", + type=bool, + default=False, + help="Apply dynamic quantization to ASR model.", + ) + parser.add_argument( + "--quantize_modules", + nargs="*", + default=None, + help="""Module names to apply dynamic quantization on. + The module names are provided as a list, where each name is separated + by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]). + Each specified name should be an attribute of 'torch.nn', e.g.: + torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", + ) + parser.add_argument( + "--quantize_dtype", + type=str, + default="qint8", + choices=["float16", "qint8"], + help="Dtype for dynamic quantization.", + ) + + group = parser.add_argument_group("Streaming related") + parser.add_argument( + "--streaming", + type=bool, + default=False, + help="Whether to perform chunk-by-chunk inference.", + ) + parser.add_argument( + "--simu_streaming", + type=bool, + default=False, + help="Whether to simulate chunk-by-chunk inference.", + ) + parser.add_argument( + "--chunk_size", + type=int, + default=16, + help="Number of frames in chunk AFTER subsampling.", + ) + parser.add_argument( + "--left_context", + type=int, + default=32, + help="Number of frames in left context of the chunk AFTER subsampling.", + ) + parser.add_argument( + "--right_context", + type=int, + default=0, + help="Number of frames in right context of the chunk AFTER subsampling.", + ) + parser.add_argument( + "--display_partial_hypotheses", + type=bool, + default=False, + help="Whether to display partial hypotheses during chunk-by-chunk inference.", ) return parser @@ -923,24 +723,15 @@ def get_parser(): def main(cmd=None): print(get_commandline_args(), file=sys.stderr) + parser = get_parser() args = parser.parse_args(cmd) - param_dict = {'hotword': args.hotword} kwargs = vars(args) + kwargs.pop("config", None) - kwargs['param_dict'] = param_dict inference(**kwargs) if __name__ == "__main__": main() - # from modelscope.pipelines import pipeline - # from modelscope.utils.constant import Tasks - # - # inference_16k_pipline = pipeline( - # task=Tasks.auto_speech_recognition, - # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') - # - # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') - # print(rec_result) diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py new file mode 100755 index 000000000..fe418dbc9 --- /dev/null +++ b/funasr/bin/asr_train_transducer.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import os + +from funasr.tasks.asr import ASRTransducerTask + + +# for ASR Training +def parse_args(): + parser = ASRTransducerTask.get_parser() + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + args = parser.parse_args() + return args + + +def main(args=None, cmd=None): + # for ASR Training + ASRTransducerTask.main(args=args, cmd=cmd) + + +if __name__ == '__main__': + args = parse_args() + + # setup local gpu_id + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + assert args.num_worker_count == 1 + + # re-compute batch size: when dataset type is small + if args.dataset_type == "small": + if args.batch_size is not None: + args.batch_size = args.batch_size * args.ngpu + if args.batch_bins is not None: + args.batch_bins = args.batch_bins * args.ngpu + + main(args=args) diff --git a/funasr/models/decoder/rnnt_decoder.py b/funasr/models/decoder/rnnt_decoder.py new file mode 100644 index 000000000..5401ab20c --- /dev/null +++ b/funasr/models/decoder/rnnt_decoder.py @@ -0,0 +1,258 @@ +"""RNN decoder definition for Transducer models.""" + +from typing import List, Optional, Tuple + +import torch +from typeguard import check_argument_types + +from funasr.modules.beam_search.beam_search_transducer import Hypothesis +from funasr.models.specaug.specaug import SpecAug + +class RNNTDecoder(torch.nn.Module): + """RNN decoder module. + + Args: + vocab_size: Vocabulary size. + embed_size: Embedding size. + hidden_size: Hidden size.. + rnn_type: Decoder layers type. + num_layers: Number of decoder layers. + dropout_rate: Dropout rate for decoder layers. + embed_dropout_rate: Dropout rate for embedding layer. + embed_pad: Embedding padding symbol ID. + + """ + + def __init__( + self, + vocab_size: int, + embed_size: int = 256, + hidden_size: int = 256, + rnn_type: str = "lstm", + num_layers: int = 1, + dropout_rate: float = 0.0, + embed_dropout_rate: float = 0.0, + embed_pad: int = 0, + ) -> None: + """Construct a RNNDecoder object.""" + super().__init__() + + assert check_argument_types() + + if rnn_type not in ("lstm", "gru"): + raise ValueError(f"Not supported: rnn_type={rnn_type}") + + self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad) + self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate) + + rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU + + self.rnn = torch.nn.ModuleList( + [rnn_class(embed_size, hidden_size, 1, batch_first=True)] + ) + + for _ in range(1, num_layers): + self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)] + + self.dropout_rnn = torch.nn.ModuleList( + [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)] + ) + + self.dlayers = num_layers + self.dtype = rnn_type + + self.output_size = hidden_size + self.vocab_size = vocab_size + + self.device = next(self.parameters()).device + self.score_cache = {} + + def forward( + self, + labels: torch.Tensor, + label_lens: torch.Tensor, + states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """Encode source label sequences. + + Args: + labels: Label ID sequences. (B, L) + states: Decoder hidden states. + ((N, B, D_dec), (N, B, D_dec) or None) or None + + Returns: + dec_out: Decoder output sequences. (B, U, D_dec) + + """ + if states is None: + states = self.init_state(labels.size(0)) + + dec_embed = self.dropout_embed(self.embed(labels)) + dec_out, states = self.rnn_forward(dec_embed, states) + return dec_out + + def rnn_forward( + self, + x: torch.Tensor, + state: Tuple[torch.Tensor, Optional[torch.Tensor]], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Encode source label sequences. + + Args: + x: RNN input sequences. (B, D_emb) + state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + + Returns: + x: RNN output sequences. (B, D_dec) + (h_next, c_next): Decoder hidden states. + (N, B, D_dec), (N, B, D_dec) or None) + + """ + h_prev, c_prev = state + h_next, c_next = self.init_state(x.size(0)) + + for layer in range(self.dlayers): + if self.dtype == "lstm": + x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[ + layer + ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])) + else: + x, h_next[layer : layer + 1] = self.rnn[layer]( + x, hx=h_prev[layer : layer + 1] + ) + + x = self.dropout_rnn[layer](x) + + return x, (h_next, c_next) + + def score( + self, + label: torch.Tensor, + label_sequence: List[int], + dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """One-step forward hypothesis. + + Args: + label: Previous label. (1, 1) + label_sequence: Current label sequence. + dec_state: Previous decoder hidden states. + ((N, 1, D_dec), (N, 1, D_dec) or None) + + Returns: + dec_out: Decoder output sequence. (1, D_dec) + dec_state: Decoder hidden states. + ((N, 1, D_dec), (N, 1, D_dec) or None) + + """ + str_labels = "_".join(map(str, label_sequence)) + + if str_labels in self.score_cache: + dec_out, dec_state = self.score_cache[str_labels] + else: + dec_embed = self.embed(label) + dec_out, dec_state = self.rnn_forward(dec_embed, dec_state) + + self.score_cache[str_labels] = (dec_out, dec_state) + + return dec_out[0], dec_state + + def batch_score( + self, + hyps: List[Hypothesis], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """One-step forward hypotheses. + + Args: + hyps: Hypotheses. + + Returns: + dec_out: Decoder output sequences. (B, D_dec) + states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + + """ + labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device) + dec_embed = self.embed(labels) + + states = self.create_batch_states([h.dec_state for h in hyps]) + dec_out, states = self.rnn_forward(dec_embed, states) + + return dec_out.squeeze(1), states + + def set_device(self, device: torch.device) -> None: + """Set GPU device to use. + + Args: + device: Device ID. + + """ + self.device = device + + def init_state( + self, batch_size: int + ) -> Tuple[torch.Tensor, Optional[torch.tensor]]: + """Initialize decoder states. + + Args: + batch_size: Batch size. + + Returns: + : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + + """ + h_n = torch.zeros( + self.dlayers, + batch_size, + self.output_size, + device=self.device, + ) + + if self.dtype == "lstm": + c_n = torch.zeros( + self.dlayers, + batch_size, + self.output_size, + device=self.device, + ) + + return (h_n, c_n) + + return (h_n, None) + + def select_state( + self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Get specified ID state from decoder hidden states. + + Args: + states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + idx: State ID to extract. + + Returns: + : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None) + + """ + return ( + states[0][:, idx : idx + 1, :], + states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None, + ) + + def create_batch_states( + self, + new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Create decoder hidden states. + + Args: + new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)] + + Returns: + states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) + + """ + return ( + torch.cat([s[0] for s in new_states], dim=1), + torch.cat([s[1] for s in new_states], dim=1) + if self.dtype == "lstm" + else None, + ) diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py new file mode 100644 index 000000000..0cae30605 --- /dev/null +++ b/funasr/models/e2e_asr_transducer.py @@ -0,0 +1,1013 @@ +"""ESPnet2 ASR Transducer model.""" + +import logging +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union + +import torch +from packaging.version import parse as V +from typeguard import check_argument_types + +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.decoder.rnnt_decoder import RNNTDecoder +from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder +from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder +from funasr.models.joint_net.joint_network import JointNetwork +from funasr.modules.nets_utils import get_transducer_task_io +from funasr.layers.abs_normalize import AbsNormalize +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel + +if V(torch.__version__) >= V("1.6.0"): + from torch.cuda.amp import autocast +else: + + @contextmanager + def autocast(enabled=True): + yield + + +class TransducerModel(AbsESPnetModel): + """ESPnet2ASRTransducerModel module definition. + + Args: + vocab_size: Size of complete vocabulary (w/ EOS and blank included). + token_list: List of token + frontend: Frontend module. + specaug: SpecAugment module. + normalize: Normalization module. + encoder: Encoder module. + decoder: Decoder module. + joint_network: Joint Network module. + transducer_weight: Weight of the Transducer loss. + fastemit_lambda: FastEmit lambda value. + auxiliary_ctc_weight: Weight of auxiliary CTC loss. + auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. + auxiliary_lm_loss_weight: Weight of auxiliary LM loss. + auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. + ignore_id: Initial padding ID. + sym_space: Space symbol. + sym_blank: Blank Symbol + report_cer: Whether to report Character Error Rate during validation. + report_wer: Whether to report Word Error Rate during validation. + extract_feats_in_collect_stats: Whether to use extract_feats stats collection. + + """ + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + encoder: Encoder, + decoder: RNNTDecoder, + joint_network: JointNetwork, + att_decoder: Optional[AbsAttDecoder] = None, + transducer_weight: float = 1.0, + fastemit_lambda: float = 0.0, + auxiliary_ctc_weight: float = 0.0, + auxiliary_ctc_dropout_rate: float = 0.0, + auxiliary_lm_loss_weight: float = 0.0, + auxiliary_lm_loss_smoothing: float = 0.0, + ignore_id: int = -1, + sym_space: str = "", + sym_blank: str = "", + report_cer: bool = True, + report_wer: bool = True, + extract_feats_in_collect_stats: bool = True, + ) -> None: + """Construct an ESPnetASRTransducerModel object.""" + super().__init__() + + assert check_argument_types() + + # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) + self.blank_id = 0 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.token_list = token_list.copy() + + self.sym_space = sym_space + self.sym_blank = sym_blank + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + + self.encoder = encoder + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = None + self.error_calculator = None + + self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 + self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 + + if self.use_auxiliary_ctc: + self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) + self.ctc_dropout_rate = auxiliary_ctc_dropout_rate + + if self.use_auxiliary_lm_loss: + self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) + self.lm_loss_smoothing = auxiliary_lm_loss_smoothing + + self.transducer_weight = transducer_weight + self.fastemit_lambda = fastemit_lambda + + self.auxiliary_ctc_weight = auxiliary_ctc_weight + self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight + + self.report_cer = report_cer + self.report_wer = report_wer + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Forward architecture and compute loss(es). + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + loss: Main loss value. + stats: Task statistics. + weight: Task weights. + + """ + assert text_lengths.dim() == 1, text_lengths.shape + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + + batch_size = speech.shape[0] + text = text[:, : text_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # 2. Transducer-related I/O preparation + decoder_in, target, t_len, u_len = get_transducer_task_io( + text, + encoder_out_lens, + ignore_id=self.ignore_id, + ) + + # 3. Decoder + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in, u_len) + + # 4. Joint Network + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) + ) + + # 5. Losses + loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( + encoder_out, + joint_out, + target, + t_len, + u_len, + ) + + loss_ctc, loss_lm = 0.0, 0.0 + + if self.use_auxiliary_ctc: + loss_ctc = self._calc_ctc_loss( + encoder_out, + target, + t_len, + u_len, + ) + + if self.use_auxiliary_lm_loss: + loss_lm = self._calc_lm_loss(decoder_out, target) + + loss = ( + self.transducer_weight * loss_trans + + self.auxiliary_ctc_weight * loss_ctc + + self.auxiliary_lm_loss_weight * loss_lm + ) + + stats = dict( + loss=loss.detach(), + loss_transducer=loss_trans.detach(), + aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, + aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, + cer_transducer=cer_trans, + wer_transducer=wer_trans, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Collect features sequences and features lengths sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + + Return: + {}: "feats": Features sequences. (B, T, D_feats), + "feats_lengths": Features sequences lengths. (B,) + + """ + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + + feats, feats_lengths = speech, speech_lengths + + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encoder speech sequences. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + encoder_out: Encoder outputs. (B, T, D_enc) + encoder_out_lens: Encoder outputs lengths. (B,) + + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # 4. Forward encoder + encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract features sequences and features sequences lengths. + + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + + Return: + feats: Features sequences. (B, T, D_feats) + feats_lengths: Features sequences lengths. (B,) + + """ + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + feats, feats_lengths = speech, speech_lengths + + return feats, feats_lengths + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + joint_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: + """Compute Transducer loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + joint_out: Joint Network output sequences (B, T, U, D_joint) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + + """ + if self.criterion_transducer is None: + try: + # from warprnnt_pytorch import RNNTLoss + # self.criterion_transducer = RNNTLoss( + # reduction="mean", + # fastemit_lambda=self.fastemit_lambda, + # ) + from warp_rnnt import rnnt_loss as RNNTLoss + self.criterion_transducer = RNNTLoss + + except ImportError: + logging.error( + "warp-rnnt was not installed." + "Please consult the installation documentation." + ) + exit(1) + + # loss_transducer = self.criterion_transducer( + # joint_out, + # target, + # t_len, + # u_len, + # ) + log_probs = torch.log_softmax(joint_out, dim=-1) + + loss_transducer = self.criterion_transducer( + log_probs, + target, + t_len, + u_len, + reduction="mean", + blank=self.blank_id, + fastemit_lambda=self.fastemit_lambda, + gather=True, + ) + + if not self.training and (self.report_cer or self.report_wer): + if self.error_calculator is None: + from espnet2.asr_transducer.error_calculator import ErrorCalculator + + self.error_calculator = ErrorCalculator( + self.decoder, + self.joint_network, + self.token_list, + self.sym_space, + self.sym_blank, + report_cer=self.report_cer, + report_wer=self.report_wer, + ) + + cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) + + return loss_transducer, cer_transducer, wer_transducer + + return loss_transducer, None, None + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + + Return: + loss_ctc: CTC loss value. + + """ + ctc_in = self.ctc_lin( + torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) + ) + ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) + + target_mask = target != 0 + ctc_target = target[target_mask].cpu() + + with torch.backends.cudnn.flags(deterministic=True): + loss_ctc = torch.nn.functional.ctc_loss( + ctc_in, + ctc_target, + t_len, + u_len, + zero_infinity=True, + reduction="sum", + ) + loss_ctc /= target.size(0) + + return loss_ctc + + def _calc_lm_loss( + self, + decoder_out: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """Compute LM loss. + + Args: + decoder_out: Decoder output sequences. (B, U, D_dec) + target: Target label ID sequences. (B, L) + + Return: + loss_lm: LM loss value. + + """ + lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) + lm_target = target.view(-1).type(torch.int64) + + with torch.no_grad(): + true_dist = lm_loss_in.clone() + true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) + + # Ignore blank ID (0) + ignore = lm_target == 0 + lm_target = lm_target.masked_fill(ignore, 0) + + true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) + + loss_lm = torch.nn.functional.kl_div( + torch.log_softmax(lm_loss_in, dim=1), + true_dist, + reduction="none", + ) + loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( + 0 + ) + + return loss_lm + +class UnifiedTransducerModel(AbsESPnetModel): + """ESPnet2ASRTransducerModel module definition. + Args: + vocab_size: Size of complete vocabulary (w/ EOS and blank included). + token_list: List of token + frontend: Frontend module. + specaug: SpecAugment module. + normalize: Normalization module. + encoder: Encoder module. + decoder: Decoder module. + joint_network: Joint Network module. + transducer_weight: Weight of the Transducer loss. + fastemit_lambda: FastEmit lambda value. + auxiliary_ctc_weight: Weight of auxiliary CTC loss. + auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. + auxiliary_lm_loss_weight: Weight of auxiliary LM loss. + auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. + ignore_id: Initial padding ID. + sym_space: Space symbol. + sym_blank: Blank Symbol + report_cer: Whether to report Character Error Rate during validation. + report_wer: Whether to report Word Error Rate during validation. + extract_feats_in_collect_stats: Whether to use extract_feats stats collection. + """ + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + encoder: Encoder, + decoder: RNNTDecoder, + joint_network: JointNetwork, + att_decoder: Optional[AbsAttDecoder] = None, + transducer_weight: float = 1.0, + fastemit_lambda: float = 0.0, + auxiliary_ctc_weight: float = 0.0, + auxiliary_att_weight: float = 0.0, + auxiliary_ctc_dropout_rate: float = 0.0, + auxiliary_lm_loss_weight: float = 0.0, + auxiliary_lm_loss_smoothing: float = 0.0, + ignore_id: int = -1, + sym_space: str = "", + sym_blank: str = "", + report_cer: bool = True, + report_wer: bool = True, + sym_sos: str = "", + sym_eos: str = "", + extract_feats_in_collect_stats: bool = True, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + ) -> None: + """Construct an ESPnetASRTransducerModel object.""" + super().__init__() + + assert check_argument_types() + + # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) + self.blank_id = 0 + + if sym_sos in token_list: + self.sos = token_list.index(sym_sos) + else: + self.sos = vocab_size - 1 + if sym_eos in token_list: + self.eos = token_list.index(sym_eos) + else: + self.eos = vocab_size - 1 + + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.token_list = token_list.copy() + + self.sym_space = sym_space + self.sym_blank = sym_blank + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + + self.encoder = encoder + self.decoder = decoder + self.joint_network = joint_network + + self.criterion_transducer = None + self.error_calculator = None + + self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 + self.use_auxiliary_att = auxiliary_att_weight > 0 + self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 + + if self.use_auxiliary_ctc: + self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) + self.ctc_dropout_rate = auxiliary_ctc_dropout_rate + + if self.use_auxiliary_att: + self.att_decoder = att_decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if self.use_auxiliary_lm_loss: + self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) + self.lm_loss_smoothing = auxiliary_lm_loss_smoothing + + self.transducer_weight = transducer_weight + self.fastemit_lambda = fastemit_lambda + + self.auxiliary_ctc_weight = auxiliary_ctc_weight + self.auxiliary_att_weight = auxiliary_att_weight + self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight + + self.report_cer = report_cer + self.report_wer = report_wer + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Forward architecture and compute loss(es). + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + Return: + loss: Main loss value. + stats: Task statistics. + weight: Task weights. + """ + assert text_lengths.dim() == 1, text_lengths.shape + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + + batch_size = speech.shape[0] + text = text[:, : text_lengths.max()] + #print(speech.shape) + # 1. Encoder + encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths) + + loss_att, loss_att_chunk = 0.0, 0.0 + + if self.use_auxiliary_att: + loss_att, _ = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + loss_att_chunk, _ = self._calc_att_loss( + encoder_out_chunk, encoder_out_lens, text, text_lengths + ) + + # 2. Transducer-related I/O preparation + decoder_in, target, t_len, u_len = get_transducer_task_io( + text, + encoder_out_lens, + ignore_id=self.ignore_id, + ) + + # 3. Decoder + self.decoder.set_device(encoder_out.device) + decoder_out = self.decoder(decoder_in, u_len) + + # 4. Joint Network + joint_out = self.joint_network( + encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) + ) + + joint_out_chunk = self.joint_network( + encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1) + ) + + # 5. Losses + loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss( + encoder_out, + joint_out, + target, + t_len, + u_len, + ) + + loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss( + encoder_out_chunk, + joint_out_chunk, + target, + t_len, + u_len, + ) + + loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0 + + if self.use_auxiliary_ctc: + loss_ctc = self._calc_ctc_loss( + encoder_out, + target, + t_len, + u_len, + ) + loss_ctc_chunk = self._calc_ctc_loss( + encoder_out_chunk, + target, + t_len, + u_len, + ) + + if self.use_auxiliary_lm_loss: + loss_lm = self._calc_lm_loss(decoder_out, target) + + loss_trans = loss_trans_utt + loss_trans_chunk + loss_ctc = loss_ctc + loss_ctc_chunk + loss_ctc = loss_att + loss_att_chunk + + loss = ( + self.transducer_weight * loss_trans + + self.auxiliary_ctc_weight * loss_ctc + + self.auxiliary_att_weight * loss_att + + self.auxiliary_lm_loss_weight * loss_lm + ) + + stats = dict( + loss=loss.detach(), + loss_transducer=loss_trans_utt.detach(), + loss_transducer_chunk=loss_trans_chunk.detach(), + aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, + aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None, + aux_att_loss=loss_att.detach() if loss_att > 0.0 else None, + aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None, + aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, + cer_transducer=cer_trans, + wer_transducer=wer_trans, + cer_transducer_chunk=cer_trans_chunk, + wer_transducer_chunk=wer_trans_chunk, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Collect features sequences and features lengths sequences. + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + text: Label ID sequences. (B, L) + text_lengths: Label ID sequences lengths. (B,) + kwargs: Contains "utts_id". + Return: + {}: "feats": Features sequences. (B, T, D_feats), + "feats_lengths": Features sequences lengths. (B,) + """ + if self.extract_feats_in_collect_stats: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + else: + # Generate dummy stats if extract_feats_in_collect_stats is False + logging.warning( + "Generating dummy stats for feats and feats_lengths, " + "because encoder_conf.extract_feats_in_collect_stats is " + f"{self.extract_feats_in_collect_stats}" + ) + + feats, feats_lengths = speech, speech_lengths + + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encoder speech sequences. + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + Return: + encoder_out: Encoder outputs. (B, T, D_enc) + encoder_out_lens: Encoder outputs lengths. (B,) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # 4. Forward encoder + encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_chunk, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract features sequences and features sequences lengths. + Args: + speech: Speech sequences. (B, S) + speech_lengths: Speech sequences lengths. (B,) + Return: + feats: Features sequences. (B, T, D_feats) + feats_lengths: Features sequences lengths. (B,) + """ + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + feats, feats_lengths = speech, speech_lengths + + return feats, feats_lengths + + def _calc_transducer_loss( + self, + encoder_out: torch.Tensor, + joint_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: + """Compute Transducer loss. + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + joint_out: Joint Network output sequences (B, T, U, D_joint) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + Return: + loss_transducer: Transducer loss value. + cer_transducer: Character error rate for Transducer. + wer_transducer: Word Error Rate for Transducer. + """ + if self.criterion_transducer is None: + try: + # from warprnnt_pytorch import RNNTLoss + # self.criterion_transducer = RNNTLoss( + # reduction="mean", + # fastemit_lambda=self.fastemit_lambda, + # ) + from warp_rnnt import rnnt_loss as RNNTLoss + self.criterion_transducer = RNNTLoss + + except ImportError: + logging.error( + "warp-rnnt was not installed." + "Please consult the installation documentation." + ) + exit(1) + + # loss_transducer = self.criterion_transducer( + # joint_out, + # target, + # t_len, + # u_len, + # ) + log_probs = torch.log_softmax(joint_out, dim=-1) + + loss_transducer = self.criterion_transducer( + log_probs, + target, + t_len, + u_len, + reduction="mean", + blank=self.blank_id, + fastemit_lambda=self.fastemit_lambda, + gather=True, + ) + + if not self.training and (self.report_cer or self.report_wer): + if self.error_calculator is None: + self.error_calculator = ErrorCalculator( + self.decoder, + self.joint_network, + self.token_list, + self.sym_space, + self.sym_blank, + report_cer=self.report_cer, + report_wer=self.report_wer, + ) + + cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) + return loss_transducer, cer_transducer, wer_transducer + + return loss_transducer, None, None + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + target: torch.Tensor, + t_len: torch.Tensor, + u_len: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + t_len: Encoder output sequences lengths. (B,) + u_len: Target label ID sequences lengths. (B,) + Return: + loss_ctc: CTC loss value. + """ + ctc_in = self.ctc_lin( + torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) + ) + ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) + + target_mask = target != 0 + ctc_target = target[target_mask].cpu() + + with torch.backends.cudnn.flags(deterministic=True): + loss_ctc = torch.nn.functional.ctc_loss( + ctc_in, + ctc_target, + t_len, + u_len, + zero_infinity=True, + reduction="sum", + ) + loss_ctc /= target.size(0) + + return loss_ctc + + def _calc_lm_loss( + self, + decoder_out: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """Compute LM loss. + Args: + decoder_out: Decoder output sequences. (B, U, D_dec) + target: Target label ID sequences. (B, L) + Return: + loss_lm: LM loss value. + """ + lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) + lm_target = target.view(-1).type(torch.int64) + + with torch.no_grad(): + true_dist = lm_loss_in.clone() + true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) + + # Ignore blank ID (0) + ignore = lm_target == 0 + lm_target = lm_target.masked_fill(ignore, 0) + + true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) + + loss_lm = torch.nn.functional.kl_div( + torch.log_softmax(lm_loss_in, dim=1), + true_dist, + reduction="none", + ) + loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( + 0 + ) + + return loss_lm + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + if hasattr(self, "lang_token_id") and self.lang_token_id is not None: + ys_pad = torch.cat( + [ + self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device), + ys_pad, + ], + dim=1, + ) + ys_pad_lens += 1 + + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.att_decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + return loss_att, acc_att diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py index 7c7f66142..9777ceed6 100644 --- a/funasr/models/encoder/conformer_encoder.py +++ b/funasr/models/encoder/conformer_encoder.py @@ -8,6 +8,7 @@ from typing import List from typing import Optional from typing import Tuple from typing import Union +from typing import Dict import torch from torch import nn @@ -18,6 +19,7 @@ from funasr.models.encoder.abs_encoder import AbsEncoder from funasr.modules.attention import ( MultiHeadedAttention, # noqa: H301 RelPositionMultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttentionChunk, LegacyRelPositionMultiHeadedAttention, # noqa: H301 ) from funasr.modules.embedding import ( @@ -25,16 +27,23 @@ from funasr.modules.embedding import ( ScaledPositionalEncoding, # noqa: H301 RelPositionalEncoding, # noqa: H301 LegacyRelPositionalEncoding, # noqa: H301 + StreamingRelPositionalEncoding, ) from funasr.modules.layer_norm import LayerNorm from funasr.modules.multi_layer_conv import Conv1dLinear from funasr.modules.multi_layer_conv import MultiLayeredConv1d from funasr.modules.nets_utils import get_activation from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.nets_utils import ( + TooShortUttError, + check_short_utt, + make_chunk_mask, + make_source_mask, +) from funasr.modules.positionwise_feed_forward import ( PositionwiseFeedForward, # noqa: H301 ) -from funasr.modules.repeat import repeat +from funasr.modules.repeat import repeat, MultiBlocks from funasr.modules.subsampling import Conv2dSubsampling from funasr.modules.subsampling import Conv2dSubsampling2 from funasr.modules.subsampling import Conv2dSubsampling6 @@ -42,6 +51,8 @@ from funasr.modules.subsampling import Conv2dSubsampling8 from funasr.modules.subsampling import TooShortUttError from funasr.modules.subsampling import check_short_utt from funasr.modules.subsampling import Conv2dSubsamplingPad +from funasr.modules.subsampling import StreamingConvInput + class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. @@ -276,6 +287,188 @@ class EncoderLayer(nn.Module): return x, mask +class ChunkEncoderLayer(torch.nn.Module): + """Chunk Conformer module definition. + Args: + block_size: Input/output size. + self_att: Self-attention module instance. + feed_forward: Feed-forward module instance. + feed_forward_macaron: Feed-forward module instance for macaron network. + conv_mod: Convolution module instance. + norm_class: Normalization module class. + norm_args: Normalization module arguments. + dropout_rate: Dropout rate. + """ + + def __init__( + self, + block_size: int, + self_att: torch.nn.Module, + feed_forward: torch.nn.Module, + feed_forward_macaron: torch.nn.Module, + conv_mod: torch.nn.Module, + norm_class: torch.nn.Module = torch.nn.LayerNorm, + norm_args: Dict = {}, + dropout_rate: float = 0.0, + ) -> None: + """Construct a Conformer object.""" + super().__init__() + + self.self_att = self_att + + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.feed_forward_scale = 0.5 + + self.conv_mod = conv_mod + + self.norm_feed_forward = norm_class(block_size, **norm_args) + self.norm_self_att = norm_class(block_size, **norm_args) + + self.norm_macaron = norm_class(block_size, **norm_args) + self.norm_conv = norm_class(block_size, **norm_args) + self.norm_final = norm_class(block_size, **norm_args) + + self.dropout = torch.nn.Dropout(dropout_rate) + + self.block_size = block_size + self.cache = None + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset self-attention and convolution modules cache for streaming. + Args: + left_context: Number of left frames during chunk-by-chunk inference. + device: Device to use for cache tensor. + """ + self.cache = [ + torch.zeros( + (1, left_context, self.block_size), + device=device, + ), + torch.zeros( + ( + 1, + self.block_size, + self.conv_mod.kernel_size - 1, + ), + device=device, + ), + ] + + def forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: Conformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T) + chunk_mask: Chunk mask. (T_2, T_2) + Returns: + x: Conformer output sequences. (B, T, D_block) + mask: Source mask. (B, T) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + """ + residual = x + + x = self.norm_macaron(x) + x = residual + self.feed_forward_scale * self.dropout( + self.feed_forward_macaron(x) + ) + + residual = x + x = self.norm_self_att(x) + x_q = x + x = residual + self.dropout( + self.self_att( + x_q, + x, + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + ) + + residual = x + + x = self.norm_conv(x) + x, _ = self.conv_mod(x) + x = residual + self.dropout(x) + residual = x + + x = self.norm_feed_forward(x) + x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) + + x = self.norm_final(x) + return x, mask, pos_enc + + def chunk_forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_size: int = 16, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode chunk of input sequence. + Args: + x: Conformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T_2) + left_context: Number of frames in left context. + right_context: Number of frames in right context. + Returns: + x: Conformer output sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + """ + residual = x + + x = self.norm_macaron(x) + x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) + + residual = x + x = self.norm_self_att(x) + if left_context > 0: + key = torch.cat([self.cache[0], x], dim=1) + else: + key = x + val = key + + if right_context > 0: + att_cache = key[:, -(left_context + right_context) : -right_context, :] + else: + att_cache = key[:, -left_context:, :] + x = residual + self.self_att( + x, + key, + val, + pos_enc, + mask, + left_context=left_context, + ) + + residual = x + x = self.norm_conv(x) + x, conv_cache = self.conv_mod( + x, cache=self.cache[1], right_context=right_context + ) + x = residual + x + residual = x + + x = self.norm_feed_forward(x) + x = residual + self.feed_forward_scale * self.feed_forward(x) + + x = self.norm_final(x) + self.cache = [att_cache, conv_cache] + + return x, pos_enc + class ConformerEncoder(AbsEncoder): """Conformer encoder module. @@ -604,3 +797,442 @@ class ConformerEncoder(AbsEncoder): if len(intermediate_outs) > 0: return (xs_pad, intermediate_outs), olens, None return xs_pad, olens, None + + +class CausalConvolution(torch.nn.Module): + """ConformerConvolution module definition. + Args: + channels: The number of channels. + kernel_size: Size of the convolving kernel. + activation: Type of activation function. + norm_args: Normalization module arguments. + causal: Whether to use causal convolution (set to True if streaming). + """ + + def __init__( + self, + channels: int, + kernel_size: int, + activation: torch.nn.Module = torch.nn.ReLU(), + norm_args: Dict = {}, + causal: bool = False, + ) -> None: + """Construct an ConformerConvolution object.""" + super().__init__() + + assert (kernel_size - 1) % 2 == 0 + + self.kernel_size = kernel_size + + self.pointwise_conv1 = torch.nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + ) + + if causal: + self.lorder = kernel_size - 1 + padding = 0 + else: + self.lorder = 0 + padding = (kernel_size - 1) // 2 + + self.depthwise_conv = torch.nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + ) + self.norm = torch.nn.BatchNorm1d(channels, **norm_args) + self.pointwise_conv2 = torch.nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + ) + + self.activation = activation + + def forward( + self, + x: torch.Tensor, + cache: Optional[torch.Tensor] = None, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x: ConformerConvolution input sequences. (B, T, D_hidden) + cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden) + right_context: Number of frames in right context. + Returns: + x: ConformerConvolution output sequences. (B, T, D_hidden) + cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden) + """ + x = self.pointwise_conv1(x.transpose(1, 2)) + x = torch.nn.functional.glu(x, dim=1) + + if self.lorder > 0: + if cache is None: + x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + else: + x = torch.cat([cache, x], dim=2) + + if right_context > 0: + cache = x[:, :, -(self.lorder + right_context) : -right_context] + else: + cache = x[:, :, -self.lorder :] + + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x).transpose(1, 2) + + return x, cache + +class ConformerChunkEncoder(AbsEncoder): + """Encoder module definition. + Args: + input_size: Input size. + body_conf: Encoder body configuration. + input_conf: Encoder input configuration. + main_conf: Encoder main configuration. + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + embed_vgg_like: bool = False, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 3, + macaron_style: bool = False, + rel_pos_type: str = "legacy", + pos_enc_layer_type: str = "rel_pos", + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + zero_triu: bool = False, + norm_type: str = "layer_norm", + cnn_module_kernel: int = 31, + conv_mod_norm_eps: float = 0.00001, + conv_mod_norm_momentum: float = 0.1, + simplified_att_score: bool = False, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 25, + left_chunk_size: int = 0, + time_reduction_factor: int = 1, + unified_model_training: bool = False, + default_chunk_size: int = 16, + jitter_range: int = 4, + subsampling_factor: int = 1, + ) -> None: + """Construct an Encoder object.""" + super().__init__() + + assert check_argument_types() + + self.embed = StreamingConvInput( + input_size, + output_size, + subsampling_factor, + vgg_like=embed_vgg_like, + output_size=output_size, + ) + + self.pos_enc = StreamingRelPositionalEncoding( + output_size, + positional_dropout_rate, + ) + + activation = get_activation( + activation_type + ) + + pos_wise_args = ( + output_size, + linear_units, + positional_dropout_rate, + activation, + ) + + conv_mod_norm_args = { + "eps": conv_mod_norm_eps, + "momentum": conv_mod_norm_momentum, + } + + conv_mod_args = ( + output_size, + cnn_module_kernel, + activation, + conv_mod_norm_args, + dynamic_chunk_training or unified_model_training, + ) + + mult_att_args = ( + attention_heads, + output_size, + attention_dropout_rate, + simplified_att_score, + ) + + + fn_modules = [] + for _ in range(num_blocks): + module = lambda: ChunkEncoderLayer( + output_size, + RelPositionMultiHeadedAttentionChunk(*mult_att_args), + PositionwiseFeedForward(*pos_wise_args), + PositionwiseFeedForward(*pos_wise_args), + CausalConvolution(*conv_mod_args), + dropout_rate=dropout_rate, + ) + fn_modules.append(module) + + self.encoders = MultiBlocks( + [fn() for fn in fn_modules], + output_size, + ) + + self._output_size = output_size + + self.dynamic_chunk_training = dynamic_chunk_training + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + self.left_chunk_size = left_chunk_size + + self.unified_model_training = unified_model_training + self.default_chunk_size = default_chunk_size + self.jitter_range = jitter_range + + self.time_reduction_factor = time_reduction_factor + + def output_size(self) -> int: + return self._output_size + + def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: + """Return the corresponding number of sample for a given chunk size, in frames. + Where size is the number of features frames after applying subsampling. + Args: + size: Number of frames after subsampling. + hop_length: Frontend's hop length + Returns: + : Number of raw samples + """ + return self.embed.get_size_before_subsampling(size) * hop_length + + def get_encoder_input_size(self, size: int) -> int: + """Return the corresponding number of sample for a given chunk size, in frames. + Where size is the number of features frames after applying subsampling. + Args: + size: Number of frames after subsampling. + Returns: + : Number of raw samples + """ + return self.embed.get_size_before_subsampling(size) + + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset encoder streaming cache. + Args: + left_context: Number of frames in left context. + device: Device ID. + """ + return self.encoders.reset_streaming_cache(left_context, device) + + def forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: Encoder input features. (B, T_in, F) + x_len: Encoder input features lengths. (B,) + Returns: + x: Encoder outputs. (B, T_out, D_enc) + x_len: Encoder outputs lenghts. (B,) + """ + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len) + + if self.unified_model_training: + chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + x_utt = self.encoders( + x, + pos_enc, + mask, + chunk_mask=None, + ) + x_chunk = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x_utt = x_utt[:,::self.time_reduction_factor,:] + x_chunk = x_chunk[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x_utt, x_chunk, olens + + elif self.dynamic_chunk_training: + max_len = x.size(1) + chunk_size = torch.randint(1, max_len, (1,)).item() + + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = (chunk_size % self.short_chunk_size) + 1 + + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + else: + x, mask = self.embed(x, mask, None) + pos_enc = self.pos_enc(x) + chunk_mask = None + x = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x, olens + + def simu_chunk_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + ) -> torch.Tensor: + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len) + + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + + x = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + + return x + + def chunk_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + processed_frames: torch.tensor, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + ) -> torch.Tensor: + """Encode input sequences as chunks. + Args: + x: Encoder input features. (1, T_in, F) + x_len: Encoder input features lengths. (1,) + processed_frames: Number of frames already seen. + left_context: Number of frames in left context. + right_context: Number of frames in right context. + Returns: + x: Encoder outputs. (B, T_out, D_enc) + """ + mask = make_source_mask(x_len) + x, mask = self.embed(x, mask, None) + + if left_context > 0: + processed_mask = ( + torch.arange(left_context, device=x.device) + .view(1, left_context) + .flip(1) + ) + processed_mask = processed_mask >= processed_frames + mask = torch.cat([processed_mask, mask], dim=1) + pos_enc = self.pos_enc(x, left_context=left_context) + x = self.encoders.chunk_forward( + x, + pos_enc, + mask, + chunk_size=chunk_size, + left_context=left_context, + right_context=right_context, + ) + + if right_context > 0: + x = x[:, 0:-right_context, :] + + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + return x diff --git a/funasr/models/joint_net/joint_network.py b/funasr/models/joint_net/joint_network.py new file mode 100644 index 000000000..ed827c420 --- /dev/null +++ b/funasr/models/joint_net/joint_network.py @@ -0,0 +1,61 @@ +"""Transducer joint network implementation.""" + +import torch + +from funasr.modules.nets_utils import get_activation + + +class JointNetwork(torch.nn.Module): + """Transducer joint network module. + + Args: + output_size: Output size. + encoder_size: Encoder output size. + decoder_size: Decoder output size.. + joint_space_size: Joint space size. + joint_act_type: Type of activation for joint network. + **activation_parameters: Parameters for the activation function. + + """ + + def __init__( + self, + output_size: int, + encoder_size: int, + decoder_size: int, + joint_space_size: int = 256, + joint_activation_type: str = "tanh", + ) -> None: + """Construct a JointNetwork object.""" + super().__init__() + + self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size) + self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False) + + self.lin_out = torch.nn.Linear(joint_space_size, output_size) + + self.joint_activation = get_activation( + joint_activation_type + ) + + def forward( + self, + enc_out: torch.Tensor, + dec_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """Joint computation of encoder and decoder hidden state sequences. + + Args: + enc_out: Expanded encoder output state sequences (B, T, 1, D_enc) + dec_out: Expanded decoder output state sequences (B, 1, U, D_dec) + + Returns: + joint_out: Joint output state sequences. (B, T, U, D_out) + + """ + if project_input: + joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out)) + else: + joint_out = self.joint_activation(enc_out + dec_out) + return self.lin_out(joint_out) diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py index 31d5a8775..62020796e 100644 --- a/funasr/modules/attention.py +++ b/funasr/modules/attention.py @@ -11,7 +11,7 @@ import math import numpy import torch from torch import nn - +from typing import Optional, Tuple class MultiHeadedAttention(nn.Module): """Multi-Head Attention layer. @@ -741,3 +741,221 @@ class MultiHeadSelfAttention(nn.Module): scores = torch.matmul(q_h, k_h.transpose(-2, -1)) att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) return att_outs + +class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): + """RelPositionMultiHeadedAttention definition. + Args: + num_heads: Number of attention heads. + embed_size: Embedding size. + dropout_rate: Dropout rate. + """ + + def __init__( + self, + num_heads: int, + embed_size: int, + dropout_rate: float = 0.0, + simplified_attention_score: bool = False, + ) -> None: + """Construct an MultiHeadedAttention object.""" + super().__init__() + + self.d_k = embed_size // num_heads + self.num_heads = num_heads + + assert self.d_k * num_heads == embed_size, ( + "embed_size (%d) must be divisible by num_heads (%d)", + (embed_size, num_heads), + ) + + self.linear_q = torch.nn.Linear(embed_size, embed_size) + self.linear_k = torch.nn.Linear(embed_size, embed_size) + self.linear_v = torch.nn.Linear(embed_size, embed_size) + + self.linear_out = torch.nn.Linear(embed_size, embed_size) + + if simplified_attention_score: + self.linear_pos = torch.nn.Linear(embed_size, num_heads) + + self.compute_att_score = self.compute_simplified_attention_score + else: + self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) + + self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) + self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + self.compute_att_score = self.compute_attention_score + + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.attn = None + + def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: + """Compute relative positional encoding. + Args: + x: Input sequence. (B, H, T_1, 2 * T_1 - 1) + left_context: Number of frames in left context. + Returns: + x: Output sequence. (B, H, T_1, T_2) + """ + batch_size, n_heads, time1, n = x.shape + time2 = time1 + left_context + + batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() + + return x.as_strided( + (batch_size, n_heads, time1, time2), + (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), + storage_offset=(n_stride * (time1 - 1)), + ) + + def compute_simplified_attention_score( + self, + query: torch.Tensor, + key: torch.Tensor, + pos_enc: torch.Tensor, + left_context: int = 0, + ) -> torch.Tensor: + """Simplified attention score computation. + Reference: https://github.com/k2-fsa/icefall/pull/458 + Args: + query: Transformed query tensor. (B, H, T_1, d_k) + key: Transformed key tensor. (B, H, T_2, d_k) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + left_context: Number of frames in left context. + Returns: + : Attention score. (B, H, T_1, T_2) + """ + pos_enc = self.linear_pos(pos_enc) + + matrix_ac = torch.matmul(query, key.transpose(2, 3)) + + matrix_bd = self.rel_shift( + pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), + left_context=left_context, + ) + + return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + + def compute_attention_score( + self, + query: torch.Tensor, + key: torch.Tensor, + pos_enc: torch.Tensor, + left_context: int = 0, + ) -> torch.Tensor: + """Attention score computation. + Args: + query: Transformed query tensor. (B, H, T_1, d_k) + key: Transformed key tensor. (B, H, T_2, d_k) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + left_context: Number of frames in left context. + Returns: + : Attention score. (B, H, T_1, T_2) + """ + p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) + + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) + matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) + + return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + Args: + query: Query tensor. (B, T_1, size) + key: Key tensor. (B, T_2, size) + v: Value tensor. (B, T_2, size) + Returns: + q: Transformed query tensor. (B, H, T_1, d_k) + k: Transformed key tensor. (B, H, T_2, d_k) + v: Transformed value tensor. (B, H, T_2, d_k) + """ + n_batch = query.size(0) + + q = ( + self.linear_q(query) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + k = ( + self.linear_k(key) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + v = ( + self.linear_v(value) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + + return q, k, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute attention context vector. + Args: + value: Transformed value. (B, H, T_2, d_k) + scores: Attention score. (B, H, T_1, T_2) + mask: Source mask. (B, T_2) + chunk_mask: Chunk mask. (T_1, T_1) + Returns: + attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) + """ + batch_size = scores.size(0) + mask = mask.unsqueeze(1).unsqueeze(2) + if chunk_mask is not None: + mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask + scores = scores.masked_fill(mask, float("-inf")) + self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + + attn_output = self.dropout(self.attn) + attn_output = torch.matmul(attn_output, value) + + attn_output = self.linear_out( + attn_output.transpose(1, 2) + .contiguous() + .view(batch_size, -1, self.num_heads * self.d_k) + ) + + return attn_output + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + left_context: int = 0, + ) -> torch.Tensor: + """Compute scaled dot product attention with rel. positional encoding. + Args: + query: Query tensor. (B, T_1, size) + key: Key tensor. (B, T_2, size) + value: Value tensor. (B, T_2, size) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + mask: Source mask. (B, T_2) + chunk_mask: Chunk mask. (T_1, T_1) + left_context: Number of frames in left context. + Returns: + : Output tensor. (B, T_1, H * d_k) + """ + q, k, v = self.forward_qkv(query, key, value) + scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) + return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py new file mode 100644 index 000000000..3eb8e08d0 --- /dev/null +++ b/funasr/modules/beam_search/beam_search_transducer.py @@ -0,0 +1,704 @@ +"""Search algorithms for Transducer models.""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from funasr.models.joint_net.joint_network import JointNetwork + + +@dataclass +class Hypothesis: + """Default hypothesis definition for Transducer search algorithms. + + Args: + score: Total log-probability. + yseq: Label sequence as integer ID sequence. + dec_state: RNNDecoder or StatelessDecoder state. + ((N, 1, D_dec), (N, 1, D_dec) or None) or None + lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None + + """ + + score: float + yseq: List[int] + dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None + lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None + + +@dataclass +class ExtendedHypothesis(Hypothesis): + """Extended hypothesis definition for NSC beam search and mAES. + + Args: + : Hypothesis dataclass arguments. + dec_out: Decoder output sequence. (B, D_dec) + lm_score: Log-probabilities of the LM for given label. (vocab_size) + + """ + + dec_out: torch.Tensor = None + lm_score: torch.Tensor = None + + +class BeamSearchTransducer: + """Beam search implementation for Transducer. + + Args: + decoder: Decoder module. + joint_network: Joint network module. + beam_size: Size of the beam. + lm: LM class. + lm_weight: LM weight for soft fusion. + search_type: Search algorithm to use during inference. + max_sym_exp: Number of maximum symbol expansions at each time step. (TSD) + u_max: Maximum expected target sequence length. (ALSD) + nstep: Number of maximum expansion steps at each time step. (mAES) + expansion_gamma: Allowed logp difference for prune-by-value method. (mAES) + expansion_beta: + Number of additional candidates for expanded hypotheses selection. (mAES) + score_norm: Normalize final scores by length. + nbest: Number of final hypothesis. + streaming: Whether to perform chunk-by-chunk beam search. + + """ + + def __init__( + self, + decoder, + joint_network: JointNetwork, + beam_size: int, + lm: Optional[torch.nn.Module] = None, + lm_weight: float = 0.1, + search_type: str = "default", + max_sym_exp: int = 3, + u_max: int = 50, + nstep: int = 2, + expansion_gamma: float = 2.3, + expansion_beta: int = 2, + score_norm: bool = False, + nbest: int = 1, + streaming: bool = False, + ) -> None: + """Construct a BeamSearchTransducer object.""" + super().__init__() + + self.decoder = decoder + self.joint_network = joint_network + + self.vocab_size = decoder.vocab_size + + assert beam_size <= self.vocab_size, ( + "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." + % ( + beam_size, + self.vocab_size, + ) + ) + self.beam_size = beam_size + + if search_type == "default": + self.search_algorithm = self.default_beam_search + elif search_type == "tsd": + assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % ( + max_sym_exp + ) + self.max_sym_exp = max_sym_exp + + self.search_algorithm = self.time_sync_decoding + elif search_type == "alsd": + assert not streaming, "ALSD is not available in streaming mode." + + assert u_max >= 0, "u_max should be a positive integer, a portion of max_T." + self.u_max = u_max + + self.search_algorithm = self.align_length_sync_decoding + elif search_type == "maes": + assert self.vocab_size >= beam_size + expansion_beta, ( + "beam_size (%d) + expansion_beta (%d) " + " should be smaller than or equal to vocab size (%d)." + % (beam_size, expansion_beta, self.vocab_size) + ) + self.max_candidates = beam_size + expansion_beta + + self.nstep = nstep + self.expansion_gamma = expansion_gamma + + self.search_algorithm = self.modified_adaptive_expansion_search + else: + raise NotImplementedError( + "Specified search type (%s) is not supported." % search_type + ) + + self.use_lm = lm is not None + + if self.use_lm: + assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported." + + self.sos = self.vocab_size - 1 + + self.lm = lm + self.lm_weight = lm_weight + + self.score_norm = score_norm + self.nbest = nbest + + self.reset_inference_cache() + + def __call__( + self, + enc_out: torch.Tensor, + is_final: bool = True, + ) -> List[Hypothesis]: + """Perform beam search. + + Args: + enc_out: Encoder output sequence. (T, D_enc) + is_final: Whether enc_out is the final chunk of data. + + Returns: + nbest_hyps: N-best decoding results + + """ + self.decoder.set_device(enc_out.device) + + hyps = self.search_algorithm(enc_out) + + if is_final: + self.reset_inference_cache() + + return self.sort_nbest(hyps) + + self.search_cache = hyps + + return hyps + + def reset_inference_cache(self) -> None: + """Reset cache for decoder scoring and streaming.""" + self.decoder.score_cache = {} + self.search_cache = None + + def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: + """Sort in-place hypotheses by score or score given sequence length. + + Args: + hyps: Hypothesis. + + Return: + hyps: Sorted hypothesis. + + """ + if self.score_norm: + hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) + else: + hyps.sort(key=lambda x: x.score, reverse=True) + + return hyps[: self.nbest] + + def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]: + """Recombine hypotheses with same label ID sequence. + + Args: + hyps: Hypotheses. + + Returns: + final: Recombined hypotheses. + + """ + final = {} + + for hyp in hyps: + str_yseq = "_".join(map(str, hyp.yseq)) + + if str_yseq in final: + final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score) + else: + final[str_yseq] = hyp + + return [*final.values()] + + def select_k_expansions( + self, + hyps: List[ExtendedHypothesis], + topk_idx: torch.Tensor, + topk_logp: torch.Tensor, + ) -> List[ExtendedHypothesis]: + """Return K hypotheses candidates for expansion from a list of hypothesis. + + K candidates are selected according to the extended hypotheses probabilities + and a prune-by-value method. Where K is equal to beam_size + beta. + + Args: + hyps: Hypotheses. + topk_idx: Indices of candidates hypothesis. + topk_logp: Log-probabilities of candidates hypothesis. + + Returns: + k_expansions: Best K expansion hypotheses candidates. + + """ + k_expansions = [] + + for i, hyp in enumerate(hyps): + hyp_i = [ + (int(k), hyp.score + float(v)) + for k, v in zip(topk_idx[i], topk_logp[i]) + ] + k_best_exp = max(hyp_i, key=lambda x: x[1])[1] + + k_expansions.append( + sorted( + filter( + lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i + ), + key=lambda x: x[1], + reverse=True, + ) + ) + + return k_expansions + + def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor: + """Make batch of inputs with left padding for LM scoring. + + Args: + hyps_seq: Hypothesis sequences. + + Returns: + : Padded batch of sequences. + + """ + max_len = max([len(h) for h in hyps_seq]) + + return torch.LongTensor( + [[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq], + device=self.decoder.device, + ) + + def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]: + """Beam search implementation without prefix search. + + Modified from https://arxiv.org/pdf/1211.3711.pdf + + Args: + enc_out: Encoder output sequence. (T, D) + + Returns: + nbest_hyps: N-best hypothesis. + + """ + beam_k = min(self.beam_size, (self.vocab_size - 1)) + max_t = len(enc_out) + + if self.search_cache is not None: + kept_hyps = self.search_cache + else: + kept_hyps = [ + Hypothesis( + score=0.0, + yseq=[0], + dec_state=self.decoder.init_state(1), + ) + ] + + for t in range(max_t): + hyps = kept_hyps + kept_hyps = [] + + while True: + max_hyp = max(hyps, key=lambda x: x.score) + hyps.remove(max_hyp) + + label = torch.full( + (1, 1), + max_hyp.yseq[-1], + dtype=torch.long, + device=self.decoder.device, + ) + dec_out, state = self.decoder.score( + label, + max_hyp.yseq, + max_hyp.dec_state, + ) + + logp = torch.log_softmax( + self.joint_network(enc_out[t : t + 1, :], dec_out), + dim=-1, + ).squeeze(0) + top_k = logp[1:].topk(beam_k, dim=-1) + + kept_hyps.append( + Hypothesis( + score=(max_hyp.score + float(logp[0:1])), + yseq=max_hyp.yseq, + dec_state=max_hyp.dec_state, + lm_state=max_hyp.lm_state, + ) + ) + + if self.use_lm: + lm_scores, lm_state = self.lm.score( + torch.LongTensor( + [self.sos] + max_hyp.yseq[1:], device=self.decoder.device + ), + max_hyp.lm_state, + None, + ) + else: + lm_state = max_hyp.lm_state + + for logp, k in zip(*top_k): + score = max_hyp.score + float(logp) + + if self.use_lm: + score += self.lm_weight * lm_scores[k + 1] + + hyps.append( + Hypothesis( + score=score, + yseq=max_hyp.yseq + [int(k + 1)], + dec_state=state, + lm_state=lm_state, + ) + ) + + hyps_max = float(max(hyps, key=lambda x: x.score).score) + kept_most_prob = sorted( + [hyp for hyp in kept_hyps if hyp.score > hyps_max], + key=lambda x: x.score, + ) + if len(kept_most_prob) >= self.beam_size: + kept_hyps = kept_most_prob + break + + return kept_hyps + + def align_length_sync_decoding( + self, + enc_out: torch.Tensor, + ) -> List[Hypothesis]: + """Alignment-length synchronous beam search implementation. + + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + h: Encoder output sequences. (T, D) + + Returns: + nbest_hyps: N-best hypothesis. + + """ + t_max = int(enc_out.size(0)) + u_max = min(self.u_max, (t_max - 1)) + + B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))] + final = [] + + if self.use_lm: + B[0].lm_state = self.lm.zero_state() + + for i in range(t_max + u_max): + A = [] + + B_ = [] + B_enc_out = [] + for hyp in B: + u = len(hyp.yseq) - 1 + t = i - u + + if t > (t_max - 1): + continue + + B_.append(hyp) + B_enc_out.append((t, enc_out[t])) + + if B_: + beam_enc_out = torch.stack([b[1] for b in B_enc_out]) + beam_dec_out, beam_state = self.decoder.batch_score(B_) + + beam_logp = torch.log_softmax( + self.joint_network(beam_enc_out, beam_dec_out), + dim=-1, + ) + beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) + + if self.use_lm: + beam_lm_scores, beam_lm_states = self.lm.batch_score( + self.create_lm_batch_inputs([b.yseq for b in B_]), + [b.lm_state for b in B_], + None, + ) + + for i, hyp in enumerate(B_): + new_hyp = Hypothesis( + score=(hyp.score + float(beam_logp[i, 0])), + yseq=hyp.yseq[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + + A.append(new_hyp) + + if B_enc_out[i][0] == (t_max - 1): + final.append(new_hyp) + + for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + yseq=(hyp.yseq[:] + [int(k)]), + dec_state=self.decoder.select_state(beam_state, i), + lm_state=hyp.lm_state, + ) + + if self.use_lm: + new_hyp.score += self.lm_weight * beam_lm_scores[i, k] + new_hyp.lm_state = beam_lm_states[i] + + A.append(new_hyp) + + B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] + B = self.recombine_hyps(B) + + if final: + return final + + return B + + def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: + """Time synchronous beam search implementation. + + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + enc_out: Encoder output sequence. (T, D) + + Returns: + nbest_hyps: N-best hypothesis. + + """ + if self.search_cache is not None: + B = self.search_cache + else: + B = [ + Hypothesis( + yseq=[0], + score=0.0, + dec_state=self.decoder.init_state(1), + ) + ] + + if self.use_lm: + B[0].lm_state = self.lm.zero_state() + + for enc_out_t in enc_out: + A = [] + C = B + + enc_out_t = enc_out_t.unsqueeze(0) + + for v in range(self.max_sym_exp): + D = [] + + beam_dec_out, beam_state = self.decoder.batch_score(C) + + beam_logp = torch.log_softmax( + self.joint_network(enc_out_t, beam_dec_out), + dim=-1, + ) + beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) + + seq_A = [h.yseq for h in A] + + for i, hyp in enumerate(C): + if hyp.yseq not in seq_A: + A.append( + Hypothesis( + score=(hyp.score + float(beam_logp[i, 0])), + yseq=hyp.yseq[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + ) + else: + dict_pos = seq_A.index(hyp.yseq) + + A[dict_pos].score = np.logaddexp( + A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) + ) + + if v < (self.max_sym_exp - 1): + if self.use_lm: + beam_lm_scores, beam_lm_states = self.lm.batch_score( + self.create_lm_batch_inputs([c.yseq for c in C]), + [c.lm_state for c in C], + None, + ) + + for i, hyp in enumerate(C): + for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + yseq=(hyp.yseq + [int(k)]), + dec_state=self.decoder.select_state(beam_state, i), + lm_state=hyp.lm_state, + ) + + if self.use_lm: + new_hyp.score += self.lm_weight * beam_lm_scores[i, k] + new_hyp.lm_state = beam_lm_states[i] + + D.append(new_hyp) + + C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size] + + B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] + + return B + + def modified_adaptive_expansion_search( + self, + enc_out: torch.Tensor, + ) -> List[ExtendedHypothesis]: + """Modified version of Adaptive Expansion Search (mAES). + + Based on AES (https://ieeexplore.ieee.org/document/9250505) and + NSC (https://arxiv.org/abs/2201.05420). + + Args: + enc_out: Encoder output sequence. (T, D_enc) + + Returns: + nbest_hyps: N-best hypothesis. + + """ + if self.search_cache is not None: + kept_hyps = self.search_cache + else: + init_tokens = [ + ExtendedHypothesis( + yseq=[0], + score=0.0, + dec_state=self.decoder.init_state(1), + ) + ] + + beam_dec_out, beam_state = self.decoder.batch_score( + init_tokens, + ) + + if self.use_lm: + beam_lm_scores, beam_lm_states = self.lm.batch_score( + self.create_lm_batch_inputs([h.yseq for h in init_tokens]), + [h.lm_state for h in init_tokens], + None, + ) + + lm_state = beam_lm_states[0] + lm_score = beam_lm_scores[0] + else: + lm_state = None + lm_score = None + + kept_hyps = [ + ExtendedHypothesis( + yseq=[0], + score=0.0, + dec_state=self.decoder.select_state(beam_state, 0), + dec_out=beam_dec_out[0], + lm_state=lm_state, + lm_score=lm_score, + ) + ] + + for enc_out_t in enc_out: + hyps = kept_hyps + kept_hyps = [] + + beam_enc_out = enc_out_t.unsqueeze(0) + + list_b = [] + for n in range(self.nstep): + beam_dec_out = torch.stack([h.dec_out for h in hyps]) + + beam_logp, beam_idx = torch.log_softmax( + self.joint_network(beam_enc_out, beam_dec_out), + dim=-1, + ).topk(self.max_candidates, dim=-1) + + k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp) + + list_exp = [] + for i, hyp in enumerate(hyps): + for k, new_score in k_expansions[i]: + new_hyp = ExtendedHypothesis( + yseq=hyp.yseq[:], + score=new_score, + dec_out=hyp.dec_out, + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + lm_score=hyp.lm_score, + ) + + if k == 0: + list_b.append(new_hyp) + else: + new_hyp.yseq.append(int(k)) + + if self.use_lm: + new_hyp.score += self.lm_weight * float(hyp.lm_score[k]) + + list_exp.append(new_hyp) + + if not list_exp: + kept_hyps = sorted( + self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True + )[: self.beam_size] + + break + else: + beam_dec_out, beam_state = self.decoder.batch_score( + list_exp, + ) + + if self.use_lm: + beam_lm_scores, beam_lm_states = self.lm.batch_score( + self.create_lm_batch_inputs([h.yseq for h in list_exp]), + [h.lm_state for h in list_exp], + None, + ) + + if n < (self.nstep - 1): + for i, hyp in enumerate(list_exp): + hyp.dec_out = beam_dec_out[i] + hyp.dec_state = self.decoder.select_state(beam_state, i) + + if self.use_lm: + hyp.lm_state = beam_lm_states[i] + hyp.lm_score = beam_lm_scores[i] + + hyps = list_exp[:] + else: + beam_logp = torch.log_softmax( + self.joint_network(beam_enc_out, beam_dec_out), + dim=-1, + ) + + for i, hyp in enumerate(list_exp): + hyp.score += float(beam_logp[i, 0]) + + hyp.dec_out = beam_dec_out[i] + hyp.dec_state = self.decoder.select_state(beam_state, i) + + if self.use_lm: + hyp.lm_state = beam_lm_states[i] + hyp.lm_score = beam_lm_scores[i] + + kept_hyps = sorted( + self.recombine_hyps(list_b + list_exp), + key=lambda x: x.score, + reverse=True, + )[: self.beam_size] + + return kept_hyps diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py index 92f90796a..f430fcb43 100644 --- a/funasr/modules/e2e_asr_common.py +++ b/funasr/modules/e2e_asr_common.py @@ -6,6 +6,8 @@ """Common functions for ASR.""" +from typing import List, Optional, Tuple + import json import logging import sys @@ -13,7 +15,10 @@ import sys from itertools import groupby import numpy as np import six +import torch +from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer +from funasr.models.joint_net.joint_network import JointNetwork def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): """End detection. @@ -247,3 +252,148 @@ class ErrorCalculator(object): word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) return float(sum(word_eds)) / sum(word_ref_lens) + +class ErrorCalculatorTransducer: + """Calculate CER and WER for transducer models. + Args: + decoder: Decoder module. + joint_network: Joint Network module. + token_list: List of token units. + sym_space: Space symbol. + sym_blank: Blank symbol. + report_cer: Whether to compute CER. + report_wer: Whether to compute WER. + """ + + def __init__( + self, + decoder, + joint_network: JointNetwork, + token_list: List[int], + sym_space: str, + sym_blank: str, + report_cer: bool = False, + report_wer: bool = False, + ) -> None: + """Construct an ErrorCalculatorTransducer object.""" + super().__init__() + + self.beam_search = BeamSearchTransducer( + decoder=decoder, + joint_network=joint_network, + beam_size=1, + search_type="default", + score_norm=False, + ) + + self.decoder = decoder + + self.token_list = token_list + self.space = sym_space + self.blank = sym_blank + + self.report_cer = report_cer + self.report_wer = report_wer + + def __call__( + self, encoder_out: torch.Tensor, target: torch.Tensor + ) -> Tuple[Optional[float], Optional[float]]: + """Calculate sentence-level WER or/and CER score for Transducer model. + Args: + encoder_out: Encoder output sequences. (B, T, D_enc) + target: Target label ID sequences. (B, L) + Returns: + : Sentence-level CER score. + : Sentence-level WER score. + """ + cer, wer = None, None + + batchsize = int(encoder_out.size(0)) + + encoder_out = encoder_out.to(next(self.decoder.parameters()).device) + + batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)] + pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest] + + char_pred, char_target = self.convert_to_char(pred, target) + + if self.report_cer: + cer = self.calculate_cer(char_pred, char_target) + + if self.report_wer: + wer = self.calculate_wer(char_pred, char_target) + + return cer, wer + + def convert_to_char( + self, pred: torch.Tensor, target: torch.Tensor + ) -> Tuple[List, List]: + """Convert label ID sequences to character sequences. + Args: + pred: Prediction label ID sequences. (B, U) + target: Target label ID sequences. (B, L) + Returns: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + """ + char_pred, char_target = [], [] + + for i, pred_i in enumerate(pred): + char_pred_i = [self.token_list[int(h)] for h in pred_i] + char_target_i = [self.token_list[int(r)] for r in target[i]] + + char_pred_i = "".join(char_pred_i).replace(self.space, " ") + char_pred_i = char_pred_i.replace(self.blank, "") + + char_target_i = "".join(char_target_i).replace(self.space, " ") + char_target_i = char_target_i.replace(self.blank, "") + + char_pred.append(char_pred_i) + char_target.append(char_target_i) + + return char_pred, char_target + + def calculate_cer( + self, char_pred: torch.Tensor, char_target: torch.Tensor + ) -> float: + """Calculate sentence-level CER score. + Args: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + Returns: + : Average sentence-level CER score. + """ + import editdistance + + distances, lens = [], [] + + for i, char_pred_i in enumerate(char_pred): + pred = char_pred_i.replace(" ", "") + target = char_target[i].replace(" ", "") + distances.append(editdistance.eval(pred, target)) + lens.append(len(target)) + + return float(sum(distances)) / sum(lens) + + def calculate_wer( + self, char_pred: torch.Tensor, char_target: torch.Tensor + ) -> float: + """Calculate sentence-level WER score. + Args: + char_pred: Prediction character sequences. (B, ?) + char_target: Target character sequences. (B, ?) + Returns: + : Average sentence-level WER score + """ + import editdistance + + distances, lens = [], [] + + for i, char_pred_i in enumerate(char_pred): + pred = char_pred_i.replace("▁", " ").split() + target = char_target[i].replace("▁", " ").split() + + distances.append(editdistance.eval(pred, target)) + lens.append(len(target)) + + return float(sum(distances)) / sum(lens) diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py index 4b292a79b..c347e24f1 100644 --- a/funasr/modules/embedding.py +++ b/funasr/modules/embedding.py @@ -440,4 +440,79 @@ class StreamSinusoidalPositionEncoder(torch.nn.Module): outputs = F.pad(outputs, (pad_left, pad_right)) outputs = outputs.transpose(1, 2) return outputs - + +class StreamingRelPositionalEncoding(torch.nn.Module): + """Relative positional encoding. + Args: + size: Module size. + max_len: Maximum input length. + dropout_rate: Dropout rate. + """ + + def __init__( + self, size: int, dropout_rate: float = 0.0, max_len: int = 5000 + ) -> None: + """Construct a RelativePositionalEncoding object.""" + super().__init__() + + self.size = size + + self.pe = None + self.dropout = torch.nn.Dropout(p=dropout_rate) + + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: + """Reset positional encoding. + Args: + x: Input sequences. (B, T, ?) + left_context: Number of frames in left context. + """ + time1 = x.size(1) + left_context + + if self.pe is not None: + if self.pe.size(1) >= time1 * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(device=x.device, dtype=x.dtype) + return + + pe_positive = torch.zeros(time1, self.size) + pe_negative = torch.zeros(time1, self.size) + + position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.size, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.size) + ) + + pe_positive[:, 0::2] = torch.sin(position * div_term) + pe_positive[:, 1::2] = torch.cos(position * div_term) + pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) + + pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) + pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) + pe_negative = pe_negative[1:].unsqueeze(0) + + self.pe = torch.cat([pe_positive, pe_negative], dim=1).to( + dtype=x.dtype, device=x.device + ) + + def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: + """Compute positional encoding. + Args: + x: Input sequences. (B, T, ?) + left_context: Number of frames in left context. + Returns: + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?) + """ + self.extend_pe(x, left_context=left_context) + + time1 = x.size(1) + left_context + + pos_enc = self.pe[ + :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1) + ] + pos_enc = self.dropout(pos_enc) + + return pos_enc diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py index 6d77d69a6..5d4fe1c85 100644 --- a/funasr/modules/nets_utils.py +++ b/funasr/modules/nets_utils.py @@ -3,7 +3,7 @@ """Network related utility tools.""" import logging -from typing import Dict +from typing import Dict, List, Tuple import numpy as np import torch @@ -506,3 +506,196 @@ def get_activation(act): } return activation_funcs[act]() + +class TooShortUttError(Exception): + """Raised when the utt is too short for subsampling. + + Args: + message: Error message to display. + actual_size: The size that cannot pass the subsampling. + limit: The size limit for subsampling. + + """ + + def __init__(self, message: str, actual_size: int, limit: int) -> None: + """Construct a TooShortUttError module.""" + super().__init__(message) + + self.actual_size = actual_size + self.limit = limit + + +def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]: + """Check if the input is too short for subsampling. + + Args: + sub_factor: Subsampling factor for Conv2DSubsampling. + size: Input size. + + Returns: + : Whether an error should be sent. + : Size limit for specified subsampling factor. + + """ + if sub_factor == 2 and size < 3: + return True, 7 + elif sub_factor == 4 and size < 7: + return True, 7 + elif sub_factor == 6 and size < 11: + return True, 11 + + return False, -1 + + +def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]: + """Get conv2D second layer parameters for given subsampling factor. + + Args: + sub_factor: Subsampling factor (1/X). + input_size: Input size. + + Returns: + : Kernel size for second convolution. + : Stride for second convolution. + : Conv2DSubsampling output size. + + """ + if sub_factor == 2: + return 3, 1, (((input_size - 1) // 2 - 2)) + elif sub_factor == 4: + return 3, 2, (((input_size - 1) // 2 - 1) // 2) + elif sub_factor == 6: + return 5, 3, (((input_size - 1) // 2 - 2) // 3) + else: + raise ValueError( + "subsampling_factor parameter should be set to either 2, 4 or 6." + ) + + +def make_chunk_mask( + size: int, + chunk_size: int, + left_chunk_size: int = 0, + device: torch.device = None, +) -> torch.Tensor: + """Create chunk mask for the subsequent steps (size, size). + + Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + + Args: + size: Size of the source mask. + chunk_size: Number of frames in chunk. + left_chunk_size: Size of the left context in chunks (0 means full context). + device: Device for the mask tensor. + + Returns: + mask: Chunk mask. (size, size) + + """ + mask = torch.zeros(size, size, device=device, dtype=torch.bool) + + for i in range(size): + if left_chunk_size <= 0: + start = 0 + else: + start = max((i // chunk_size - left_chunk_size) * chunk_size, 0) + + end = min((i // chunk_size + 1) * chunk_size, size) + mask[i, start:end] = True + + return ~mask + +def make_source_mask(lengths: torch.Tensor) -> torch.Tensor: + """Create source mask for given lengths. + + Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + + Args: + lengths: Sequence lengths. (B,) + + Returns: + : Mask for the sequence lengths. (B, max_len) + + """ + max_len = lengths.max() + batch_size = lengths.size(0) + + expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths) + + return expanded_lengths >= lengths.unsqueeze(1) + + +def get_transducer_task_io( + labels: torch.Tensor, + encoder_out_lens: torch.Tensor, + ignore_id: int = -1, + blank_id: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get Transducer loss I/O. + + Args: + labels: Label ID sequences. (B, L) + encoder_out_lens: Encoder output lengths. (B,) + ignore_id: Padding symbol ID. + blank_id: Blank symbol ID. + + Returns: + decoder_in: Decoder inputs. (B, U) + target: Target label ID sequences. (B, U) + t_len: Time lengths. (B,) + u_len: Label lengths. (B,) + + """ + + def pad_list(labels: List[torch.Tensor], padding_value: int = 0): + """Create padded batch of labels from a list of labels sequences. + + Args: + labels: Labels sequences. [B x (?)] + padding_value: Padding value. + + Returns: + labels: Batch of padded labels sequences. (B,) + + """ + batch_size = len(labels) + + padded = ( + labels[0] + .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:]) + .fill_(padding_value) + ) + + for i in range(batch_size): + padded[i, : labels[i].size(0)] = labels[i] + + return padded + + device = labels.device + + labels_unpad = [y[y != ignore_id] for y in labels] + blank = labels[0].new([blank_id]) + + decoder_in = pad_list( + [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id + ).to(device) + + target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) + + encoder_out_lens = list(map(int, encoder_out_lens)) + t_len = torch.IntTensor(encoder_out_lens).to(device) + + u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) + + return decoder_in, target, t_len, u_len + +def pad_to_len(t: torch.Tensor, pad_len: int, dim: int): + """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros.""" + if t.size(dim) == pad_len: + return t + else: + pad_size = list(t.shape) + pad_size[dim] = pad_len - t.size(dim) + return torch.cat( + [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim + ) diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py index a3d2676a8..2b2dac8f3 100644 --- a/funasr/modules/repeat.py +++ b/funasr/modules/repeat.py @@ -6,6 +6,8 @@ """Repeat the same layer definition.""" +from typing import Dict, List, Optional + import torch @@ -31,3 +33,92 @@ def repeat(N, fn): """ return MultiSequential(*[fn(n) for n in range(N)]) + + +class MultiBlocks(torch.nn.Module): + """MultiBlocks definition. + Args: + block_list: Individual blocks of the encoder architecture. + output_size: Architecture output size. + norm_class: Normalization module class. + norm_args: Normalization module arguments. + """ + + def __init__( + self, + block_list: List[torch.nn.Module], + output_size: int, + norm_class: torch.nn.Module = torch.nn.LayerNorm, + ) -> None: + """Construct a MultiBlocks object.""" + super().__init__() + + self.blocks = torch.nn.ModuleList(block_list) + self.norm_blocks = norm_class(output_size) + + self.num_blocks = len(block_list) + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset encoder streaming cache. + Args: + left_context: Number of left frames during chunk-by-chunk inference. + device: Device to use for cache tensor. + """ + for idx in range(self.num_blocks): + self.blocks[idx].reset_streaming_cache(left_context, device) + + def forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward each block of the encoder architecture. + Args: + x: MultiBlocks input sequences. (B, T, D_block_1) + pos_enc: Positional embedding sequences. + mask: Source mask. (B, T) + chunk_mask: Chunk mask. (T_2, T_2) + Returns: + x: Output sequences. (B, T, D_block_N) + """ + for block_index, block in enumerate(self.blocks): + x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask) + + x = self.norm_blocks(x) + + return x + + def chunk_forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_size: int = 0, + left_context: int = 0, + right_context: int = 0, + ) -> torch.Tensor: + """Forward each block of the encoder architecture. + Args: + x: MultiBlocks input sequences. (B, T, D_block_1) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att) + mask: Source mask. (B, T_2) + left_context: Number of frames in left context. + right_context: Number of frames in right context. + Returns: + x: MultiBlocks output sequences. (B, T, D_block_N) + """ + for block_idx, block in enumerate(self.blocks): + x, pos_enc = block.chunk_forward( + x, + pos_enc, + mask, + chunk_size=chunk_size, + left_context=left_context, + right_context=right_context, + ) + + x = self.norm_blocks(x) + + return x diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py index d492ccf61..623be65bc 100644 --- a/funasr/modules/subsampling.py +++ b/funasr/modules/subsampling.py @@ -11,6 +11,10 @@ import torch.nn.functional as F from funasr.modules.embedding import PositionalEncoding import logging from funasr.modules.streaming_utils.utils import sequence_mask +from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len +from typing import Optional, Tuple, Union +import math + class TooShortUttError(Exception): """Raised when the utt is too short for subsampling. @@ -407,3 +411,201 @@ class Conv1dSubsampling(torch.nn.Module): var_dict_tf[name_tf].shape)) return var_dict_torch_update +class StreamingConvInput(torch.nn.Module): + """Streaming ConvInput module definition. + Args: + input_size: Input size. + conv_size: Convolution size. + subsampling_factor: Subsampling factor. + vgg_like: Whether to use a VGG-like network. + output_size: Block output dimension. + """ + + def __init__( + self, + input_size: int, + conv_size: Union[int, Tuple], + subsampling_factor: int = 4, + vgg_like: bool = True, + output_size: Optional[int] = None, + ) -> None: + """Construct a ConvInput object.""" + super().__init__() + if vgg_like: + if subsampling_factor == 1: + conv_size1, conv_size2 = conv_size + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((1, 2)), + torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((1, 2)), + ) + + output_proj = conv_size2 * ((input_size // 2) // 2) + + self.subsampling_factor = 1 + + self.stride_1 = 1 + + self.create_new_mask = self.create_new_vgg_mask + + else: + conv_size1, conv_size2 = conv_size + + kernel_1 = int(subsampling_factor / 2) + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((kernel_1, 2)), + torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((2, 2)), + ) + + output_proj = conv_size2 * ((input_size // 2) // 2) + + self.subsampling_factor = subsampling_factor + + self.create_new_mask = self.create_new_vgg_mask + + self.stride_1 = kernel_1 + + else: + if subsampling_factor == 1: + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]), + torch.nn.ReLU(), + ) + + output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2) + + self.subsampling_factor = subsampling_factor + self.kernel_2 = 3 + self.stride_2 = 1 + + self.create_new_mask = self.create_new_conv2d_mask + + else: + kernel_2, stride_2, conv_2_output_size = sub_factor_to_params( + subsampling_factor, + input_size, + ) + + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, conv_size, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2), + torch.nn.ReLU(), + ) + + output_proj = conv_size * conv_2_output_size + + self.subsampling_factor = subsampling_factor + self.kernel_2 = kernel_2 + self.stride_2 = stride_2 + + self.create_new_mask = self.create_new_conv2d_mask + + self.vgg_like = vgg_like + self.min_frame_length = 7 + + if output_size is not None: + self.output = torch.nn.Linear(output_proj, output_size) + self.output_size = output_size + else: + self.output = None + self.output_size = output_proj + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: ConvInput input sequences. (B, T, D_feats) + mask: Mask of input sequences. (B, 1, T) + Returns: + x: ConvInput output sequences. (B, sub(T), D_out) + mask: Mask of output sequences. (B, 1, sub(T)) + """ + if mask is not None: + mask = self.create_new_mask(mask) + olens = max(mask.eq(0).sum(1)) + + b, t, f = x.size() + x = x.unsqueeze(1) # (b. 1. t. f) + + if chunk_size is not None: + max_input_length = int( + chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) )) + ) + x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) + x = list(x) + x = torch.stack(x, dim=0) + N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) + x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) + + x = self.conv(x) + + _, c, _, f = x.size() + if chunk_size is not None: + x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] + else: + x = x.transpose(1, 2).contiguous().view(b, -1, c * f) + + if self.output is not None: + x = self.output(x) + + return x, mask[:,:olens][:,:x.size(1)] + + def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Create a new mask for VGG output sequences. + Args: + mask: Mask of input sequences. (B, T) + Returns: + mask: Mask of output sequences. (B, sub(T)) + """ + if self.subsampling_factor > 1: + vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 )) + mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2] + + vgg2_t_len = mask.size(1) - (mask.size(1) % 2) + mask = mask[:, :vgg2_t_len][:, ::2] + else: + mask = mask + + return mask + + def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Create new conformer mask for Conv2d output sequences. + Args: + mask: Mask of input sequences. (B, T) + Returns: + mask: Mask of output sequences. (B, sub(T)) + """ + if self.subsampling_factor > 1: + return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2] + else: + return mask + + def get_size_before_subsampling(self, size: int) -> int: + """Return the original size before subsampling for a given size. + Args: + size: Number of frames after subsampling. + Returns: + : Number of frames before subsampling. + """ + return size * self.subsampling_factor diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py index 52a0ce753..d52c9c383 100644 --- a/funasr/tasks/asr.py +++ b/funasr/tasks/asr.py @@ -38,13 +38,16 @@ from funasr.models.decoder.transformer_decoder import ( from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN from funasr.models.decoder.transformer_decoder import TransformerDecoder from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder +from funasr.models.decoder.rnnt_decoder import RNNTDecoder +from funasr.models.joint_net.joint_network import JointNetwork from funasr.models.e2e_asr import ESPnetASRModel from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer from funasr.models.e2e_tp import TimestampPredictor from funasr.models.e2e_asr_mfcca import MFCCA from funasr.models.e2e_uni_asr import UniASR +from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.models.encoder.conformer_encoder import ConformerEncoder +from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder from funasr.models.encoder.data2vec_encoder import Data2VecEncoder from funasr.models.encoder.rnn_encoder import RNNEncoder from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt @@ -151,6 +154,7 @@ encoder_choices = ClassChoices( sanm_chunk_opt=SANMEncoderChunkOpt, data2vec_encoder=Data2VecEncoder, mfcca_enc=MFCCAEncoder, + chunk_conformer=ConformerChunkEncoder, ), type_check=AbsEncoder, default="rnn", @@ -208,6 +212,16 @@ decoder_choices2 = ClassChoices( type_check=AbsDecoder, default="rnn", ) + +rnnt_decoder_choices = ClassChoices( + "rnnt_decoder", + classes=dict( + rnnt=RNNTDecoder, + ), + type_check=RNNTDecoder, + default="rnnt", +) + predictor_choices = ClassChoices( name="predictor", classes=dict( @@ -1332,3 +1346,378 @@ class ASRTaskAligner(ASRTaskParaformer): ) -> Tuple[str, ...]: retval = ("speech", "text") return retval + + +class ASRTransducerTask(AbsTask): + """ASR Transducer Task definition.""" + + num_optimizers: int = 1 + + class_choices_list = [ + frontend_choices, + specaug_choices, + normalize_choices, + encoder_choices, + rnnt_decoder_choices, + ] + + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + """Add Transducer task arguments. + Args: + cls: ASRTransducerTask object. + parser: Transducer arguments parser. + """ + group = parser.add_argument_group(description="Task related.") + + # required = parser.get_default("required") + # required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="Integer-string mapper for tokens.", + ) + group.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of dimensions for input features.", + ) + group.add_argument( + "--init", + type=str_or_none, + default=None, + help="Type of model initialization to use.", + ) + group.add_argument( + "--model_conf", + action=NestedDictAction, + default=get_default_kwargs(TransducerModel), + help="The keyword arguments for the model class.", + ) + # group.add_argument( + # "--encoder_conf", + # action=NestedDictAction, + # default={}, + # help="The keyword arguments for the encoder class.", + # ) + group.add_argument( + "--joint_network_conf", + action=NestedDictAction, + default={}, + help="The keyword arguments for the joint network class.", + ) + group = parser.add_argument_group(description="Preprocess related.") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Whether to apply preprocessing to input data.", + ) + group.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word", "phn"], + help="The type of tokens to use during tokenization.", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The path of the sentencepiece model.", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="The 'non_linguistic_symbols' file path.", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Text cleaner to use.", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=g2p_choices, + default=None, + help="g2p method to use if --token_type=phn.", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Normalization value for maximum amplitude scaling.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The RIR SCP file path.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="The probability of the applied RIR convolution.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The path of noise SCP file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability of the applied noise addition.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of the noise decibel level.", + ) + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --decoder and --decoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + """Build collate function. + Args: + cls: ASRTransducerTask object. + args: Task arguments. + train: Training mode. + Return: + : Callable collate function. + """ + assert check_argument_types() + + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + """Build pre-processing function. + Args: + cls: ASRTransducerTask object. + args: Task arguments. + train: Training mode. + Return: + : Callable pre-processing function. + """ + assert check_argument_types() + + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Required data depending on task mode. + Args: + cls: ASRTransducerTask object. + train: Training mode. + inference: Inference mode. + Return: + retval: Required task data. + """ + if not inference: + retval = ("speech", "text") + else: + retval = ("speech",) + + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Optional data depending on task mode. + Args: + cls: ASRTransducerTask object. + train: Training mode. + inference: Inference mode. + Return: + retval: Optional task data. + """ + retval = () + assert check_return_type(retval) + + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> TransducerModel: + """Required data depending on task mode. + Args: + cls: ASRTransducerTask object. + args: Task arguments. + Return: + model: ASR Transducer model. + """ + assert check_argument_types() + + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size }") + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Encoder + + if getattr(args, "encoder", None) is not None: + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size, **args.encoder_conf) + else: + encoder = Encoder(input_size, **args.encoder_conf) + encoder_output_size = encoder.output_size() + + # 5. Decoder + rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder) + decoder = rnnt_decoder_class( + vocab_size, + **args.rnnt_decoder_conf, + ) + decoder_output_size = decoder.output_size + + if getattr(args, "decoder", None) is not None: + att_decoder_class = decoder_choices.get_class(args.att_decoder) + + att_decoder = att_decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **args.decoder_conf, + ) + else: + att_decoder = None + # 6. Joint Network + joint_network = JointNetwork( + vocab_size, + encoder_output_size, + decoder_output_size, + **args.joint_network_conf, + ) + + # 7. Build model + + if encoder.unified_model_training: + model = UnifiedTransducerModel( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + **args.model_conf, + ) + + else: + model = TransducerModel( + vocab_size=vocab_size, + token_list=token_list, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + att_decoder=att_decoder, + joint_network=joint_network, + **args.model_conf, + ) + + # 8. Initialize model + if args.init is not None: + raise NotImplementedError( + "Currently not supported.", + "Initialization part will be reworked in a short future.", + ) + + #assert check_return_type(model) + + return model