mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
commit
de0ecb446f
@ -0,0 +1,6 @@
|
||||
beam_size: 1
|
||||
penalty: 0.0
|
||||
maxlenratio: 0.0
|
||||
minlenratio: 0.0
|
||||
ctc_weight: 0.0
|
||||
lm_weight: 0.15
|
||||
@ -0,0 +1,104 @@
|
||||
# network architecture
|
||||
# encoder related
|
||||
encoder: data2vec_encoder
|
||||
encoder_conf:
|
||||
extractor_mode: layer_norm
|
||||
encoder_layerdrop: 0.1
|
||||
dropout_input: 0.0
|
||||
dropout_features: 0.0
|
||||
feature_grad_mult: 0.0
|
||||
encoder_embed_dim: 768
|
||||
|
||||
mask_prob: 0.65
|
||||
mask_length: 10
|
||||
|
||||
loss_beta: 0
|
||||
loss_scale: null
|
||||
|
||||
instance_norm_target_layer: true
|
||||
average_top_k_layers: 8
|
||||
|
||||
pos_conv_depth: 5
|
||||
conv_pos: 95
|
||||
|
||||
ema_decay: 0.999
|
||||
ema_end_decay: 0.9999
|
||||
ema_anneal_end_step: 30000
|
||||
ema_transformer_only: true
|
||||
ema_layers_only: true
|
||||
|
||||
require_same_masks: true
|
||||
mask_dropout: 0
|
||||
|
||||
|
||||
# decoder related
|
||||
decoder: paraformer_decoder_san
|
||||
decoder_conf:
|
||||
attention_heads: 12
|
||||
linear_units: 3072
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
model: paraformer
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1
|
||||
length_normalized_loss: false
|
||||
predictor_weight: 1.0
|
||||
sampling_ratio: 0.4
|
||||
|
||||
# minibatch related
|
||||
batch_type: length
|
||||
batch_bins: 25000
|
||||
num_workers: 16
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 50
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- acc
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- acc
|
||||
- max
|
||||
keep_nbest_models: 10
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.00002
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 30000
|
||||
|
||||
specaug: specaug
|
||||
specaug_conf:
|
||||
apply_time_warp: true
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
num_freq_mask: 2
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 40
|
||||
num_time_mask: 2
|
||||
|
||||
predictor: cif_predictor
|
||||
predictor_conf:
|
||||
idim: 768
|
||||
threshold: 1.0
|
||||
l_order: 1
|
||||
r_order: 1
|
||||
|
||||
|
||||
log_interval: 50
|
||||
unused_parameters: true
|
||||
normalize: None
|
||||
252
egs/aishell/data2vec_paraformer_finetune/run.sh
Executable file
252
egs/aishell/data2vec_paraformer_finetune/run.sh
Executable file
@ -0,0 +1,252 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
# machines configuration
|
||||
CUDA_VISIBLE_DEVICES="0,1"
|
||||
gpu_num=2
|
||||
count=1
|
||||
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
|
||||
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
|
||||
njob=5
|
||||
train_cmd=utils/run.pl
|
||||
infer_cmd=utils/run.pl
|
||||
|
||||
# general configuration
|
||||
feats_dir="../DATA" #feature output dictionary, for large data
|
||||
exp_dir="."
|
||||
lang=zh
|
||||
dumpdir=dump/fbank
|
||||
feats_type=fbank
|
||||
token_type=char
|
||||
scp=feats.scp
|
||||
type=kaldi_ark
|
||||
stage=0
|
||||
stop_stage=4
|
||||
|
||||
# feature configuration
|
||||
feats_dim=80
|
||||
sample_frequency=16000
|
||||
nj=32
|
||||
speed_perturb="0.9,1.0,1.1"
|
||||
|
||||
# data
|
||||
data_aishell=
|
||||
|
||||
# exp tag
|
||||
tag=""
|
||||
|
||||
model_name=damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch
|
||||
init_param="$HOME/.cache/modelscope/hub/$model_name/basemodel.pb"
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
train_set=train
|
||||
valid_set=dev
|
||||
test_sets="dev test"
|
||||
|
||||
asr_config=conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
|
||||
inference_asr_model=valid.acc.ave_10best.pth
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
|
||||
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
|
||||
|
||||
if ${gpu_inference}; then
|
||||
inference_nj=$[${ngpu}*${njob}]
|
||||
_ngpu=1
|
||||
else
|
||||
inference_nj=$njob
|
||||
_ngpu=0
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "stage 0: Data preparation"
|
||||
# Data preparation
|
||||
local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
|
||||
for x in train dev test; do
|
||||
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
|
||||
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
|
||||
> ${feats_dir}/data/${x}/text
|
||||
utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
|
||||
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
|
||||
done
|
||||
fi
|
||||
|
||||
feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
|
||||
feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
|
||||
feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "stage 1: Feature Generation"
|
||||
# compute fbank features
|
||||
fbankdir=${feats_dir}/fbank
|
||||
utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
|
||||
${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
|
||||
utils/fix_data_feat.sh ${fbankdir}/train
|
||||
utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
|
||||
${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
|
||||
utils/fix_data_feat.sh ${fbankdir}/dev
|
||||
utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
|
||||
${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
|
||||
utils/fix_data_feat.sh ${fbankdir}/test
|
||||
|
||||
# compute global cmvn
|
||||
utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
|
||||
${fbankdir}/train ${exp_dir}/exp/make_fbank/train
|
||||
|
||||
# apply cmvn
|
||||
utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
|
||||
${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
|
||||
utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
|
||||
${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
|
||||
utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
|
||||
${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
|
||||
|
||||
cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
|
||||
cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
|
||||
cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
|
||||
|
||||
utils/fix_data_feat.sh ${feat_train_dir}
|
||||
utils/fix_data_feat.sh ${feat_dev_dir}
|
||||
utils/fix_data_feat.sh ${feat_test_dir}
|
||||
|
||||
#generate ark list
|
||||
utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
|
||||
utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
|
||||
fi
|
||||
|
||||
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
|
||||
echo "dictionary: ${token_list}"
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "stage 2: Dictionary Preparation"
|
||||
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
|
||||
|
||||
echo "make a dictionary"
|
||||
echo "<blank>" > ${token_list}
|
||||
echo "<s>" >> ${token_list}
|
||||
echo "</s>" >> ${token_list}
|
||||
utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
|
||||
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
|
||||
num_token=$(cat ${token_list} | wc -l)
|
||||
echo "<unk>" >> ${token_list}
|
||||
vocab_size=$(cat ${token_list} | wc -l)
|
||||
awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
|
||||
awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
|
||||
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
|
||||
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
|
||||
cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
|
||||
cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
|
||||
fi
|
||||
|
||||
# Training Stage
|
||||
world_size=$gpu_num # run on one machine
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "stage 3: Training"
|
||||
python utils/download_model.py --model_name ${model_name} # download pretrained model on ModelScope
|
||||
mkdir -p ${exp_dir}/exp/${model_dir}
|
||||
mkdir -p ${exp_dir}/exp/${model_dir}/log
|
||||
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
|
||||
asr_train_paraformer.py \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor true \
|
||||
--token_type char \
|
||||
--token_list $token_list \
|
||||
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
|
||||
--train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
|
||||
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
|
||||
--train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
|
||||
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
|
||||
--valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
|
||||
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
|
||||
--valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
|
||||
--init_param ${init_param} \
|
||||
--resume true \
|
||||
--output_dir ${exp_dir}/exp/${model_dir} \
|
||||
--config $asr_config \
|
||||
--input_size $feats_dim \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--multiprocessing_distributed true \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
# Testing Stage
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "stage 4: Inference"
|
||||
for dset in ${test_sets}; do
|
||||
asr_exp=${exp_dir}/exp/${model_dir}
|
||||
inference_tag="$(basename "${inference_config}" .yaml)"
|
||||
_dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
if [ -d ${_dir} ]; then
|
||||
echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
|
||||
exit 0
|
||||
fi
|
||||
mkdir -p "${_logdir}"
|
||||
_data="${feats_dir}/${dumpdir}/${dset}"
|
||||
key_file=${_data}/${scp}
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
_opts=
|
||||
if [ -n "${inference_config}" ]; then
|
||||
_opts+="--config ${inference_config} "
|
||||
fi
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
|
||||
python -m funasr.bin.asr_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--asr_train_config "${asr_exp}"/config.yaml \
|
||||
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode paraformer \
|
||||
${_opts}
|
||||
|
||||
for f in token token_int score text; do
|
||||
if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
|
||||
for i in $(seq "${_nj}"); do
|
||||
cat "${_logdir}/output.${i}/1best_recog/${f}"
|
||||
done | sort -k1 >"${_dir}/${f}"
|
||||
fi
|
||||
done
|
||||
python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
|
||||
python utils/proce_text.py ${_data}/text ${_data}/text.proc
|
||||
python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
|
||||
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
|
||||
cat ${_dir}/text.cer.txt
|
||||
done
|
||||
fi
|
||||
1
egs/aishell/data2vec_paraformer_finetune/utils
Symbolic link
1
egs/aishell/data2vec_paraformer_finetune/utils
Symbolic link
@ -0,0 +1 @@
|
||||
../../aishell/transformer/utils
|
||||
66
egs/aishell/data2vec_transformer_finetune/local/aishell_data_prep.sh
Executable file
66
egs/aishell/data2vec_transformer_finetune/local/aishell_data_prep.sh
Executable file
@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2017 Xingyu Na
|
||||
# Apache 2.0
|
||||
|
||||
#. ./path.sh || exit 1;
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: $0 <audio-path> <text-path> <output-path>"
|
||||
echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
aishell_audio_dir=$1
|
||||
aishell_text=$2/aishell_transcript_v0.8.txt
|
||||
output_dir=$3
|
||||
|
||||
train_dir=$output_dir/data/local/train
|
||||
dev_dir=$output_dir/data/local/dev
|
||||
test_dir=$output_dir/data/local/test
|
||||
tmp_dir=$output_dir/data/local/tmp
|
||||
|
||||
mkdir -p $train_dir
|
||||
mkdir -p $dev_dir
|
||||
mkdir -p $test_dir
|
||||
mkdir -p $tmp_dir
|
||||
|
||||
# data directory check
|
||||
if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
|
||||
echo "Error: $0 requires two directory arguments"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
# find wav audio file for train, dev and test resp.
|
||||
find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
|
||||
n=`cat $tmp_dir/wav.flist | wc -l`
|
||||
[ $n -ne 141925 ] && \
|
||||
echo Warning: expected 141925 data data files, found $n
|
||||
|
||||
grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
|
||||
grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
|
||||
grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
|
||||
|
||||
rm -r $tmp_dir
|
||||
|
||||
# Transcriptions preparation
|
||||
for dir in $train_dir $dev_dir $test_dir; do
|
||||
echo Preparing $dir transcriptions
|
||||
sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
|
||||
paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
|
||||
utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
|
||||
awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
|
||||
utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
|
||||
sort -u $dir/transcripts.txt > $dir/text
|
||||
done
|
||||
|
||||
mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
|
||||
|
||||
for f in wav.scp text; do
|
||||
cp $train_dir/$f $output_dir/data/train/$f || exit 1;
|
||||
cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
|
||||
cp $test_dir/$f $output_dir/data/test/$f || exit 1;
|
||||
done
|
||||
|
||||
echo "$0: AISHELL data preparation succeeded"
|
||||
exit 0;
|
||||
53
egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
Executable file
53
egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
Executable file
@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
|
||||
# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
|
||||
# Apache 2.0
|
||||
|
||||
# transform raw AISHELL-2 data to kaldi format
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
tmp=
|
||||
dir=
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
|
||||
echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
corpus=$1
|
||||
tmp=$2
|
||||
dir=$3
|
||||
|
||||
echo "prepare_data.sh: Preparing data in $corpus"
|
||||
|
||||
mkdir -p $tmp
|
||||
mkdir -p $dir
|
||||
|
||||
# corpus check
|
||||
if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
|
||||
echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
# validate utt-key list, IC0803W0380 is a bad utterance
|
||||
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
|
||||
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
|
||||
utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
|
||||
|
||||
# wav.scp
|
||||
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
|
||||
utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
|
||||
|
||||
# text
|
||||
utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
|
||||
|
||||
# copy prepared resources from tmp_dir to target dir
|
||||
mkdir -p $dir
|
||||
for f in wav.scp text; do
|
||||
cp $tmp/$f $dir/$f || exit 1;
|
||||
done
|
||||
|
||||
echo "local/prepare_data.sh succeeded"
|
||||
exit 0;
|
||||
5
egs/aishell/data2vec_transformer_finetune/path.sh
Executable file
5
egs/aishell/data2vec_transformer_finetune/path.sh
Executable file
@ -0,0 +1,5 @@
|
||||
export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
@ -0,0 +1,79 @@
|
||||
# network architecture
|
||||
# encoder related
|
||||
encoder: data2vec_encoder
|
||||
encoder_conf:
|
||||
extractor_mode: layer_norm
|
||||
encoder_layerdrop: 0.05
|
||||
dropout_input: 0.0
|
||||
dropout_features: 0.0
|
||||
feature_grad_mult: 1.0
|
||||
encoder_embed_dim: 768
|
||||
|
||||
mask_prob: 0.65
|
||||
mask_length: 10
|
||||
|
||||
loss_beta: 0
|
||||
loss_scale: null
|
||||
|
||||
instance_norm_target_layer: true
|
||||
average_top_k_layers: 8
|
||||
|
||||
pos_conv_depth: 5
|
||||
conv_pos: 95
|
||||
|
||||
ema_decay: 0.999
|
||||
ema_end_decay: 0.9999
|
||||
ema_anneal_end_step: 30000
|
||||
ema_transformer_only: true
|
||||
ema_layers_only: true
|
||||
|
||||
require_same_masks: true
|
||||
mask_dropout: 0
|
||||
|
||||
log_interval: 50
|
||||
normalize: None
|
||||
|
||||
# minibatch related
|
||||
batch_type: length
|
||||
batch_bins: 64000
|
||||
num_workers: 16
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
patience: none
|
||||
max_epoch: 600
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- acc
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- loss
|
||||
- min
|
||||
keep_nbest_models: 50
|
||||
unused_parameters: true
|
||||
|
||||
optim: fairseq_adam
|
||||
optim_conf:
|
||||
lr: 0.0005
|
||||
adam_betas: [0.9,0.98]
|
||||
adam_eps: 1.0e-06
|
||||
weight_decay: 0.01
|
||||
|
||||
scheduler: tri_stage
|
||||
scheduler_conf:
|
||||
phase_ratio: [0.03,0.9,0.07]
|
||||
|
||||
# for dataset
|
||||
dataset_conf:
|
||||
batch_mode: clipping
|
||||
data_names: speech,none
|
||||
data_types: kaldi_ark,none
|
||||
shuffle: true
|
||||
shuffle_conf:
|
||||
shuffle_size: 12800
|
||||
sort_size: 12800
|
||||
batch_conf:
|
||||
batch_type: token
|
||||
batch_size: 64000
|
||||
num_workers: 8
|
||||
53
egs/aishell2/data2vec_pretrain/local/prepare_data.sh
Executable file
53
egs/aishell2/data2vec_pretrain/local/prepare_data.sh
Executable file
@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
|
||||
# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
|
||||
# Apache 2.0
|
||||
|
||||
# transform raw AISHELL-2 data to kaldi format
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
tmp=
|
||||
dir=
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
|
||||
echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
corpus=$1
|
||||
tmp=$2
|
||||
dir=$3
|
||||
|
||||
echo "prepare_data.sh: Preparing data in $corpus"
|
||||
|
||||
mkdir -p $tmp
|
||||
mkdir -p $dir
|
||||
|
||||
# corpus check
|
||||
if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
|
||||
echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
# validate utt-key list, IC0803W0380 is a bad utterance
|
||||
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
|
||||
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
|
||||
tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
|
||||
|
||||
# wav.scp
|
||||
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
|
||||
tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
|
||||
|
||||
# text
|
||||
tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
|
||||
|
||||
# copy prepared resources from tmp_dir to target dir
|
||||
mkdir -p $dir
|
||||
for f in wav.scp text; do
|
||||
cp $tmp/$f $dir/$f || exit 1;
|
||||
done
|
||||
|
||||
echo "local/prepare_data.sh succeeded"
|
||||
exit 0;
|
||||
6
egs/aishell2/data2vec_pretrain/path.sh
Executable file
6
egs/aishell2/data2vec_pretrain/path.sh
Executable file
@ -0,0 +1,6 @@
|
||||
export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=../../../:$PYTHONPATH
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
172
egs/aishell2/data2vec_pretrain/run.sh
Executable file
172
egs/aishell2/data2vec_pretrain/run.sh
Executable file
@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
# machines configuration
|
||||
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
gpu_num=8
|
||||
count=1
|
||||
|
||||
train_cmd=tools/run.pl
|
||||
|
||||
# general configuration
|
||||
feats_dir="../DATA" #feature output dictionary
|
||||
exp_dir="."
|
||||
lang=zh
|
||||
dumpdir=dump/fbank
|
||||
feats_type=fbank
|
||||
token_type=char
|
||||
dataset_type=large
|
||||
stage=0
|
||||
stop_stage=4
|
||||
|
||||
# feature configuration
|
||||
feats_dim=80
|
||||
sample_frequency=16000
|
||||
nj=100
|
||||
speed_perturb="0.9,1.0,1.1"
|
||||
|
||||
# data
|
||||
tr_dir=
|
||||
dev_tst_dir=
|
||||
|
||||
# exp tag
|
||||
tag="exp1"
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
# Set bash to 'debug' mode, it will exit on :
|
||||
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
|
||||
set -e
|
||||
set -u
|
||||
set -o pipefail
|
||||
|
||||
train_set=train
|
||||
valid_set=dev_ios
|
||||
|
||||
asr_config=conf/train_pretrain_transformer.yaml
|
||||
model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "stage 0: Data preparation"
|
||||
# For training set
|
||||
local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
|
||||
# # For dev and test set
|
||||
for x in Android iOS Mic; do
|
||||
local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
|
||||
local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
|
||||
done
|
||||
# Normalize text to capital letters
|
||||
for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
|
||||
mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
|
||||
paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
|
||||
| tr 'A-Z' 'a-z' | tr -d " ") \
|
||||
> ${feats_dir}/data/${x}/text
|
||||
tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
|
||||
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
|
||||
done
|
||||
fi
|
||||
|
||||
feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
|
||||
feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "stage 1: Feature Generation"
|
||||
# compute fbank features
|
||||
fbankdir=${feats_dir}/fbank
|
||||
steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
|
||||
${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
|
||||
tools/fix_data_feat.sh ${fbankdir}/train
|
||||
for x in android ios mic; do
|
||||
steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
|
||||
${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
|
||||
tools/fix_data_feat.sh ${fbankdir}/dev_${x}
|
||||
steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
|
||||
${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
|
||||
tools/fix_data_feat.sh ${fbankdir}/test_${x}
|
||||
done
|
||||
|
||||
# compute global cmvn
|
||||
steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
|
||||
${fbankdir}/train ${exp_dir}/exp/make_fbank/train
|
||||
|
||||
# apply cmvn
|
||||
steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
|
||||
${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
|
||||
steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
|
||||
${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
|
||||
for x in android ios mic; do
|
||||
steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
|
||||
${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
|
||||
done
|
||||
|
||||
cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
|
||||
tools/fix_data_feat.sh ${feat_train_dir}
|
||||
cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
|
||||
tools/fix_data_feat.sh ${feat_dev_dir}
|
||||
for x in android ios mic; do
|
||||
cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
|
||||
tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
|
||||
done
|
||||
fi
|
||||
|
||||
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
|
||||
echo "dictionary: ${token_list}"
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "stage 2: Dictionary Preparation"
|
||||
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
|
||||
|
||||
echo "make a dictionary"
|
||||
echo "<blank>" > ${token_list}
|
||||
echo "<s>" >> ${token_list}
|
||||
echo "</s>" >> ${token_list}
|
||||
tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
|
||||
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
|
||||
num_token=$(cat ${token_list} | wc -l)
|
||||
echo "<unk>" >> ${token_list}
|
||||
vocab_size=$(cat ${token_list} | wc -l)
|
||||
awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
|
||||
awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
|
||||
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
|
||||
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
|
||||
cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
|
||||
cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
|
||||
fi
|
||||
|
||||
# Training Stage
|
||||
world_size=$gpu_num # run on one machine
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "stage 3: Training"
|
||||
mkdir -p ${exp_dir}/exp/${model_dir}
|
||||
mkdir -p ${exp_dir}/exp/${model_dir}/log
|
||||
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
|
||||
data2vec_train.py \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor true \
|
||||
--dataset_type $dataset_type \
|
||||
--train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
|
||||
--valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
|
||||
--resume true \
|
||||
--output_dir ${exp_dir}/exp/${model_dir} \
|
||||
--config $asr_config \
|
||||
--input_size $feats_dim \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--multiprocessing_distributed true \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
fi
|
||||
1
egs/aishell2/data2vec_pretrain/utils
Symbolic link
1
egs/aishell2/data2vec_pretrain/utils
Symbolic link
@ -0,0 +1 @@
|
||||
../../aishell/transformer/utils
|
||||
@ -181,7 +181,7 @@ class Speech2Text:
|
||||
self.nbest = nbest
|
||||
self.frontend = frontend
|
||||
self.encoder_downsampling_factor = 1
|
||||
if asr_train_args.encoder_conf["input_layer"] == "conv2d":
|
||||
if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
|
||||
self.encoder_downsampling_factor = 4
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
45
funasr/bin/data2vec_train.py
Executable file
45
funasr/bin/data2vec_train.py
Executable file
@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
|
||||
from funasr.tasks.data2vec import Data2VecTask
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = Data2VecTask.get_parser()
|
||||
parser.add_argument(
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="local gpu id.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args=None, cmd=None):
|
||||
# for data2vec Training
|
||||
Data2VecTask.main(args=args, cmd=cmd)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
# setup local gpu_id
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
|
||||
# DDP settings
|
||||
if args.ngpu > 1:
|
||||
args.distributed = True
|
||||
else:
|
||||
args.distributed = False
|
||||
assert args.num_worker_count == 1
|
||||
|
||||
# re-compute batch size: when dataset type is small
|
||||
if args.dataset_type == "small":
|
||||
if args.batch_size is not None:
|
||||
args.batch_size = args.batch_size * args.ngpu
|
||||
if args.batch_bins is not None:
|
||||
args.batch_bins = args.batch_bins * args.ngpu
|
||||
|
||||
main(args=args)
|
||||
@ -78,6 +78,58 @@ def common_collate_fn(
|
||||
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
|
||||
def crop_to_max_size(feature, target_size):
|
||||
size = len(feature)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return feature
|
||||
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return feature[start:end]
|
||||
|
||||
|
||||
def clipping_collate_fn(
|
||||
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
||||
max_sample_size=None,
|
||||
not_sequence: Collection[str] = (),
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
# mainly for pre-training
|
||||
assert check_argument_types()
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
||||
assert all(
|
||||
not k.endswith("_lengths") for k in data[0]
|
||||
), f"*_lengths is reserved: {list(data[0])}"
|
||||
|
||||
output = {}
|
||||
for key in data[0]:
|
||||
array_list = [d[key] for d in data]
|
||||
tensor_list = [torch.from_numpy(a) for a in array_list]
|
||||
sizes = [len(s) for s in tensor_list]
|
||||
if max_sample_size is None:
|
||||
target_size = min(sizes)
|
||||
else:
|
||||
target_size = min(min(sizes), max_sample_size)
|
||||
tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
|
||||
for i, (source, size) in enumerate(zip(tensor_list, sizes)):
|
||||
diff = size - target_size
|
||||
if diff == 0:
|
||||
tensor[i] = source
|
||||
else:
|
||||
tensor[i] = crop_to_max_size(source, target_size)
|
||||
output[key] = tensor
|
||||
|
||||
if key not in not_sequence:
|
||||
lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
@ -35,15 +35,16 @@ def load_seg_dict(seg_dict_file):
|
||||
|
||||
class ArkDataLoader(AbsIterFactory):
|
||||
def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"):
|
||||
symbol_table = read_symbol_table(dict_file)
|
||||
symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
|
||||
if seg_dict_file is not None:
|
||||
seg_dict = load_seg_dict(seg_dict_file)
|
||||
else:
|
||||
seg_dict = None
|
||||
self.dataset_conf = dataset_conf
|
||||
logging.info("dataloader config: {}".format(self.dataset_conf))
|
||||
batch_mode = self.dataset_conf.get("batch_mode", "padding")
|
||||
self.dataset = Dataset(data_list, symbol_table, seg_dict,
|
||||
self.dataset_conf, mode=mode)
|
||||
self.dataset_conf, mode=mode, batch_mode=batch_mode)
|
||||
|
||||
def build_iter(self, epoch, shuffle=True):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
@ -24,7 +24,8 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
||||
batch_size=8000,
|
||||
len_fn=_default_len_fn,
|
||||
buffer_size=10240,
|
||||
sort_size=500
|
||||
sort_size=500,
|
||||
batch_mode="padding",
|
||||
):
|
||||
assert batch_size > 0, "Batch size is required to be larger than 0!"
|
||||
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
|
||||
@ -35,6 +36,7 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
||||
self.batch_size = batch_size
|
||||
self.buffer_size = buffer_size
|
||||
self.sort_size = sort_size
|
||||
self.batch_mode = batch_mode
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
@ -44,55 +46,137 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
||||
batch = []
|
||||
bucket = []
|
||||
max_lengths = 0
|
||||
min_lengths = 999999
|
||||
batch_lengths = 0
|
||||
|
||||
if self.buffer_size == -1:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
buffer.sort()
|
||||
for sample in buffer:
|
||||
length, _, token = sample
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
bucket.append(batch)
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
random.shuffle(bucket)
|
||||
if bucket:
|
||||
for batch_sample in bucket:
|
||||
yield batch_sample
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
elif self.buffer_size == 0:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
length, _, token = d
|
||||
if length > self.batch_size:
|
||||
continue
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
else:
|
||||
if self.batch_mode == "clipping":
|
||||
assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
if len(buffer) == self.buffer_size:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length < min_lengths:
|
||||
min_lengths = length
|
||||
batch_lengths = min_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
min_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if buffer:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length < min_lengths:
|
||||
min_lengths = length
|
||||
batch_lengths = min_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
min_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if bucket:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length < min_lengths:
|
||||
min_lengths = length
|
||||
batch_lengths = min_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
min_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
else:
|
||||
if self.buffer_size == -1:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
buffer.sort()
|
||||
for sample in buffer:
|
||||
length, _, token = sample
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
bucket.append(batch)
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
random.shuffle(bucket)
|
||||
if bucket:
|
||||
for batch_sample in bucket:
|
||||
yield batch_sample
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
elif self.buffer_size == 0:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
length, _, token = d
|
||||
if length > self.batch_size:
|
||||
continue
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
else:
|
||||
for d in self.datapipe:
|
||||
if d[0] > self.batch_size:
|
||||
continue
|
||||
buffer.append(d)
|
||||
if len(buffer) == self.buffer_size:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if buffer:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
@ -111,38 +195,19 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
||||
bucket = []
|
||||
buffer = []
|
||||
|
||||
if buffer:
|
||||
random.shuffle(buffer)
|
||||
for sample in buffer:
|
||||
bucket.append(sample)
|
||||
if len(bucket) == self.sort_size:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
buffer = []
|
||||
if bucket:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
|
||||
if bucket:
|
||||
bucket.sort()
|
||||
for x in bucket:
|
||||
length, _, token = x
|
||||
if length > max_lengths:
|
||||
max_lengths = length
|
||||
batch_lengths = max_lengths * (len(batch) + 1)
|
||||
if batch_lengths > self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
max_lengths = length
|
||||
batch.append(token)
|
||||
bucket = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
@ -13,6 +13,7 @@ from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
|
||||
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
|
||||
from funasr.datasets.large_datasets.utils.filter import filter
|
||||
from funasr.datasets.large_datasets.utils.padding import padding
|
||||
from funasr.datasets.large_datasets.utils.clipping import clipping
|
||||
from funasr.datasets.large_datasets.utils.tokenize import tokenize
|
||||
|
||||
|
||||
@ -101,6 +102,8 @@ class AudioDataset(IterableDataset):
|
||||
elif data_type == "text" or data_type == "sound":
|
||||
text_reader = open(data_file, "r")
|
||||
reader_list.append(text_reader)
|
||||
elif data_type == "none":
|
||||
continue
|
||||
else:
|
||||
raise TypeError("Data type {} is not supported".format(data_type))
|
||||
|
||||
@ -143,7 +146,8 @@ def Dataset(data_list_file,
|
||||
dict,
|
||||
seg_dict,
|
||||
conf,
|
||||
mode="train"):
|
||||
mode="train",
|
||||
batch_mode="padding"):
|
||||
scp_lists = read_lists(data_list_file)
|
||||
shuffle = conf.get('shuffle', True)
|
||||
data_names = conf.get("data_names", "speech,text")
|
||||
@ -154,9 +158,10 @@ def Dataset(data_list_file,
|
||||
filter_fn = partial(filter, **filter_conf)
|
||||
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
|
||||
|
||||
vocab = {'vocab': dict, 'seg_dict': seg_dict}
|
||||
tokenize_fn = partial(tokenize, **vocab)
|
||||
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
|
||||
if "text" in data_names:
|
||||
vocab = {'vocab': dict, 'seg_dict': seg_dict}
|
||||
tokenize_fn = partial(tokenize, **vocab)
|
||||
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
|
||||
|
||||
if shuffle:
|
||||
buffer_conf = conf.get('shuffle_conf', {})
|
||||
@ -180,8 +185,9 @@ def Dataset(data_list_file,
|
||||
batch_size=batch_size,
|
||||
len_fn=len_fn,
|
||||
buffer_size=buffer_size,
|
||||
sort_size=sort_size)
|
||||
sort_size=sort_size,
|
||||
batch_mode=batch_mode)
|
||||
|
||||
dataset = MapperIterDataPipe(dataset, fn=padding)
|
||||
dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping)
|
||||
|
||||
return dataset
|
||||
|
||||
40
funasr/datasets/large_datasets/utils/clipping.py
Normal file
40
funasr/datasets/large_datasets/utils/clipping.py
Normal file
@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr.datasets.collate_fn import crop_to_max_size
|
||||
|
||||
|
||||
def clipping(data):
|
||||
assert isinstance(data, list)
|
||||
assert "key" in data[0]
|
||||
|
||||
keys = [x["key"] for x in data]
|
||||
|
||||
batch = {}
|
||||
data_names = data[0].keys()
|
||||
for data_name in data_names:
|
||||
if data_name == "key":
|
||||
continue
|
||||
else:
|
||||
if data[0][data_name].dtype.kind == "i":
|
||||
tensor_type = torch.int64
|
||||
else:
|
||||
tensor_type = torch.float32
|
||||
|
||||
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
|
||||
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
|
||||
|
||||
length_clip = min(tensor_lengths)
|
||||
tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
|
||||
for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
|
||||
diff = length - length_clip
|
||||
assert diff >= 0
|
||||
if diff == 0:
|
||||
tensor_clip[i] = tensor
|
||||
else:
|
||||
tensor_clip[i] = crop_to_max_size(tensor, length_clip)
|
||||
|
||||
batch[data_name] = tensor_clip
|
||||
batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
|
||||
|
||||
return keys, batch
|
||||
@ -6,13 +6,21 @@ def filter(data,
|
||||
speech_length_max=15000,
|
||||
token_length_min=0,
|
||||
token_length_max=200):
|
||||
assert "speech" in data
|
||||
assert "text" in data
|
||||
assert "speech" in data or "text" in data
|
||||
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||
if "speech" in data and "text" in data:
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
num_tokens = len(data['text'])
|
||||
return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
|
||||
elif "speech" in data:
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
return speech_length_min < speech_length < speech_length_max
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
num_tokens = len(data['text'])
|
||||
|
||||
return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
|
||||
num_tokens = len(data['text'])
|
||||
return token_length_min < num_tokens < token_length_max
|
||||
|
||||
160
funasr/models/data2vec.py
Normal file
160
funasr/models/data2vec.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
|
||||
|
||||
class Data2VecPretrainModel(AbsESPnetModel):
|
||||
"""Data2Vec Pretrain model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.encoder = encoder
|
||||
self.num_updates = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape)
|
||||
|
||||
self.encoder.set_num_updates(self.num_updates)
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out = self.encode(speech, speech_lengths)
|
||||
|
||||
losses = encoder_out["losses"]
|
||||
loss = sum(losses.values())
|
||||
sample_size = encoder_out["sample_size"]
|
||||
loss = loss.sum() / sample_size
|
||||
|
||||
target_var = float(encoder_out["target_var"])
|
||||
pred_var = float(encoder_out["pred_var"])
|
||||
ema_decay = float(encoder_out["ema_decay"])
|
||||
|
||||
stats = dict(
|
||||
loss=torch.clone(loss.detach()),
|
||||
target_var=target_var,
|
||||
pred_var=pred_var,
|
||||
ema_decay=ema_decay,
|
||||
)
|
||||
|
||||
loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
):
|
||||
"""Frontend + Encoder.
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
# 4. Forward encoder
|
||||
if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
|
||||
speech_lengths = None
|
||||
encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
|
||||
|
||||
return encoder_out
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
self.num_updates = num_updates
|
||||
|
||||
def get_num_updates(self):
|
||||
return self.num_updates
|
||||
148
funasr/optimizers/fairseq_adam.py
Normal file
148
funasr/optimizers/fairseq_adam.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.optim
|
||||
|
||||
|
||||
class FairseqAdam(torch.optim.Optimizer):
|
||||
r"""Implements Adam algorithm.
|
||||
|
||||
This implementation is modified from torch.optim.Adam based on:
|
||||
`Fixed Weight Decay Regularization in Adam`
|
||||
(see https://arxiv.org/abs/1711.05101)
|
||||
|
||||
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
adam_betas=(0.9, 0.999),
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr, betas=adam_betas, eps=adam_eps, weight_decay=weight_decay, amsgrad=amsgrad
|
||||
)
|
||||
super(FairseqAdam, self).__init__(params, defaults)
|
||||
self.optimizer_lr = lr
|
||||
|
||||
@property
|
||||
def supports_memory_efficient_fp16(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_flat_params(self):
|
||||
return True
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||
grad = grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"Adam does not support sparse gradients, please consider SparseAdam instead"
|
||||
)
|
||||
amsgrad = group.get("amsgrad", False)
|
||||
|
||||
p_data_fp32 = p.data
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_data_fp32 = p_data_fp32.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
|
||||
if amsgrad:
|
||||
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
|
||||
p_data_fp32
|
||||
)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
if amsgrad:
|
||||
max_exp_avg_sq = state["max_exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
|
||||
else:
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(
|
||||
p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
|
||||
)
|
||||
|
||||
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
def set_lr(self, lr):
|
||||
"""Set the learning rate."""
|
||||
for param_group in self.param_groups:
|
||||
param_group["lr"] = lr
|
||||
108
funasr/schedulers/tri_stage_scheduler.py
Normal file
108
funasr/schedulers/tri_stage_scheduler.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
|
||||
|
||||
|
||||
class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
last_epoch: int = -1,
|
||||
phase_ratio: Optional[List[float]] = None,
|
||||
init_lr_scale: float = 0.01,
|
||||
final_lr_scale: float = 0.01,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.optimizer = optimizer
|
||||
self.last_epoch = last_epoch
|
||||
self.phase_ratio = phase_ratio
|
||||
self.init_lr_scale = init_lr_scale
|
||||
self.final_lr_scale = final_lr_scale
|
||||
self.optimizer_lr = self.optimizer.defaults["lr"]
|
||||
|
||||
def init_tri_stage_scheudler(self, max_update):
|
||||
self.max_update = max_update
|
||||
self.peak_lr = self.optimizer_lr
|
||||
self.init_lr = self.init_lr_scale * self.optimizer_lr
|
||||
self.final_lr = self.final_lr_scale * self.optimizer_lr
|
||||
|
||||
assert self.max_update > 0
|
||||
assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1"
|
||||
assert len(self.phase_ratio) == 3
|
||||
self.warmup_steps = int(self.max_update * self.phase_ratio[0])
|
||||
self.hold_steps = int(self.max_update * self.phase_ratio[1])
|
||||
self.decay_steps = int(self.max_update * self.phase_ratio[2])
|
||||
|
||||
self.warmup_rate = (
|
||||
(self.peak_lr - self.init_lr) / self.warmup_steps
|
||||
if self.warmup_steps != 0
|
||||
else 0
|
||||
)
|
||||
self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps
|
||||
|
||||
# initial learning rate
|
||||
self.lr = self.init_lr
|
||||
|
||||
# __init__() must be invoked before setting field
|
||||
# because step() is also invoked in __init__()
|
||||
self.set_optimizer_lr(self.lr)
|
||||
super().__init__(self.optimizer, self.last_epoch)
|
||||
|
||||
def _decide_stage(self, update_step):
|
||||
"""
|
||||
return stage, and the corresponding steps within the current stage
|
||||
"""
|
||||
if update_step < self.warmup_steps:
|
||||
# warmup state
|
||||
return 0, update_step
|
||||
|
||||
offset = self.warmup_steps
|
||||
|
||||
if update_step < offset + self.hold_steps:
|
||||
# hold stage
|
||||
return 1, update_step - offset
|
||||
|
||||
offset += self.hold_steps
|
||||
|
||||
if update_step <= offset + self.decay_steps:
|
||||
# decay stage
|
||||
return 2, update_step - offset
|
||||
|
||||
offset += self.decay_steps
|
||||
|
||||
# still here ? constant lr stage
|
||||
return 3, update_step - offset
|
||||
|
||||
def step_update(self, num_updates):
|
||||
"""Update the learning rate after each update."""
|
||||
stage, steps_in_stage = self._decide_stage(num_updates)
|
||||
if stage == 0:
|
||||
self.lr = self.init_lr + self.warmup_rate * steps_in_stage
|
||||
elif stage == 1:
|
||||
self.lr = self.peak_lr
|
||||
elif stage == 2:
|
||||
self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
|
||||
elif stage == 3:
|
||||
self.lr = self.final_lr
|
||||
else:
|
||||
raise ValueError("Undefined stage")
|
||||
self.set_optimizer_lr(self.lr)
|
||||
|
||||
def set_optimizer_lr(self, lr):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
def get_lr(self):
|
||||
step_num = self.last_epoch + 1
|
||||
self.step_update(step_num)
|
||||
return [self.lr]
|
||||
@ -44,11 +44,13 @@ from funasr.iterators.chunk_iter_factory import ChunkIterFactory
|
||||
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
|
||||
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
|
||||
from funasr.optimizers.sgd import SGD
|
||||
from funasr.optimizers.fairseq_adam import FairseqAdam
|
||||
from funasr.samplers.build_batch_sampler import BATCH_TYPES
|
||||
from funasr.samplers.build_batch_sampler import build_batch_sampler
|
||||
from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
|
||||
from funasr.schedulers.noam_lr import NoamLR
|
||||
from funasr.schedulers.warmup_lr import WarmupLR
|
||||
from funasr.schedulers.tri_stage_scheduler import TriStageLR
|
||||
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
|
||||
from funasr.torch_utils.model_summary import model_summary
|
||||
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
|
||||
@ -83,6 +85,7 @@ else:
|
||||
|
||||
optim_classes = dict(
|
||||
adam=torch.optim.Adam,
|
||||
fairseq_adam=FairseqAdam,
|
||||
adamw=torch.optim.AdamW,
|
||||
sgd=SGD,
|
||||
adadelta=torch.optim.Adadelta,
|
||||
@ -149,6 +152,7 @@ scheduler_classes = dict(
|
||||
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
|
||||
noamlr=NoamLR,
|
||||
warmuplr=WarmupLR,
|
||||
tri_stage=TriStageLR,
|
||||
cycliclr=torch.optim.lr_scheduler.CyclicLR,
|
||||
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
|
||||
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
|
||||
|
||||
376
funasr/tasks/data2vec.py
Normal file
376
funasr/tasks/data2vec.py
Normal file
@ -0,0 +1,376 @@
|
||||
import argparse
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.preprocessor import CommonPreprocessor
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||
from funasr.models.data2vec import Data2VecPretrainModel
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.windowing import SlidingWindow
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.preencoder.sinc import LightweightSincConvs
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
from funasr.train.trainer import Trainer
|
||||
from funasr.utils.types import float_or_none
|
||||
from funasr.utils.types import int_or_none
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str_or_none
|
||||
|
||||
frontend_choices = ClassChoices(
|
||||
name="frontend",
|
||||
classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
|
||||
type_check=AbsFrontend,
|
||||
default="default",
|
||||
)
|
||||
specaug_choices = ClassChoices(
|
||||
name="specaug",
|
||||
classes=dict(specaug=SpecAug),
|
||||
type_check=AbsSpecAug,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
normalize_choices = ClassChoices(
|
||||
"normalize",
|
||||
classes=dict(
|
||||
global_mvn=GlobalMVN,
|
||||
utterance_mvn=UtteranceMVN,
|
||||
),
|
||||
type_check=AbsNormalize,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
preencoder_choices = ClassChoices(
|
||||
name="preencoder",
|
||||
classes=dict(
|
||||
sinc=LightweightSincConvs,
|
||||
),
|
||||
type_check=AbsPreEncoder,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
encoder_choices = ClassChoices(
|
||||
"encoder",
|
||||
classes=dict(
|
||||
data2vec_encoder=Data2VecEncoder,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
default="data2vec_encoder",
|
||||
)
|
||||
model_choices = ClassChoices(
|
||||
"model",
|
||||
classes=dict(
|
||||
data2vec=Data2VecPretrainModel,
|
||||
),
|
||||
default="data2vec",
|
||||
)
|
||||
|
||||
|
||||
class Data2VecTask(AbsTask):
|
||||
# If you need more than one optimizers, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --preencoder and --preencoder_conf
|
||||
preencoder_choices,
|
||||
# --encoder and --encoder_conf
|
||||
encoder_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def add_task_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(description="Task related")
|
||||
|
||||
# NOTE(kamo): add_arguments(..., required=True) can't be used
|
||||
# to provide --print_config mode. Instead of it, do as
|
||||
group.add_argument(
|
||||
"--token_list",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="A text mapping int-id to token",
|
||||
)
|
||||
group.add_argument(
|
||||
"--init",
|
||||
type=lambda x: str_or_none(x.lower()),
|
||||
default=None,
|
||||
help="The initialization method",
|
||||
choices=[
|
||||
"chainer",
|
||||
"xavier_uniform",
|
||||
"xavier_normal",
|
||||
"kaiming_uniform",
|
||||
"kaiming_normal",
|
||||
None,
|
||||
],
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input_size",
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help="The number of input dimension of the feature",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group(description="Preprocess related")
|
||||
group.add_argument(
|
||||
"--use_preprocessor",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Apply preprocessing to data or not",
|
||||
)
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["bpe", "char", "word", "phn"],
|
||||
help="The text will be tokenized " "in the specified level token",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--feats_type",
|
||||
type=str,
|
||||
default='fbank',
|
||||
help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--bpemodel",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The model file of sentencepiece",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non_linguistic_symbols",
|
||||
type=str_or_none,
|
||||
help="non_linguistic_symbols file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cleaner",
|
||||
type=str_or_none,
|
||||
choices=[None, "tacotron", "jaconv", "vietnamese"],
|
||||
default=None,
|
||||
help="Apply text cleaning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g2p",
|
||||
type=str_or_none,
|
||||
choices=g2p_choices,
|
||||
default=None,
|
||||
help="Specify g2p method if --token_type=phn",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speech_volume_normalize",
|
||||
type=float_or_none,
|
||||
default=None,
|
||||
help="Scale the maximum amplitude to the given value.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of rir scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="THe probability for applying RIR convolution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of noise scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The probability applying Noise adding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_db_range",
|
||||
type=str,
|
||||
default="13_15",
|
||||
help="The range of noise decibel level.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pred_masked_weight",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="weight for predictive loss for masked frames",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pred_nomask_weight",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="weight for predictive loss for unmasked frames",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loss_weights",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="weights for additional loss terms (not first one)",
|
||||
)
|
||||
|
||||
for class_choices in cls.class_choices_list:
|
||||
# Append --<name> and --<name>_conf.
|
||||
# e.g. --encoder and --encoder_conf
|
||||
class_choices.add_arguments(group)
|
||||
|
||||
@classmethod
|
||||
def build_collate_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Callable[
|
||||
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
|
||||
Tuple[List[str], Dict[str, torch.Tensor]],
|
||||
]:
|
||||
assert check_argument_types()
|
||||
return CommonCollateFn(clipping=True)
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||
assert check_argument_types()
|
||||
if args.use_preprocessor:
|
||||
retval = CommonPreprocessor(
|
||||
train=train,
|
||||
bpemodel=args.bpemodel,
|
||||
non_linguistic_symbols=args.non_linguistic_symbols,
|
||||
text_cleaner=args.cleaner,
|
||||
g2p_type=args.g2p,
|
||||
# NOTE(kamo): Check attribute existence for backward compatibility
|
||||
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
|
||||
rir_apply_prob=args.rir_apply_prob
|
||||
if hasattr(args, "rir_apply_prob")
|
||||
else 1.0,
|
||||
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
|
||||
noise_apply_prob=args.noise_apply_prob
|
||||
if hasattr(args, "noise_apply_prob")
|
||||
else 1.0,
|
||||
noise_db_range=args.noise_db_range
|
||||
if hasattr(args, "noise_db_range")
|
||||
else "13_15",
|
||||
speech_volume_normalize=args.speech_volume_normalize
|
||||
if hasattr(args, "rir_scp")
|
||||
else None,
|
||||
)
|
||||
else:
|
||||
retval = None
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def required_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
# for pre-training
|
||||
retval = ("speech",)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def optional_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
retval = ()
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# 4. Pre-encoder input block
|
||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility
|
||||
if getattr(args, "preencoder", None) is not None:
|
||||
preencoder_class = preencoder_choices.get_class(args.preencoder)
|
||||
preencoder = preencoder_class(**args.preencoder_conf)
|
||||
input_size = preencoder.output_size()
|
||||
else:
|
||||
preencoder = None
|
||||
|
||||
# 5. Encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(
|
||||
input_size=input_size,
|
||||
**args.encoder_conf,
|
||||
)
|
||||
|
||||
# 6. Build model
|
||||
try:
|
||||
model_class = model_choices.get_class(args.model)
|
||||
except AttributeError:
|
||||
model_class = model_choices.get_class("data2vec")
|
||||
model = model_class(
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
preencoder=preencoder,
|
||||
encoder=encoder,
|
||||
)
|
||||
|
||||
# 7. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
Loading…
Reference in New Issue
Block a user