diff --git a/egs/aishell/e_branchformer/conf/decode_asr_transformer.yaml b/egs/aishell/e_branchformer/conf/decode_asr_transformer.yaml new file mode 100644 index 000000000..e87a2937f --- /dev/null +++ b/egs/aishell/e_branchformer/conf/decode_asr_transformer.yaml @@ -0,0 +1,6 @@ +beam_size: 10 +penalty: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +ctc_weight: 0.4 +lm_weight: 0.0 diff --git a/egs/aishell/e_branchformer/conf/train_asr_e_branchformer.yaml b/egs/aishell/e_branchformer/conf/train_asr_e_branchformer.yaml new file mode 100644 index 000000000..a30e9a23b --- /dev/null +++ b/egs/aishell/e_branchformer/conf/train_asr_e_branchformer.yaml @@ -0,0 +1,101 @@ +# network architecture +# encoder related +encoder: e_branchformer +encoder_conf: + output_size: 256 + attention_heads: 4 + attention_layer_type: rel_selfattn + pos_enc_layer_type: rel_pos + rel_pos_type: latest + cgmlp_linear_units: 1024 + cgmlp_conv_kernel: 31 + use_linear_after_conv: false + gate_activation: identity + num_blocks: 12 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: conv2d + layer_drop_rate: 0.0 + linear_units: 1024 + positionwise_layer_type: linear + use_ffn: true + macaron_ffn: true + merge_conv_kernel: 31 + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0. + src_attention_dropout_rate: 0. + +# frontend related +frontend: wav_frontend +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 1 + lfr_n: 1 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# optimization related +accum_grad: 1 +grad_clip: 5 +max_epoch: 180 +best_model_criterion: +- - valid + - acc + - max +keep_nbest_models: 10 + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 0.000001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 35000 + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 27 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_ratio_range: + - 0. + - 0.05 + num_time_mask: 10 + +dataset_conf: + data_names: speech,text + data_types: sound,text + shuffle: True + shuffle_conf: + shuffle_size: 2048 + sort_size: 500 + batch_conf: + batch_type: token + batch_size: 10000 + num_workers: 8 + +log_interval: 50 +normalize: None \ No newline at end of file diff --git a/egs/aishell/e_branchformer/local/aishell_data_prep.sh b/egs/aishell/e_branchformer/local/aishell_data_prep.sh new file mode 100755 index 000000000..83f489b3c --- /dev/null +++ b/egs/aishell/e_branchformer/local/aishell_data_prep.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Copyright 2017 Xingyu Na +# Apache 2.0 + +#. ./path.sh || exit 1; + +if [ $# != 3 ]; then + echo "Usage: $0 " + echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data" + exit 1; +fi + +aishell_audio_dir=$1 +aishell_text=$2/aishell_transcript_v0.8.txt +output_dir=$3 + +train_dir=$output_dir/data/local/train +dev_dir=$output_dir/data/local/dev +test_dir=$output_dir/data/local/test +tmp_dir=$output_dir/data/local/tmp + +mkdir -p $train_dir +mkdir -p $dev_dir +mkdir -p $test_dir +mkdir -p $tmp_dir + +# data directory check +if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then + echo "Error: $0 requires two directory arguments" + exit 1; +fi + +# find wav audio file for train, dev and test resp. +find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist +n=`cat $tmp_dir/wav.flist | wc -l` +[ $n -ne 141925 ] && \ + echo Warning: expected 141925 data data files, found $n + +grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; +grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; +grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; + +rm -r $tmp_dir + +# Transcriptions preparation +for dir in $train_dir $dev_dir $test_dir; do + echo Preparing $dir transcriptions + sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list + paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all + utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt + awk '{print $1}' $dir/transcripts.txt > $dir/utt.list + utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp + sort -u $dir/transcripts.txt > $dir/text +done + +mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test + +for f in wav.scp text; do + cp $train_dir/$f $output_dir/data/train/$f || exit 1; + cp $dev_dir/$f $output_dir/data/dev/$f || exit 1; + cp $test_dir/$f $output_dir/data/test/$f || exit 1; +done + +echo "$0: AISHELL data preparation succeeded" +exit 0; diff --git a/egs/aishell/e_branchformer/local/download_and_untar.sh b/egs/aishell/e_branchformer/local/download_and_untar.sh new file mode 100755 index 000000000..d98255915 --- /dev/null +++ b/egs/aishell/e_branchformer/local/download_and_untar.sh @@ -0,0 +1,105 @@ +#!/usr/bin/env bash + +# Copyright 2014 Johns Hopkins University (author: Daniel Povey) +# 2017 Xingyu Na +# Apache 2.0 + +remove_archive=false + +if [ "$1" == --remove-archive ]; then + remove_archive=true + shift +fi + +if [ $# -ne 3 ]; then + echo "Usage: $0 [--remove-archive] " + echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" + echo "With --remove-archive it will remove the archive after successfully un-tarring it." + echo " can be one of: data_aishell, resource_aishell." +fi + +data=$1 +url=$2 +part=$3 + +if [ ! -d "$data" ]; then + echo "$0: no such directory $data" + exit 1; +fi + +part_ok=false +list="data_aishell resource_aishell" +for x in $list; do + if [ "$part" == $x ]; then part_ok=true; fi +done +if ! $part_ok; then + echo "$0: expected to be one of $list, but got '$part'" + exit 1; +fi + +if [ -z "$url" ]; then + echo "$0: empty URL base." + exit 1; +fi + +if [ -f $data/$part/.complete ]; then + echo "$0: data part $part was already successfully extracted, nothing to do." + exit 0; +fi + +# sizes of the archive files in bytes. +sizes="15582913665 1246920" + +if [ -f $data/$part.tgz ]; then + size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') + size_ok=false + for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done + if ! $size_ok; then + echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" + echo "does not equal the size of one of the archives." + rm $data/$part.tgz + else + echo "$data/$part.tgz exists and appears to be complete." + fi +fi + +if [ ! -f $data/$part.tgz ]; then + if ! command -v wget >/dev/null; then + echo "$0: wget is not installed." + exit 1; + fi + full_url=$url/$part.tgz + echo "$0: downloading data from $full_url. This may take some time, please be patient." + + cd $data || exit 1 + if ! wget --no-check-certificate $full_url; then + echo "$0: error executing wget $full_url" + exit 1; + fi +fi + +cd $data || exit 1 + +if ! tar -xvzf $part.tgz; then + echo "$0: error un-tarring archive $data/$part.tgz" + exit 1; +fi + +touch $data/$part/.complete + +if [ $part == "data_aishell" ]; then + cd $data/$part/wav || exit 1 + for wav in ./*.tar.gz; do + echo "Extracting wav from $wav" + tar -zxf $wav && rm $wav + done +fi + +echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" + +if $remove_archive; then + echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." + rm $data/$part.tgz +fi + +exit 0; diff --git a/egs/aishell/e_branchformer/path.sh b/egs/aishell/e_branchformer/path.sh new file mode 100755 index 000000000..7972642d0 --- /dev/null +++ b/egs/aishell/e_branchformer/path.sh @@ -0,0 +1,5 @@ +export FUNASR_DIR=$PWD/../../.. + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PATH=$FUNASR_DIR/funasr/bin:$PATH diff --git a/egs/aishell/e_branchformer/run.sh b/egs/aishell/e_branchformer/run.sh new file mode 100755 index 000000000..bcba2d75f --- /dev/null +++ b/egs/aishell/e_branchformer/run.sh @@ -0,0 +1,225 @@ +#!/usr/bin/env bash + +. ./path.sh || exit 1; + +# machines configuration +CUDA_VISIBLE_DEVICES="0,1,2,3" +gpu_num=4 +count=1 +gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding +# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob +njob=5 +train_cmd=utils/run.pl +infer_cmd=utils/run.pl + +# general configuration +feats_dir="../DATA" #feature output dictionary +exp_dir="." +lang=zh +token_type=char +type=sound +scp=wav.scp +speed_perturb="0.9 1.0 1.1" +stage=0 +stop_stage=5 + +# feature configuration +feats_dim=80 +nj=64 + +# data +raw_data=../raw_data +data_url=www.openslr.org/resources/33 + +# exp tag +tag="exp1" + +. utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set=train +valid_set=dev +test_sets="dev test" + +asr_config=conf/train_asr_e_branchformer.yaml +model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}" + +inference_config=conf/decode_asr_transformer.yaml +inference_asr_model=valid.acc.ave_10best.pb + +# you can set gpu num for decoding here +gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default +ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') + +if ${gpu_inference}; then + inference_nj=$[${ngpu}*${njob}] + _ngpu=1 +else + inference_nj=$njob + _ngpu=0 +fi + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "stage -1: Data Download" + local/download_and_untar.sh ${raw_data} ${data_url} data_aishell + local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: Data preparation" + # Data preparation + local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir} + for x in train dev test; do + cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org + paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \ + > ${feats_dir}/data/${x}/text + utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org + mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text + done +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Feature and CMVN Generation" + utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0 +fi + +token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt +echo "dictionary: ${token_list}" +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: Dictionary Preparation" + mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/ + + echo "make a dictionary" + echo "" > ${token_list} + echo "" >> ${token_list} + echo "" >> ${token_list} + utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} + echo "" >> ${token_list} +fi + +# LM Training Stage +world_size=$gpu_num # run on one machine +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: LM Training" +fi + +# ASR Training Stage +world_size=$gpu_num # run on one machine +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "stage 4: ASR Training" + mkdir -p ${exp_dir}/exp/${model_dir} + mkdir -p ${exp_dir}/exp/${model_dir}/log + INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init + if [ -f $INIT_FILE ];then + rm -f $INIT_FILE + fi + init_method=file://$(readlink -f $INIT_FILE) + echo "$0: init method is $init_method" + for ((i = 0; i < $gpu_num; ++i)); do + { + rank=$i + local_rank=$i + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + train.py \ + --task_name asr \ + --gpu_id $gpu_id \ + --use_preprocessor true \ + --token_type $token_type \ + --token_list $token_list \ + --data_dir ${feats_dir}/data \ + --train_set ${train_set} \ + --valid_set ${valid_set} \ + --data_file_names "wav.scp,text" \ + --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \ + --speed_perturb ${speed_perturb} \ + --resume true \ + --output_dir ${exp_dir}/exp/${model_dir} \ + --config $asr_config \ + --ngpu $gpu_num \ + --num_worker_count $count \ + --dist_init_method $init_method \ + --dist_world_size $world_size \ + --dist_rank $rank \ + --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1 + } & + done + wait +fi + +# Testing Stage +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + echo "stage 5: Inference" + for dset in ${test_sets}; do + asr_exp=${exp_dir}/exp/${model_dir} + inference_tag="$(basename "${inference_config}" .yaml)" + _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" + _logdir="${_dir}/logdir" + if [ -d ${_dir} ]; then + echo "${_dir} is already exists. if you want to decode again, please delete this dir first." + exit 0 + fi + mkdir -p "${_logdir}" + _data="${feats_dir}/data/${dset}" + key_file=${_data}/${scp} + num_scp_file="$(<${key_file} wc -l)" + _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") + split_scps= + for n in $(seq "${_nj}"); do + split_scps+=" ${_logdir}/keys.${n}.scp" + done + # shellcheck disable=SC2086 + utils/split_scp.pl "${key_file}" ${split_scps} + _opts= + if [ -n "${inference_config}" ]; then + _opts+="--config ${inference_config} " + fi + ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ + python -m funasr.bin.asr_inference_launch \ + --batch_size 1 \ + --ngpu "${_ngpu}" \ + --njob ${njob} \ + --gpuid_list ${gpuid_list} \ + --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ + --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \ + --key_file "${_logdir}"/keys.JOB.scp \ + --asr_train_config "${asr_exp}"/config.yaml \ + --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ + --output_dir "${_logdir}"/output.JOB \ + --mode asr \ + ${_opts} + + for f in token token_int score text; do + if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then + for i in $(seq "${_nj}"); do + cat "${_logdir}/output.${i}/1best_recog/${f}" + done | sort -k1 >"${_dir}/${f}" + fi + done + python utils/proce_text.py ${_dir}/text ${_dir}/text.proc + python utils/proce_text.py ${_data}/text ${_data}/text.proc + python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer + tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt + cat ${_dir}/text.cer.txt + done +fi + +# Prepare files for ModelScope fine-tuning and inference +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + echo "stage 6: ModelScope Preparation" + cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn + vocab_size=$(cat ${token_list} | wc -l) + python utils/gen_modelscope_configuration.py \ + --am_model_name $inference_asr_model \ + --mode asr \ + --model_name conformer \ + --dataset aishell \ + --output_dir $exp_dir/exp/$model_dir \ + --vocab_size $vocab_size \ + --tag $tag +fi \ No newline at end of file diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py index b0734fff4..5e9344469 100644 --- a/funasr/build_utils/build_asr_model.py +++ b/funasr/build_utils/build_asr_model.py @@ -40,6 +40,7 @@ from funasr.models.encoder.resnet34_encoder import ResNet34Diar from funasr.models.encoder.rnn_encoder import RNNEncoder from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt from funasr.models.encoder.branchformer_encoder import BranchformerEncoder +from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder from funasr.models.encoder.transformer_encoder import TransformerEncoder from funasr.models.frontend.default import DefaultFrontend from funasr.models.frontend.default import MultiChannelFrontend @@ -115,6 +116,7 @@ encoder_choices = ClassChoices( sanm_chunk_opt=SANMEncoderChunkOpt, data2vec_encoder=Data2VecEncoder, branchformer=BranchformerEncoder, + e_branchformer=EBranchformerEncoder, mfcca_enc=MFCCAEncoder, chunk_conformer=ConformerChunkEncoder, ), diff --git a/funasr/models/encoder/e_branchformer_encoder.py b/funasr/models/encoder/e_branchformer_encoder.py new file mode 100644 index 000000000..65e481db8 --- /dev/null +++ b/funasr/models/encoder/e_branchformer_encoder.py @@ -0,0 +1,467 @@ +# Copyright 2022 Kwangyoun Kim (ASAPP inc.) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""E-Branchformer encoder definition. +Reference: + Kwangyoun Kim, Felix Wu, Yifan Peng, Jing Pan, + Prashant Sridhar, Kyu J. Han, Shinji Watanabe, + "E-Branchformer: Branchformer with Enhanced merging + for speech recognition," in SLT 2022. +""" + +import logging +from typing import List, Optional, Tuple + +import torch +from typeguard import check_argument_types + +from funasr.models.ctc import CTC +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.modules.cgmlp import ConvolutionalGatingMLP +from funasr.modules.fastformer import FastSelfAttention +from funasr.modules.nets_utils import get_activation, make_pad_mask +from funasr.modules.attention import ( # noqa: H301 + LegacyRelPositionMultiHeadedAttention, + MultiHeadedAttention, + RelPositionMultiHeadedAttention, +) +from funasr.modules.embedding import ( # noqa: H301 + LegacyRelPositionalEncoding, + PositionalEncoding, + RelPositionalEncoding, + ScaledPositionalEncoding, +) +from funasr.modules.layer_norm import LayerNorm +from funasr.modules.positionwise_feed_forward import ( + PositionwiseFeedForward, +) +from funasr.modules.repeat import repeat +from funasr.modules.subsampling import ( + Conv2dSubsampling, + Conv2dSubsampling2, + Conv2dSubsampling6, + Conv2dSubsampling8, + TooShortUttError, + check_short_utt, +) + + +class EBranchformerEncoderLayer(torch.nn.Module): + """E-Branchformer encoder layer module. + + Args: + size (int): model dimension + attn: standard self-attention or efficient attention + cgmlp: ConvolutionalGatingMLP + feed_forward: feed-forward module, optional + feed_forward: macaron-style feed-forward module, optional + dropout_rate (float): dropout probability + merge_conv_kernel (int): kernel size of the depth-wise conv in merge module + """ + + def __init__( + self, + size: int, + attn: torch.nn.Module, + cgmlp: torch.nn.Module, + feed_forward: Optional[torch.nn.Module], + feed_forward_macaron: Optional[torch.nn.Module], + dropout_rate: float, + merge_conv_kernel: int = 3, + ): + super().__init__() + + self.size = size + self.attn = attn + self.cgmlp = cgmlp + + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.ff_scale = 1.0 + if self.feed_forward is not None: + self.norm_ff = LayerNorm(size) + if self.feed_forward_macaron is not None: + self.ff_scale = 0.5 + self.norm_ff_macaron = LayerNorm(size) + + self.norm_mha = LayerNorm(size) # for the MHA module + self.norm_mlp = LayerNorm(size) # for the MLP module + self.norm_final = LayerNorm(size) # for the final output of the block + + self.dropout = torch.nn.Dropout(dropout_rate) + + self.depthwise_conv_fusion = torch.nn.Conv1d( + size + size, + size + size, + kernel_size=merge_conv_kernel, + stride=1, + padding=(merge_conv_kernel - 1) // 2, + groups=size + size, + bias=True, + ) + self.merge_proj = torch.nn.Linear(size + size, size) + + def forward(self, x_input, mask, cache=None): + """Compute encoded features. + + Args: + x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. + - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. + - w/o pos emb: Tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, 1, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + """ + + if cache is not None: + raise NotImplementedError("cache is not None, which is not tested") + + if isinstance(x_input, tuple): + x, pos_emb = x_input[0], x_input[1] + else: + x, pos_emb = x_input, None + + if self.feed_forward_macaron is not None: + residual = x + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + + # Two branches + x1 = x + x2 = x + + # Branch 1: multi-headed attention module + x1 = self.norm_mha(x1) + + if isinstance(self.attn, FastSelfAttention): + x_att = self.attn(x1, mask) + else: + if pos_emb is not None: + x_att = self.attn(x1, x1, x1, pos_emb, mask) + else: + x_att = self.attn(x1, x1, x1, mask) + + x1 = self.dropout(x_att) + + # Branch 2: convolutional gating mlp + x2 = self.norm_mlp(x2) + + if pos_emb is not None: + x2 = (x2, pos_emb) + x2 = self.cgmlp(x2, mask) + if isinstance(x2, tuple): + x2 = x2[0] + + x2 = self.dropout(x2) + + # Merge two branches + x_concat = torch.cat([x1, x2], dim=-1) + x_tmp = x_concat.transpose(1, 2) + x_tmp = self.depthwise_conv_fusion(x_tmp) + x_tmp = x_tmp.transpose(1, 2) + x = x + self.dropout(self.merge_proj(x_concat + x_tmp)) + + if self.feed_forward is not None: + # feed forward module + residual = x + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + + x = self.norm_final(x) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask + + +class EBranchformerEncoder(AbsEncoder): + """E-Branchformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + attention_layer_type: str = "rel_selfattn", + pos_enc_layer_type: str = "rel_pos", + rel_pos_type: str = "latest", + cgmlp_linear_units: int = 2048, + cgmlp_conv_kernel: int = 31, + use_linear_after_conv: bool = False, + gate_activation: str = "identity", + num_blocks: int = 12, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + zero_triu: bool = False, + padding_idx: int = -1, + layer_drop_rate: float = 0.0, + max_pos_emb_len: int = 5000, + use_ffn: bool = False, + macaron_ffn: bool = False, + ffn_activation_type: str = "swish", + linear_units: int = 2048, + positionwise_layer_type: str = "linear", + merge_conv_kernel: int = 3, + interctc_layer_idx=None, + interctc_use_conditioning: bool = False, + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if rel_pos_type == "legacy": + if pos_enc_layer_type == "rel_pos": + pos_enc_layer_type = "legacy_rel_pos" + if attention_layer_type == "rel_selfattn": + attention_layer_type = "legacy_rel_selfattn" + elif rel_pos_type == "latest": + assert attention_layer_type != "legacy_rel_selfattn" + assert pos_enc_layer_type != "legacy_rel_pos" + else: + raise ValueError("unknown rel_pos_type: " + rel_pos_type) + + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert attention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "legacy_rel_pos": + assert attention_layer_type == "legacy_rel_selfattn" + pos_enc_class = LegacyRelPositionalEncoding + logging.warning( + "Using legacy_rel_pos and it will be deprecated in the future." + ) + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer is None: + if input_size == output_size: + self.embed = None + else: + self.embed = torch.nn.Linear(input_size, output_size) + else: + raise ValueError("unknown input_layer: " + input_layer) + + activation = get_activation(ffn_activation_type) + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type is None: + logging.warning("no macaron ffn") + else: + raise ValueError("Support only linear.") + + if attention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif attention_layer_type == "legacy_rel_selfattn": + assert pos_enc_layer_type == "legacy_rel_pos" + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + logging.warning( + "Using legacy_rel_selfattn and it will be deprecated in the future." + ) + elif attention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + zero_triu, + ) + elif attention_layer_type == "fast_selfattn": + assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"] + encoder_selfattn_layer = FastSelfAttention + encoder_selfattn_layer_args = ( + output_size, + attention_heads, + attention_dropout_rate, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + attention_layer_type) + + cgmlp_layer = ConvolutionalGatingMLP + cgmlp_layer_args = ( + output_size, + cgmlp_linear_units, + cgmlp_conv_kernel, + dropout_rate, + use_linear_after_conv, + gate_activation, + ) + + self.encoders = repeat( + num_blocks, + lambda lnum: EBranchformerEncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + cgmlp_layer(*cgmlp_layer_args), + positionwise_layer(*positionwise_layer_args) if use_ffn else None, + positionwise_layer(*positionwise_layer_args) + if use_ffn and macaron_ffn + else None, + dropout_rate, + merge_conv_kernel, + ), + layer_drop_rate, + ) + self.after_norm = LayerNorm(output_size) + + if interctc_layer_idx is None: + interctc_layer_idx = [] + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + max_layer: int = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Calculate forward propagation. + + Args: + xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). + ilens (torch.Tensor): Input length (#batch). + prev_states (torch.Tensor): Not to be used now. + ctc (CTC): Intermediate CTC module. + max_layer (int): Layer depth below which InterCTC is applied. + Returns: + torch.Tensor: Output tensor (#batch, L, output_size). + torch.Tensor: Output length (#batch). + torch.Tensor: Not to be used now. + """ + + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + elif self.embed is not None: + xs_pad = self.embed(xs_pad) + + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + if max_layer is not None and 0 <= max_layer < len(self.encoders): + for layer_idx, encoder_layer in enumerate(self.encoders): + xs_pad, masks = encoder_layer(xs_pad, masks) + if layer_idx >= max_layer: + break + else: + xs_pad, masks = self.encoders(xs_pad, masks) + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + xs_pad, masks = encoder_layer(xs_pad, masks) + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + + if isinstance(xs_pad, tuple): + xs_pad = list(xs_pad) + xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out) + xs_pad = tuple(xs_pad) + else: + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if isinstance(xs_pad, tuple): + xs_pad = xs_pad[0] + + xs_pad = self.after_norm(xs_pad) + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None