This commit is contained in:
嘉渊 2023-07-05 20:36:27 +08:00
parent 60fdbbf63d
commit 964a50d246
11 changed files with 1335 additions and 0 deletions

View File

@ -0,0 +1,17 @@
# Conformer Result
## Training Config
- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
- Train config: conf/train_asr_transformer.yaml
- LM config: LM was not used
- Model size: 46M
## Results (CER)
- Decode config: conf/decode_asr_transformer.yaml (ctc weight:0.5)
| testset | CER(%) |
|:-----------:|:-------:|
| dev | 4.42 |
| test | 4.87 |

View File

@ -0,0 +1,6 @@
beam_size: 10
penalty: 0.0
maxlenratio: 0.0
minlenratio: 0.0
ctc_weight: 0.4
lm_weight: 0.0

View File

@ -0,0 +1,86 @@
# network architecture
# encoder related
encoder: branchformer
encoder_conf:
output_size: 256
use_attn: true
attention_heads: 4
attention_layer_type: rel_selfattn
pos_enc_layer_type: rel_pos
rel_pos_type: latest
use_cgmlp: true
cgmlp_linear_units: 2048
cgmlp_conv_kernel: 31
use_linear_after_conv: false
gate_activation: identity
merge_method: concat
cgmlp_weight: 0.5 # used only if merge_method is "fixed_ave"
attn_branch_drop_rate: 0.0 # used only if merge_method is "learned_ave"
num_blocks: 24
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d
stochastic_depth_rate: 0.0
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.
src_attention_dropout_rate: 0.
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
# minibatch related
batch_type: numel
batch_bins: 25000000
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 60
val_scheduler_criterion:
- valid
- acc
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10
optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 35000
num_workers: 4 # num of workers of data loader
use_amp: true # automatic mixed precision
unused_parameters: false # set as true if some params are unused in DDP
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 10

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2017 Xingyu Na
# Apache 2.0
#. ./path.sh || exit 1;
if [ $# != 3 ]; then
echo "Usage: $0 <audio-path> <text-path> <output-path>"
echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
exit 1;
fi
aishell_audio_dir=$1
aishell_text=$2/aishell_transcript_v0.8.txt
output_dir=$3
train_dir=$output_dir/data/local/train
dev_dir=$output_dir/data/local/dev
test_dir=$output_dir/data/local/test
tmp_dir=$output_dir/data/local/tmp
mkdir -p $train_dir
mkdir -p $dev_dir
mkdir -p $test_dir
mkdir -p $tmp_dir
# data directory check
if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
echo "Error: $0 requires two directory arguments"
exit 1;
fi
# find wav audio file for train, dev and test resp.
find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
n=`cat $tmp_dir/wav.flist | wc -l`
[ $n -ne 141925 ] && \
echo Warning: expected 141925 data data files, found $n
grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
rm -r $tmp_dir
# Transcriptions preparation
for dir in $train_dir $dev_dir $test_dir; do
echo Preparing $dir transcriptions
sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
sort -u $dir/transcripts.txt > $dir/text
done
mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
for f in wav.scp text; do
cp $train_dir/$f $output_dir/data/train/$f || exit 1;
cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
cp $test_dir/$f $output_dir/data/test/$f || exit 1;
done
echo "$0: AISHELL data preparation succeeded"
exit 0;

View File

@ -0,0 +1,105 @@
#!/usr/bin/env bash
# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
# 2017 Xingyu Na
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
remove_archive=true
shift
fi
if [ $# -ne 3 ]; then
echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
echo "With --remove-archive it will remove the archive after successfully un-tarring it."
echo "<corpus-part> can be one of: data_aishell, resource_aishell."
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
echo "$0: no such directory $data"
exit 1;
fi
part_ok=false
list="data_aishell resource_aishell"
for x in $list; do
if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
exit 1;
fi
if [ -z "$url" ]; then
echo "$0: empty URL base."
exit 1;
fi
if [ -f $data/$part/.complete ]; then
echo "$0: data part $part was already successfully extracted, nothing to do."
exit 0;
fi
# sizes of the archive files in bytes.
sizes="15582913665 1246920"
if [ -f $data/$part.tgz ]; then
size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
size_ok=false
for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
if ! $size_ok; then
echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
echo "does not equal the size of one of the archives."
rm $data/$part.tgz
else
echo "$data/$part.tgz exists and appears to be complete."
fi
fi
if [ ! -f $data/$part.tgz ]; then
if ! command -v wget >/dev/null; then
echo "$0: wget is not installed."
exit 1;
fi
full_url=$url/$part.tgz
echo "$0: downloading data from $full_url. This may take some time, please be patient."
cd $data || exit 1
if ! wget --no-check-certificate $full_url; then
echo "$0: error executing wget $full_url"
exit 1;
fi
fi
cd $data || exit 1
if ! tar -xvzf $part.tgz; then
echo "$0: error un-tarring archive $data/$part.tgz"
exit 1;
fi
touch $data/$part/.complete
if [ $part == "data_aishell" ]; then
cd $data/$part/wav || exit 1
for wav in ./*.tar.gz; do
echo "Extracting wav from $wav"
tar -zxf $wav && rm $wav
done
fi
echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
if $remove_archive; then
echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
rm $data/$part.tgz
fi
exit 0;

View File

@ -0,0 +1,5 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH

225
egs/aishell/branchformer/run.sh Executable file
View File

@ -0,0 +1,225 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="0,1"
gpu_num=2
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=0
stop_stage=5
# feature configuration
feats_dim=80
nj=64
# data
raw_data=../raw_data
data_url=www.openslr.org/resources/33
# exp tag
tag="exp1"
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train
valid_set=dev
test_sets="dev test"
asr_config=conf/train_asr_branchformer.yaml
model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
inference_nj=$[${ngpu}*${njob}]
_ngpu=1
else
inference_nj=$njob
_ngpu=0
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Data Download"
local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
> ${feats_dir}/data/${x}/text
utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: Feature and CMVN Generation"
utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
fi
token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
echo "<unk>" >> ${token_list}
fi
# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: LM Training"
fi
# ASR Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
train.py \
--task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type $token_type \
--token_list $token_list \
--data_dir ${feats_dir}/data \
--train_set ${train_set} \
--valid_set ${valid_set} \
--data_file_names "wav.scp,text" \
--cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
--speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
--ngpu $gpu_num \
--num_worker_count $count \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
} &
done
wait
fi
# Testing Stage
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
_dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
exit 0
fi
mkdir -p "${_logdir}"
_data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
_opts=
if [ -n "${inference_config}" ]; then
_opts+="--config ${inference_config} "
fi
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
--cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
--output_dir "${_logdir}"/output.JOB \
--mode asr \
${_opts}
for f in token token_int score text; do
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/1best_recog/${f}"
done | sort -k1 >"${_dir}/${f}"
fi
done
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
python utils/proce_text.py ${_data}/text ${_data}/text.proc
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
fi
# Prepare files for ModelScope fine-tuning and inference
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
echo "stage 6: ModelScope Preparation"
cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn
vocab_size=$(cat ${token_list} | wc -l)
python utils/gen_modelscope_configuration.py \
--am_model_name $inference_asr_model \
--mode asr \
--model_name conformer \
--dataset aishell \
--output_dir $exp_dir/exp/$model_dir \
--vocab_size $vocab_size \
--tag $tag
fi

View File

@ -0,0 +1 @@
../transformer/utils

View File

@ -0,0 +1,547 @@
# Copyright 2022 Yifan Peng (Carnegie Mellon University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Branchformer encoder definition.
Reference:
Yifan Peng, Siddharth Dalmia, Ian Lane, and Shinji Watanabe,
Branchformer: Parallel MLP-Attention Architectures to Capture
Local and Global Context for Speech Recognition and Understanding,
in Proceedings of ICML, 2022.
"""
import logging
from typing import List, Optional, Tuple, Union
import numpy
import torch
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.cgmlp import ConvolutionalGatingMLP
from funasr.modules.fastformer import FastSelfAttention
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import ( # noqa: H301
LegacyRelPositionMultiHeadedAttention,
MultiHeadedAttention,
RelPositionMultiHeadedAttention,
)
from funasr.modules.embedding import ( # noqa: H301
LegacyRelPositionalEncoding,
PositionalEncoding,
RelPositionalEncoding,
ScaledPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
TooShortUttError,
check_short_utt,
)
class BranchformerEncoderLayer(torch.nn.Module):
"""Branchformer encoder layer module.
Args:
size (int): model dimension
attn: standard self-attention or efficient attention, optional
cgmlp: ConvolutionalGatingMLP, optional
dropout_rate (float): dropout probability
merge_method (str): concat, learned_ave, fixed_ave
cgmlp_weight (float): weight of the cgmlp branch, between 0 and 1,
used if merge_method is fixed_ave
attn_branch_drop_rate (float): probability of dropping the attn branch,
used if merge_method is learned_ave
stochastic_depth_rate (float): stochastic depth probability
"""
def __init__(
self,
size: int,
attn: Optional[torch.nn.Module],
cgmlp: Optional[torch.nn.Module],
dropout_rate: float,
merge_method: str,
cgmlp_weight: float = 0.5,
attn_branch_drop_rate: float = 0.0,
stochastic_depth_rate: float = 0.0,
):
super().__init__()
assert (attn is not None) or (
cgmlp is not None
), "At least one branch should be valid"
self.size = size
self.attn = attn
self.cgmlp = cgmlp
self.merge_method = merge_method
self.cgmlp_weight = cgmlp_weight
self.attn_branch_drop_rate = attn_branch_drop_rate
self.stochastic_depth_rate = stochastic_depth_rate
self.use_two_branches = (attn is not None) and (cgmlp is not None)
if attn is not None:
self.norm_mha = LayerNorm(size) # for the MHA module
if cgmlp is not None:
self.norm_mlp = LayerNorm(size) # for the MLP module
self.norm_final = LayerNorm(size) # for the final output of the block
self.dropout = torch.nn.Dropout(dropout_rate)
if self.use_two_branches:
if merge_method == "concat":
self.merge_proj = torch.nn.Linear(size + size, size)
elif merge_method == "learned_ave":
# attention-based pooling for two branches
self.pooling_proj1 = torch.nn.Linear(size, 1)
self.pooling_proj2 = torch.nn.Linear(size, 1)
# linear projections for calculating merging weights
self.weight_proj1 = torch.nn.Linear(size, 1)
self.weight_proj2 = torch.nn.Linear(size, 1)
# linear projection after weighted average
self.merge_proj = torch.nn.Linear(size, size)
elif merge_method == "fixed_ave":
assert (
0.0 <= cgmlp_weight <= 1.0
), "cgmlp weight should be between 0.0 and 1.0"
# remove the other branch if only one branch is used
if cgmlp_weight == 0.0:
self.use_two_branches = False
self.cgmlp = None
self.norm_mlp = None
elif cgmlp_weight == 1.0:
self.use_two_branches = False
self.attn = None
self.norm_mha = None
# linear projection after weighted average
self.merge_proj = torch.nn.Linear(size, size)
else:
raise ValueError(f"unknown merge method: {merge_method}")
else:
self.merge_proj = torch.nn.Identity()
def forward(self, x_input, mask, cache=None):
"""Compute encoded features.
Args:
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- w/o pos emb: Tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
if cache is not None:
raise NotImplementedError("cache is not None, which is not tested")
if isinstance(x_input, tuple):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
# Two branches
x1 = x
x2 = x
# Branch 1: multi-headed attention module
if self.attn is not None:
x1 = self.norm_mha(x1)
if isinstance(self.attn, FastSelfAttention):
x_att = self.attn(x1, mask)
else:
if pos_emb is not None:
x_att = self.attn(x1, x1, x1, pos_emb, mask)
else:
x_att = self.attn(x1, x1, x1, mask)
x1 = self.dropout(x_att)
# Branch 2: convolutional gating mlp
if self.cgmlp is not None:
x2 = self.norm_mlp(x2)
if pos_emb is not None:
x2 = (x2, pos_emb)
x2 = self.cgmlp(x2, mask)
if isinstance(x2, tuple):
x2 = x2[0]
x2 = self.dropout(x2)
# Merge two branches
if self.use_two_branches:
if self.merge_method == "concat":
x = x + stoch_layer_coeff * self.dropout(
self.merge_proj(torch.cat([x1, x2], dim=-1))
)
elif self.merge_method == "learned_ave":
if (
self.training
and self.attn_branch_drop_rate > 0
and torch.rand(1).item() < self.attn_branch_drop_rate
):
# Drop the attn branch
w1, w2 = 0.0, 1.0
else:
# branch1
score1 = (
self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5
) # (batch, 1, time)
if mask is not None:
min_value = float(
numpy.finfo(
torch.tensor(0, dtype=score1.dtype).numpy().dtype
).min
)
score1 = score1.masked_fill(mask.eq(0), min_value)
score1 = torch.softmax(score1, dim=-1).masked_fill(
mask.eq(0), 0.0
)
else:
score1 = torch.softmax(score1, dim=-1)
pooled1 = torch.matmul(score1, x1).squeeze(1) # (batch, size)
weight1 = self.weight_proj1(pooled1) # (batch, 1)
# branch2
score2 = (
self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5
) # (batch, 1, time)
if mask is not None:
min_value = float(
numpy.finfo(
torch.tensor(0, dtype=score2.dtype).numpy().dtype
).min
)
score2 = score2.masked_fill(mask.eq(0), min_value)
score2 = torch.softmax(score2, dim=-1).masked_fill(
mask.eq(0), 0.0
)
else:
score2 = torch.softmax(score2, dim=-1)
pooled2 = torch.matmul(score2, x2).squeeze(1) # (batch, size)
weight2 = self.weight_proj2(pooled2) # (batch, 1)
# normalize weights of two branches
merge_weights = torch.softmax(
torch.cat([weight1, weight2], dim=-1), dim=-1
) # (batch, 2)
merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
-1
) # (batch, 2, 1, 1)
w1, w2 = merge_weights[:, 0], merge_weights[:, 1] # (batch, 1, 1)
x = x + stoch_layer_coeff * self.dropout(
self.merge_proj(w1 * x1 + w2 * x2)
)
elif self.merge_method == "fixed_ave":
x = x + stoch_layer_coeff * self.dropout(
self.merge_proj(
(1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
)
)
else:
raise RuntimeError(f"unknown merge method: {self.merge_method}")
else:
if self.attn is None:
x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2))
elif self.cgmlp is None:
x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1))
else:
# This should not happen
raise RuntimeError("Both branches are not None, which is unexpected.")
x = self.norm_final(x)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
class BranchformerEncoder(AbsEncoder):
"""Branchformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
use_attn: bool = True,
attention_heads: int = 4,
attention_layer_type: str = "rel_selfattn",
pos_enc_layer_type: str = "rel_pos",
rel_pos_type: str = "latest",
use_cgmlp: bool = True,
cgmlp_linear_units: int = 2048,
cgmlp_conv_kernel: int = 31,
use_linear_after_conv: bool = False,
gate_activation: str = "identity",
merge_method: str = "concat",
cgmlp_weight: Union[float, List[float]] = 0.5,
attn_branch_drop_rate: Union[float, List[float]] = 0.0,
num_blocks: int = 12,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
zero_triu: bool = False,
padding_idx: int = -1,
stochastic_depth_rate: Union[float, List[float]] = 0.0,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if rel_pos_type == "legacy":
if pos_enc_layer_type == "rel_pos":
pos_enc_layer_type = "legacy_rel_pos"
if attention_layer_type == "rel_selfattn":
attention_layer_type = "legacy_rel_selfattn"
elif rel_pos_type == "latest":
assert attention_layer_type != "legacy_rel_selfattn"
assert pos_enc_layer_type != "legacy_rel_pos"
else:
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert attention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "legacy_rel_pos":
assert attention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
else:
raise ValueError("unknown input_layer: " + input_layer)
if attention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif attention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
logging.warning(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
elif attention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
zero_triu,
)
elif attention_layer_type == "fast_selfattn":
assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
encoder_selfattn_layer = FastSelfAttention
encoder_selfattn_layer_args = (
output_size,
attention_heads,
attention_dropout_rate,
)
else:
raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
cgmlp_layer = ConvolutionalGatingMLP
cgmlp_layer_args = (
output_size,
cgmlp_linear_units,
cgmlp_conv_kernel,
dropout_rate,
use_linear_after_conv,
gate_activation,
)
if isinstance(stochastic_depth_rate, float):
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
if len(stochastic_depth_rate) != num_blocks:
raise ValueError(
f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
f"should be equal to num_blocks ({num_blocks})"
)
if isinstance(cgmlp_weight, float):
cgmlp_weight = [cgmlp_weight] * num_blocks
if len(cgmlp_weight) != num_blocks:
raise ValueError(
f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to "
f"num_blocks ({num_blocks})"
)
if isinstance(attn_branch_drop_rate, float):
attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks
if len(attn_branch_drop_rate) != num_blocks:
raise ValueError(
f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
f"should be equal to num_blocks ({num_blocks})"
)
self.encoders = repeat(
num_blocks,
lambda lnum: BranchformerEncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args)
if use_attn
else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate,
merge_method,
cgmlp_weight[lnum],
attn_branch_drop_rate[lnum],
stochastic_depth_rate[lnum],
),
)
self.after_norm = LayerNorm(output_size)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
elif self.embed is not None:
xs_pad = self.embed(xs_pad)
xs_pad, masks = self.encoders(xs_pad, masks)
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None

124
funasr/modules/cgmlp.py Normal file
View File

@ -0,0 +1,124 @@
"""MLP with convolutional gating (cgMLP) definition.
References:
https://openreview.net/forum?id=RA-zVvZLYIy
https://arxiv.org/abs/2105.08050
"""
import torch
from funasr.modules.nets_utils import get_activation
from funasr.modules.layer_norm import LayerNorm
class ConvolutionalSpatialGatingUnit(torch.nn.Module):
"""Convolutional Spatial Gating Unit (CSGU)."""
def __init__(
self,
size: int,
kernel_size: int,
dropout_rate: float,
use_linear_after_conv: bool,
gate_activation: str,
):
super().__init__()
n_channels = size // 2 # split input channels
self.norm = LayerNorm(n_channels)
self.conv = torch.nn.Conv1d(
n_channels,
n_channels,
kernel_size,
1,
(kernel_size - 1) // 2,
groups=n_channels,
)
if use_linear_after_conv:
self.linear = torch.nn.Linear(n_channels, n_channels)
else:
self.linear = None
if gate_activation == "identity":
self.act = torch.nn.Identity()
else:
self.act = get_activation(gate_activation)
self.dropout = torch.nn.Dropout(dropout_rate)
def espnet_initialization_fn(self):
torch.nn.init.normal_(self.conv.weight, std=1e-6)
torch.nn.init.ones_(self.conv.bias)
if self.linear is not None:
torch.nn.init.normal_(self.linear.weight, std=1e-6)
torch.nn.init.ones_(self.linear.bias)
def forward(self, x, gate_add=None):
"""Forward method
Args:
x (torch.Tensor): (N, T, D)
gate_add (torch.Tensor): (N, T, D/2)
Returns:
out (torch.Tensor): (N, T, D/2)
"""
x_r, x_g = x.chunk(2, dim=-1)
x_g = self.norm(x_g) # (N, T, D/2)
x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
if self.linear is not None:
x_g = self.linear(x_g)
if gate_add is not None:
x_g = x_g + gate_add
x_g = self.act(x_g)
out = x_r * x_g # (N, T, D/2)
out = self.dropout(out)
return out
class ConvolutionalGatingMLP(torch.nn.Module):
"""Convolutional Gating MLP (cgMLP)."""
def __init__(
self,
size: int,
linear_units: int,
kernel_size: int,
dropout_rate: float,
use_linear_after_conv: bool,
gate_activation: str,
):
super().__init__()
self.channel_proj1 = torch.nn.Sequential(
torch.nn.Linear(size, linear_units), torch.nn.GELU()
)
self.csgu = ConvolutionalSpatialGatingUnit(
size=linear_units,
kernel_size=kernel_size,
dropout_rate=dropout_rate,
use_linear_after_conv=use_linear_after_conv,
gate_activation=gate_activation,
)
self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
def forward(self, x, mask):
if isinstance(x, tuple):
xs_pad, pos_emb = x
else:
xs_pad, pos_emb = x, None
xs_pad = self.channel_proj1(xs_pad) # size -> linear_units
xs_pad = self.csgu(xs_pad) # linear_units -> linear_units/2
xs_pad = self.channel_proj2(xs_pad) # linear_units/2 -> size
if pos_emb is not None:
out = (xs_pad, pos_emb)
else:
out = xs_pad
return out

View File

@ -0,0 +1,153 @@
"""Fastformer attention definition.
Reference:
Wu et al., "Fastformer: Additive Attention Can Be All You Need"
https://arxiv.org/abs/2108.09084
https://github.com/wuch15/Fastformer
"""
import numpy
import torch
class FastSelfAttention(torch.nn.Module):
"""Fast self-attention used in Fastformer."""
def __init__(
self,
size,
attention_heads,
dropout_rate,
):
super().__init__()
if size % attention_heads != 0:
raise ValueError(
f"Hidden size ({size}) is not an integer multiple "
f"of attention heads ({attention_heads})"
)
self.attention_head_size = size // attention_heads
self.num_attention_heads = attention_heads
self.query = torch.nn.Linear(size, size)
self.query_att = torch.nn.Linear(size, attention_heads)
self.key = torch.nn.Linear(size, size)
self.key_att = torch.nn.Linear(size, attention_heads)
self.transform = torch.nn.Linear(size, size)
self.dropout = torch.nn.Dropout(dropout_rate)
def espnet_initialization_fn(self):
self.apply(self.init_weights)
def init_weights(self, module):
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, torch.nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def transpose_for_scores(self, x):
"""Reshape and transpose to compute scores.
Args:
x: (batch, time, size = n_heads * attn_dim)
Returns:
(batch, n_heads, time, attn_dim)
"""
new_x_shape = x.shape[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
return x.reshape(*new_x_shape).transpose(1, 2)
def forward(self, xs_pad, mask):
"""Forward method.
Args:
xs_pad: (batch, time, size = n_heads * attn_dim)
mask: (batch, 1, time), nonpadding is 1, padding is 0
Returns:
torch.Tensor: (batch, time, size)
"""
batch_size, seq_len, _ = xs_pad.shape
mixed_query_layer = self.query(xs_pad) # (batch, time, size)
mixed_key_layer = self.key(xs_pad) # (batch, time, size)
if mask is not None:
mask = mask.eq(0) # padding is 1, nonpadding is 0
# (batch, n_heads, time)
query_for_score = (
self.query_att(mixed_query_layer).transpose(1, 2)
/ self.attention_head_size**0.5
)
if mask is not None:
min_value = float(
numpy.finfo(
torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
).min
)
query_for_score = query_for_score.masked_fill(mask, min_value)
query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
else:
query_weight = torch.softmax(query_for_score, dim=-1)
query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time)
query_layer = self.transpose_for_scores(
mixed_query_layer
) # (batch, n_heads, time, attn_dim)
pooled_query = (
torch.matmul(query_weight, query_layer)
.transpose(1, 2)
.reshape(-1, 1, self.num_attention_heads * self.attention_head_size)
) # (batch, 1, size = n_heads * attn_dim)
pooled_query = self.dropout(pooled_query)
pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size)
mixed_query_key_layer = (
mixed_key_layer * pooled_query_repeat
) # (batch, time, size)
# (batch, n_heads, time)
query_key_score = (
self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5
).transpose(1, 2)
if mask is not None:
min_value = float(
numpy.finfo(
torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
).min
)
query_key_score = query_key_score.masked_fill(mask, min_value)
query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
mask, 0.0
)
else:
query_key_weight = torch.softmax(query_key_score, dim=-1)
query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time)
key_layer = self.transpose_for_scores(
mixed_query_key_layer
) # (batch, n_heads, time, attn_dim)
pooled_key = torch.matmul(
query_key_weight, key_layer
) # (batch, n_heads, 1, attn_dim)
pooled_key = self.dropout(pooled_key)
# NOTE: value = query, due to param sharing
weighted_value = (pooled_key * query_layer).transpose(
1, 2
) # (batch, time, n_heads, attn_dim)
weighted_value = weighted_value.reshape(
weighted_value.shape[:-2]
+ (self.num_attention_heads * self.attention_head_size,)
) # (batch, time, size)
weighted_value = (
self.dropout(self.transform(weighted_value)) + mixed_query_layer
)
return weighted_value