TOLD/SOND: update finetune and train recipe

This commit is contained in:
志浩 2023-08-02 10:59:31 +08:00
parent 5cfdcfc45a
commit bee8346c4b
2 changed files with 31 additions and 72 deletions

View File

@ -8,13 +8,18 @@
# [2] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, EMNLP 2022
# We recommend you run this script stage by stage.
# This recipe includes:
# 1. downloading a pretrained model on the simulated data from switchboard and NIST,
# 2. finetuning the pretrained model on Callhome1.
# Finally, you will get a slightly better DER result 9.95% on Callhome2 than that in the paper 10.14%.
# environment configuration
if [ ! -e utils ]; then
ln -s ../../../aishell/transformer/utils ./utils
fi
# machines configuration
gpu_devices="0,1,2,3"
gpu_devices="0,1,2,3" # for V100-16G, need 4 gpus.
gpu_num=4
count=1
@ -76,10 +81,14 @@ fi
# Download required resources
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Stage 0: Download required resources."
wget told_finetune_resources.zip
if [ ! -e told_finetune_resources.tar.gz ]; then
# MD5SUM: abc7424e4e86ce6f040e9cba4178123b
wget --no-check-certificate https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/Speaker_Diar/told_finetune_resources.tar.gz
tar zxf told_finetune_resources.tar.gz
fi
fi
# Finetune model on callhome1
# Finetune model on callhome1, this will take about 1.5 hours.
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Stage 1: Finetune pretrained model on callhome1."
world_size=$gpu_num # run on one machine
@ -230,11 +239,11 @@ fi
# Then find the wav files to construct wav.scp and put it at data/callhome2/wav.scp.
# After iteratively perform SOAP, you will get DER results like:
# iters : oracle_vad | system_vad
# iter_0: 9.68 | 10.51
# iter_1: 9.26 | 10.14 (reported in the paper)
# iter_2: 9.18 | 10.08
# iter_3: 9.24 | 10.15
# iter_4: 9.27 | 10.17
# iter_0: 9.63 | 10.43
# iter_1: 9.17 | 10.03
# iter_2: 9.11 | 9.98
# iter_3: 9.08 | 9.96
# iter_4: 9.07 | 9.95
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
if [ ! -e ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ]; then
git lfs install

View File

@ -8,6 +8,15 @@
# [2] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, EMNLP 2022
# We recommend you run this script stage by stage.
# [developing] This recipe includes:
# 1. simulating data with switchboard and NIST.
# 2. training the model from scratch for 3 stages:
# 2-1. pre-train on simu_swbd_sre
# 2-2. train on simu_swbd_sre
# 2-3. finetune on callhome1
# 3. evaluating model with the results from the first stage EEND-OLA,
# Finally, you will get a similar DER result claimed in the paper.
# environment configuration
kaldi_root=
@ -26,8 +35,8 @@ if [ ! -e utils ]; then
fi
# machines configuration
gpu_devices="6,7"
gpu_num=2
gpu_devices="4,5,6,7" # for V100-16G, use 4 GPUs
gpu_num=4
count=1
# general configuration
@ -417,7 +426,7 @@ if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
rank=$i
local_rank=$i
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
diar_train.py \
python -m funasr.bin.diar_train \
--gpu_id $gpu_id \
--use_preprocessor false \
--token_type char \
@ -565,7 +574,7 @@ if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then
rank=$i
local_rank=$i
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
diar_train.py \
python -m funasr.bin.diar_train \
--gpu_id $gpu_id \
--use_preprocessor false \
--token_type char \
@ -710,7 +719,7 @@ if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then
rank=$i
local_rank=$i
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
diar_train.py \
python -m funasr.bin.diar_train \
--gpu_id $gpu_id \
--use_preprocessor false \
--token_type char \
@ -942,62 +951,3 @@ if [ ${stage} -le 19 ] && [ ${stop_stage} -ge 19 ]; then
echo "Done."
done
fi
if [ ${stage} -le 30 ] && [ ${stop_stage} -ge 30 ]; then
echo "stage 30: training phase 1, pretraining on simulated data"
world_size=$gpu_num # run on one machine
mkdir -p ${expdir}/${model_dir}
mkdir -p ${expdir}/${model_dir}/log
mkdir -p /tmp/${model_dir}
INIT_FILE=/tmp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_opt=""
if [ ! -z "${init_param}" ]; then
init_opt="--init_param ${init_param}"
echo ${init_opt}
fi
freeze_opt=""
if [ ! -z "${freeze_param}" ]; then
freeze_opt="--freeze_param ${freeze_param}"
echo ${freeze_opt}
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
diar_train.py \
--gpu_id $gpu_id \
--use_preprocessor false \
--token_type char \
--token_list $token_list \
--dataset_type large \
--train_data_file ${datadir}/${train_set}/dumped_files/data_file.list \
--valid_data_file ${datadir}/${valid_set}/dumped_files/data_file.list \
--init_param ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/sv.pth:encoder:encoder \
--freeze_param encoder \
${init_opt} \
${freeze_opt} \
--ignore_init_mismatch true \
--resume true \
--output_dir ${expdir}/${model_dir} \
--config $train_config \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${expdir}/${model_dir}/log/train.log.$i 2>&1
} &
done
echo "Training log can be found at ${expdir}/${model_dir}/log/train.log.*"
wait
fi