Merge pull request #799 from alibaba-damo-academy/dev_dzh

Recipe for TOLD/SOND speaker diarization model
This commit is contained in:
Zhihao Du 2023-08-03 10:15:11 +08:00 committed by GitHub
commit c63486e0b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 82918 additions and 352 deletions

View File

@ -0,0 +1,133 @@
model: sond
model_conf:
lsm_weight: 0.0
length_normalized_loss: true
max_spk_num: 16
normalize_speech_speaker: true
speaker_discrimination_loss_weight: 0
inter_score_loss_weight: 0.1
model_regularizer_weight: 0.0
freeze_encoder: true
onfly_shuffle_speaker: false
# label aggregator
label_aggregator: label_aggregator_max_pool
label_aggregator_conf:
hop_length: 8
# speech encoder
encoder: resnet34_sp_l2reg
encoder_conf:
# pass by model, equal to feature dim
# input_size: 80
batchnorm_momentum: 0.01
pooling_type: "window_shift"
pool_size: 20
stride: 1
tf2torch_tensor_name_prefix_torch: encoder
tf2torch_tensor_name_prefix_tf: EAND/speech_encoder
speaker_encoder: null
speaker_encoder_conf: {}
ci_scorer: conv
ci_scorer_conf:
input_units: 512
num_layers: 3
num_units: 512
kernel_size: 1
dropout_rate: 0.0
position_encoder: null
out_units: 1
out_norm: false
auxiliary_states: false
tf2torch_tensor_name_prefix_torch: ci_scorer
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/ci_scorer
cd_scorer: san
cd_scorer_conf:
input_size: 512
output_size: 512
out_units: 1
attention_heads: 4
linear_units: 1024
num_blocks: 4
dropout_rate: 0.0
positional_dropout_rate: 0.0
attention_dropout_rate: 0.0
# use string "null" to remove input layer
input_layer: "null"
pos_enc_class: null
normalize_before: true
tf2torch_tensor_name_prefix_torch: cd_scorer
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/cd_scorer
# post net
decoder: fsmn
decoder_conf:
in_units: 32
out_units: 2517
filter_size: 31
fsmn_num_layers: 6
dnn_num_layers: 1
num_memory_units: 16
ffn_inner_dim: 512
dropout_rate: 0.0
tf2torch_tensor_name_prefix_torch: decoder
tf2torch_tensor_name_prefix_tf: EAND/post_net
input_size: 80
frontend: null
frontend_conf:
fs: 16000
window: povey
n_mels: 80
frame_length: 25
frame_shift: 10
filter_length_min: -1
filter_length_max: -1
lfr_m: 1
lfr_n: 1
dither: 0.0
snip_edges: false
upsacle_samples: false
# minibatch related
batch_type: unsorted
# 16 samples
batch_size: 8
num_workers: 8
max_epoch: 20
num_iters_per_epoch: 10000
keep_nbest_models: 20
# optimization related
accum_grad: 1
grad_clip: 5.0
val_scheduler_criterion:
- valid
- der
- min
best_model_criterion:
- - valid
- der
- min
- - valid
- forward_steps
- max
optim: adamw
optim_conf:
lr: 1.0
betas: [0.9, 0.998]
weight_decay: 0
scheduler: noamlr
scheduler_conf:
model_size: 512
warmup_steps: 10000
# without spec aug
specaug: null
log_interval: 50
# without normalize
normalize: null

View File

@ -0,0 +1,133 @@
model: sond
model_conf:
lsm_weight: 0.0
length_normalized_loss: true
max_spk_num: 16
normalize_speech_speaker: true
speaker_discrimination_loss_weight: 0
inter_score_loss_weight: 0.1
model_regularizer_weight: 0.0
freeze_encoder: false
onfly_shuffle_speaker: false
# label aggregator
label_aggregator: label_aggregator_max_pool
label_aggregator_conf:
hop_length: 8
# speech encoder
encoder: resnet34_sp_l2reg
encoder_conf:
# pass by model, equal to feature dim
# input_size: 80
batchnorm_momentum: 0.01
pooling_type: "window_shift"
pool_size: 20
stride: 1
tf2torch_tensor_name_prefix_torch: encoder
tf2torch_tensor_name_prefix_tf: EAND/speech_encoder
speaker_encoder: null
speaker_encoder_conf: {}
ci_scorer: conv
ci_scorer_conf:
input_units: 512
num_layers: 3
num_units: 512
kernel_size: 1
dropout_rate: 0.0
position_encoder: null
out_units: 1
out_norm: false
auxiliary_states: false
tf2torch_tensor_name_prefix_torch: ci_scorer
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/ci_scorer
cd_scorer: san
cd_scorer_conf:
input_size: 512
output_size: 512
out_units: 1
attention_heads: 4
linear_units: 1024
num_blocks: 4
dropout_rate: 0.0
positional_dropout_rate: 0.0
attention_dropout_rate: 0.0
# use string "null" to remove input layer
input_layer: "null"
pos_enc_class: null
normalize_before: true
tf2torch_tensor_name_prefix_torch: cd_scorer
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/cd_scorer
# post net
decoder: fsmn
decoder_conf:
in_units: 32
out_units: 2517
filter_size: 31
fsmn_num_layers: 6
dnn_num_layers: 1
num_memory_units: 16
ffn_inner_dim: 512
dropout_rate: 0.0
tf2torch_tensor_name_prefix_torch: decoder
tf2torch_tensor_name_prefix_tf: EAND/post_net
input_size: 80
frontend: null
frontend_conf:
fs: 16000
window: povey
n_mels: 80
frame_length: 25
frame_shift: 10
filter_length_min: -1
filter_length_max: -1
lfr_m: 1
lfr_n: 1
dither: 0.0
snip_edges: false
upsacle_samples: false
# minibatch related
batch_type: unsorted
# 6 samples
batch_size: 6
num_workers: 8
max_epoch: 30
num_iters_per_epoch: 10000
keep_nbest_models: 30
# optimization related
accum_grad: 1
grad_clip: 5.0
val_scheduler_criterion:
- valid
- der
- min
best_model_criterion:
- - valid
- der
- min
- - valid
- forward_steps
- max
optim: adamw
optim_conf:
lr: 0.0001
betas: [0.9, 0.998]
weight_decay: 0
scheduler: null
scheduler_conf:
model_size: 512
warmup_steps: 10000
# without spec aug
specaug: null
log_interval: 50
# without normalize
normalize: null

View File

@ -0,0 +1,133 @@
model: sond
model_conf:
lsm_weight: 0.0
length_normalized_loss: true
max_spk_num: 16
normalize_speech_speaker: true
speaker_discrimination_loss_weight: 0
inter_score_loss_weight: 0.1
model_regularizer_weight: 0.0
freeze_encoder: false
onfly_shuffle_speaker: false
# label aggregator
label_aggregator: label_aggregator_max_pool
label_aggregator_conf:
hop_length: 8
# speech encoder
encoder: resnet34_sp_l2reg
encoder_conf:
# pass by model, equal to feature dim
# input_size: 80
batchnorm_momentum: 0.01
pooling_type: "window_shift"
pool_size: 20
stride: 1
tf2torch_tensor_name_prefix_torch: encoder
tf2torch_tensor_name_prefix_tf: EAND/speech_encoder
speaker_encoder: null
speaker_encoder_conf: {}
ci_scorer: conv
ci_scorer_conf:
input_units: 512
num_layers: 3
num_units: 512
kernel_size: 1
dropout_rate: 0.0
position_encoder: null
out_units: 1
out_norm: false
auxiliary_states: false
tf2torch_tensor_name_prefix_torch: ci_scorer
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/ci_scorer
cd_scorer: san
cd_scorer_conf:
input_size: 512
output_size: 512
out_units: 1
attention_heads: 4
linear_units: 1024
num_blocks: 4
dropout_rate: 0.0
positional_dropout_rate: 0.0
attention_dropout_rate: 0.0
# use string "null" to remove input layer
input_layer: "null"
pos_enc_class: null
normalize_before: true
tf2torch_tensor_name_prefix_torch: cd_scorer
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/cd_scorer
# post net
decoder: fsmn
decoder_conf:
in_units: 32
out_units: 2517
filter_size: 31
fsmn_num_layers: 6
dnn_num_layers: 1
num_memory_units: 16
ffn_inner_dim: 512
dropout_rate: 0.0
tf2torch_tensor_name_prefix_torch: decoder
tf2torch_tensor_name_prefix_tf: EAND/post_net
input_size: 80
frontend: null
frontend_conf:
fs: 16000
window: povey
n_mels: 80
frame_length: 25
frame_shift: 10
filter_length_min: -1
filter_length_max: -1
lfr_m: 1
lfr_n: 1
dither: 0.0
snip_edges: false
upsacle_samples: false
# minibatch related
batch_type: unsorted
# 6 samples
batch_size: 6
num_workers: 8
max_epoch: 12
num_iters_per_epoch: 300
keep_nbest_models: 5
# optimization related
accum_grad: 1
grad_clip: 5.0
val_scheduler_criterion:
- valid
- der
- min
best_model_criterion:
- - valid
- der
- min
- - valid
- forward_steps
- max
optim: adamw
optim_conf:
lr: 0.00001
betas: [0.9, 0.998]
weight_decay: 0
scheduler: null
scheduler_conf:
model_size: 512
warmup_steps: 10000
# without spec aug
specaug: null
log_interval: 50
# without normalize
normalize: null

View File

@ -0,0 +1,2 @@
smooth_size: 1
dur_threshold: 0

View File

@ -0,0 +1,4 @@
--sample-frequency=8000
--num-mel-bins=80
--frame-length=25
--snip-edges=false

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,17 @@
0
1
2
4
8
16
32
64
128
256
512
1024
2048
4096
8192
16384
32768

View File

@ -0,0 +1,137 @@
0
1
2
3
4
5
6
8
9
10
12
16
17
18
20
24
32
33
34
36
40
48
64
65
66
68
72
80
96
128
129
130
132
136
144
160
192
256
257
258
260
264
272
288
320
384
512
513
514
516
520
528
544
576
640
768
1024
1025
1026
1028
1032
1040
1056
1088
1152
1280
1536
2048
2049
2050
2052
2056
2064
2080
2112
2176
2304
2560
3072
4096
4097
4098
4100
4104
4112
4128
4160
4224
4352
4608
5120
6144
8192
8193
8194
8196
8200
8208
8224
8256
8320
8448
8704
9216
10240
12288
16384
16385
16386
16388
16392
16400
16416
16448
16512
16640
16896
17408
18432
20480
24576
32768
32769
32770
32772
32776
32784
32800
32832
32896
33024
33280
33792
34816
36864
40960
49152

View File

@ -0,0 +1,697 @@
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
16
17
18
19
20
21
22
24
25
26
28
32
33
34
35
36
37
38
40
41
42
44
48
49
50
52
56
64
65
66
67
68
69
70
72
73
74
76
80
81
82
84
88
96
97
98
100
104
112
128
129
130
131
132
133
134
136
137
138
140
144
145
146
148
152
160
161
162
164
168
176
192
193
194
196
200
208
224
256
257
258
259
260
261
262
264
265
266
268
272
273
274
276
280
288
289
290
292
296
304
320
321
322
324
328
336
352
384
385
386
388
392
400
416
448
512
513
514
515
516
517
518
520
521
522
524
528
529
530
532
536
544
545
546
548
552
560
576
577
578
580
584
592
608
640
641
642
644
648
656
672
704
768
769
770
772
776
784
800
832
896
1024
1025
1026
1027
1028
1029
1030
1032
1033
1034
1036
1040
1041
1042
1044
1048
1056
1057
1058
1060
1064
1072
1088
1089
1090
1092
1096
1104
1120
1152
1153
1154
1156
1160
1168
1184
1216
1280
1281
1282
1284
1288
1296
1312
1344
1408
1536
1537
1538
1540
1544
1552
1568
1600
1664
1792
2048
2049
2050
2051
2052
2053
2054
2056
2057
2058
2060
2064
2065
2066
2068
2072
2080
2081
2082
2084
2088
2096
2112
2113
2114
2116
2120
2128
2144
2176
2177
2178
2180
2184
2192
2208
2240
2304
2305
2306
2308
2312
2320
2336
2368
2432
2560
2561
2562
2564
2568
2576
2592
2624
2688
2816
3072
3073
3074
3076
3080
3088
3104
3136
3200
3328
3584
4096
4097
4098
4099
4100
4101
4102
4104
4105
4106
4108
4112
4113
4114
4116
4120
4128
4129
4130
4132
4136
4144
4160
4161
4162
4164
4168
4176
4192
4224
4225
4226
4228
4232
4240
4256
4288
4352
4353
4354
4356
4360
4368
4384
4416
4480
4608
4609
4610
4612
4616
4624
4640
4672
4736
4864
5120
5121
5122
5124
5128
5136
5152
5184
5248
5376
5632
6144
6145
6146
6148
6152
6160
6176
6208
6272
6400
6656
7168
8192
8193
8194
8195
8196
8197
8198
8200
8201
8202
8204
8208
8209
8210
8212
8216
8224
8225
8226
8228
8232
8240
8256
8257
8258
8260
8264
8272
8288
8320
8321
8322
8324
8328
8336
8352
8384
8448
8449
8450
8452
8456
8464
8480
8512
8576
8704
8705
8706
8708
8712
8720
8736
8768
8832
8960
9216
9217
9218
9220
9224
9232
9248
9280
9344
9472
9728
10240
10241
10242
10244
10248
10256
10272
10304
10368
10496
10752
11264
12288
12289
12290
12292
12296
12304
12320
12352
12416
12544
12800
13312
14336
16384
16385
16386
16387
16388
16389
16390
16392
16393
16394
16396
16400
16401
16402
16404
16408
16416
16417
16418
16420
16424
16432
16448
16449
16450
16452
16456
16464
16480
16512
16513
16514
16516
16520
16528
16544
16576
16640
16641
16642
16644
16648
16656
16672
16704
16768
16896
16897
16898
16900
16904
16912
16928
16960
17024
17152
17408
17409
17410
17412
17416
17424
17440
17472
17536
17664
17920
18432
18433
18434
18436
18440
18448
18464
18496
18560
18688
18944
19456
20480
20481
20482
20484
20488
20496
20512
20544
20608
20736
20992
21504
22528
24576
24577
24578
24580
24584
24592
24608
24640
24704
24832
25088
25600
26624
28672
32768
32769
32770
32771
32772
32773
32774
32776
32777
32778
32780
32784
32785
32786
32788
32792
32800
32801
32802
32804
32808
32816
32832
32833
32834
32836
32840
32848
32864
32896
32897
32898
32900
32904
32912
32928
32960
33024
33025
33026
33028
33032
33040
33056
33088
33152
33280
33281
33282
33284
33288
33296
33312
33344
33408
33536
33792
33793
33794
33796
33800
33808
33824
33856
33920
34048
34304
34816
34817
34818
34820
34824
34832
34848
34880
34944
35072
35328
35840
36864
36865
36866
36868
36872
36880
36896
36928
36992
37120
37376
37888
38912
40960
40961
40962
40964
40968
40976
40992
41024
41088
41216
41472
41984
43008
45056
49152
49153
49154
49156
49160
49168
49184
49216
49280
49408
49664
50176
51200
53248
57344

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,567 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# This recipe aims at reimplement the results of SOND on Callhome corpus which is represented in
# [1] TOLD: A Novel Two-stage Overlap-aware Framework for Speaker Diarization, ICASSP 2023
# You can also use it on other dataset such AliMeeting to reproduce the results in
# [2] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, EMNLP 2022
# We recommend you run this script stage by stage.
# This recipe includes:
# 1. downloading a pretrained model on the simulated data from switchboard and NIST,
# 2. finetuning the pretrained model on Callhome1.
# Finally, you will get a slightly better DER result 9.95% on Callhome2 than that in the paper 10.14%.
# environment configuration
kaldi_root=
if [ -z "${kaldi_root}" ]; then
echo "We need kaldi to prepare dataset, extract fbank features, please install kaldi first and set kaldi_root."
echo "Kaldi installation guide can be found at https://kaldi-asr.org/"
exit;
fi
if [ ! -e local ]; then
ln -s ${kaldi_root}/egs/callhome_diarization/v2/local ./local
fi
if [ ! -e utils ]; then
ln -s ${kaldi_root}/egs/callhome_diarization/v2/utils ./utils
fi
# callhome data root like path/to/NIST/LDC2001S97
callhome_root=
if [ -z "${kaldi_root}" ]; then
echo "We need callhome corpus to prepare data."
exit;
fi
# machines configuration
gpu_devices="0,1,2,3" # for V100-16G, need 4 gpus.
gpu_num=4
count=1
# general configuration
stage=0
stop_stage=10
# number of jobs for data process
nj=16
sr=8000
# experiment configuration
lang=en
feats_type=fbank
datadir=data
dumpdir=dump
expdir=exp
train_cmd=utils/run.pl
# training related
tag=""
train_set=callhome1
valid_set=callhome1
train_config=conf/EAND_ResNet34_SAN_L4N512_None_FFN_FSMN_L6N512_bce_dia_loss_01_phase3.yaml
token_list=${datadir}/token_list/powerset_label_n16k4.txt
init_param=
freeze_param=
# inference related
inference_model=valid.der.ave_5best.pb
inference_config=conf/basic_inference.yaml
inference_tag=""
test_sets="callhome2"
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# number of jobs for inference
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=4
infer_cmd=utils/run.pl
told_max_iter=4
. utils/parse_options.sh || exit 1;
model_dir="$(basename "${train_config}" .yaml)_${feats_type}_${lang}${tag}"
# you can set gpu num for decoding here
gpuid_list=$gpu_devices # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
inference_nj=$[${ngpu}*${njob}]
_ngpu=1
else
inference_nj=$njob
_ngpu=0
fi
# Prepare datasets
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Stage 0: Prepare callhome data."
local/make_callhome.sh ${callhome_root} ${datadir}/
# split ref.rttm
for dset in callhome1 callhome2; do
rm -rf ${datadir}/${dset}/ref.rttm
for name in `awk '{print $1}' ${datadir}/${dset}/wav.scp`; do
grep ${name} ${datadir}/callhome/fullref.rttm >> ${datadir}/${dset}/ref.rttm;
done
# filter out records which don't have rttm labels.
awk '{print $2}' ${datadir}/${dset}/ref.rttm | sort | uniq > ${datadir}/${dset}/uttid
mv ${datadir}/${dset}/wav.scp ${datadir}/${dset}/wav.scp.bak
awk '{if (NR==FNR){a[$1]=1}else{if (a[$1]==1){print $0}}}' ${datadir}/${dset}/uttid ${datadir}/${dset}/wav.scp.bak > ${datadir}/${dset}/wav.scp
mkdir ${datadir}/${dset}/raw
mv ${datadir}/${dset}/{reco2num_spk,segments,spk2utt,utt2spk,uttid,wav.scp.bak} ${datadir}/${dset}/raw/
awk '{print $1,$1}' ${datadir}/${dset}/wav.scp > ${datadir}/${dset}/utt2spk
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Stage 1: Dump sph file to wav"
export PATH=${kaldi_root}/tools/sph2pipe/:${PATH}
if [ ! -f ${kaldi_root}/tools/sph2pipe/sph2pipe ]; then
echo "Can not find sph2pipe in ${kaldi_root}/tools/sph2pipe/,"
echo "please install sph2pipe and put it in the right place."
exit;
fi
for dset in callhome1 callhome2; do
echo "Stage 1: start to dump ${dset}."
mv ${datadir}/${dset}/wav.scp ${datadir}/${dset}/sph.scp
mkdir -p ${dumpdir}/${dset}/wavs
python -Wignore script/dump_pipe_wav.py ${datadir}/${dset}/sph.scp ${dumpdir}/${dset}/wavs \
--sr ${sr} --nj ${nj} --no_pbar
find `pwd`/${dumpdir}/${dset}/wavs -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/wav.scp
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Stage 2: Extract non-overlap segments from callhome dataset"
for dset in callhome1 callhome2; do
echo "Stage 2: Extracting non-overlap segments for "${dset}
mkdir -p ${dumpdir}/${dset}/nonoverlap_0s
python -Wignore script/extract_nonoverlap_segments.py \
${datadir}/${dset}/wav.scp ${datadir}/${dset}/ref.rttm ${dumpdir}/${dset}/nonoverlap_0s \
--min_dur 0.1 --max_spk_num 8 --sr ${sr} --no_pbar --nj ${nj}
mkdir -p ${datadir}/${dset}/nonoverlap_0s
find ${dumpdir}/${dset}/nonoverlap_0s/ -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/nonoverlap_0s/wav.scp
awk -F'[/.]' '{print $(NF-1),$(NF-2)}' ${datadir}/${dset}/nonoverlap_0s/wav.scp > ${datadir}/${dset}/nonoverlap_0s/utt2spk
echo "Done."
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Stage 3: Generate fbank features"
home_path=$(pwd)
cd ${kaldi_root}/egs/callhome_diarization/v2 || exit
export train_cmd="run.pl"
export cmd="run.pl"
. ./path.sh
cd $home_path || exit
ln -s ${kaldi_root}/egs/callhome_diarization/v2/steps ./
for dset in callhome1 callhome2; do
utils/fix_data_dir.sh ${datadir}/${dset}
steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
done
rm -f steps
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Stage 4: Extract speaker embeddings."
sv_exp_dir=exp/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
if [ ! -e ${sv_exp_dir} ]; then
echo "start to download sv models"
git lfs install
git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
echo "Done."
fi
for dset in callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
echo "Start to extract speaker embeddings for ${dset}"
key_file=${datadir}/${dset}/wav.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
_logdir=${dumpdir}/${dset}/xvecs
mkdir -p ${_logdir}
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/sv_inference.JOB.log \
python -m funasr.bin.sv_inference_launch \
--batch_size 1 \
--njob ${njob} \
--ngpu "${_ngpu}" \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${key_file},speech,sound" \
--key_file "${_logdir}"/keys.JOB.scp \
--sv_train_config ${sv_exp_dir}/sv.yaml \
--sv_model_file ${sv_exp_dir}/sv.pth \
--output_dir "${_logdir}"/output.JOB
cat ${_logdir}/output.*/xvector.scp | sort > ${datadir}/${dset}/utt2xvec
python script/calc_num_frames.py ${key_file} ${datadir}/${dset}/utt2num_frames
echo "Done."
done
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Stage 5: Generate label files."
for dset in callhome1 callhome2; do
echo "Stage 5: Generate labels for ${dset}."
python -Wignore script/calc_real_meeting_frame_labels.py \
${datadir}/${dset} ${dumpdir}/${dset}/labels \
--n_spk 8 --frame_shift 0.01 --nj 16 --sr 8000
find `pwd`/${dumpdir}/${dset}/labels/ -iname "*.lbl.mat" | awk -F'[/.]' '{print $(NF-2),$0}' | sort > ${datadir}/${dset}/labels.scp
done
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
echo "Stage 6: Make training and evaluation files."
# dump callhome1 data in training mode.
data_dir=${datadir}/callhome1/files_for_dump
mkdir ${data_dir}
# filter out zero duration segments
LC_ALL=C awk '{if ($5 > 0){print $0}}' ${datadir}/callhome1/ref.rttm > ${data_dir}/ref.rttm
cp ${datadir}/callhome1/{feats.scp,labels.scp} ${data_dir}/
cp ${datadir}/callhome1/nonoverlap_0s/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
echo "Stage 6: start to dump for callhome1."
python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
--out ${dumpdir}/callhome1/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode train \
--chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
mkdir -p ${datadir}/callhome1/dumped_files
cat ${dumpdir}/callhome1/dumped_files/data_parts*_feat.scp | sort > ${datadir}/callhome1/dumped_files/feats.scp
cat ${dumpdir}/callhome1/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/callhome1/dumped_files/profile.scp
cat ${dumpdir}/callhome1/dumped_files/data_parts*_label.scp | sort > ${datadir}/callhome1/dumped_files/label.scp
mkdir -p ${expdir}/callhome1_states
awk '{print $1,"1600"}' ${datadir}/callhome1/dumped_files/feats.scp | shuf > ${expdir}/callhome1_states/speech_shape
python -Wignore script/convert_rttm_to_seg_file.py --rttm_scp ${data_dir}/ref.rttm --seg_file ${data_dir}/org_vad.txt
# dump callhome2 data in test mode.
data_dir=${datadir}/callhome2/files_for_dump
mkdir ${data_dir}
# filter out zero duration segments
LC_ALL=C awk '{if ($5 > 0){print $0}}' ${datadir}/callhome2/ref.rttm > ${data_dir}/ref.rttm
cp ${datadir}/callhome2/{feats.scp,labels.scp} ${data_dir}/
cp ${datadir}/callhome2/nonoverlap_0s/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
echo "Stage 6: start to dump for callhome2."
python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
--out ${dumpdir}/callhome2/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode test \
--chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
mkdir -p ${datadir}/callhome2/dumped_files
cat ${dumpdir}/callhome2/dumped_files/data_parts*_feat.scp | sort > ${datadir}/callhome2/dumped_files/feats.scp
cat ${dumpdir}/callhome2/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/callhome2/dumped_files/profile.scp
cat ${dumpdir}/callhome2/dumped_files/data_parts*_label.scp | sort > ${datadir}/callhome2/dumped_files/label.scp
mkdir -p ${expdir}/callhome2_states
awk '{print $1,"1600"}' ${datadir}/callhome2/dumped_files/feats.scp | shuf > ${expdir}/callhome2_states/speech_shape
python -Wignore script/convert_rttm_to_seg_file.py --rttm_scp ${data_dir}/ref.rttm --seg_file ${data_dir}/org_vad.txt
fi
# Finetune model on callhome1, this will take about 1.5 hours.
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "Stage 7: Finetune pretrained model on callhome1."
if [ ! -e ${expdir}/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch ]; then
echo "start to download pretrained models"
git lfs install
git clone https://www.modelscope.cn/damo/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch.git
mv speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch ${expdir}/
echo "Done."
fi
world_size=$gpu_num # run on one machine
mkdir -p ${expdir}/${model_dir}
mkdir -p ${expdir}/${model_dir}/log
mkdir -p /tmp/${model_dir}
INIT_FILE=/tmp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_opt=""
if [ ! -z "${init_param}" ]; then
init_opt="--init_param ${init_param}"
echo ${init_opt}
fi
freeze_opt=""
if [ ! -z "${freeze_param}" ]; then
freeze_opt="--freeze_param ${freeze_param}"
echo ${freeze_opt}
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
python -m funasr.bin.diar_train \
--gpu_id $gpu_id \
--use_preprocessor false \
--token_type char \
--token_list $token_list \
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
--train_shape_file ${expdir}/${valid_set}_states/speech_shape \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
--valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
--init_param ${expdir}/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch/sond.pth \
--unused_parameters true \
${init_opt} \
${freeze_opt} \
--ignore_init_mismatch true \
--resume true \
--output_dir ${expdir}/${model_dir} \
--config ${train_config} \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${expdir}/${model_dir}/log/train.log.$i 2>&1
} &
done
echo "Training log can be found at ${expdir}/${model_dir}/log/train.log.*"
wait
fi
# evaluate for finetuned model
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
echo "stage 8: evaluation for finetuned model ${inference_model}."
for dset in ${test_sets}; do
echo "Processing for $dset"
exp_model_dir=${expdir}/${model_dir}
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${exp_model_dir}/${_inference_tag}/${inference_model}/${dset}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "WARNING: ${_dir} is already exists."
fi
mkdir -p "${_logdir}"
_data="${datadir}/${dset}/dumped_files"
key_file=${_data}/feats.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
_opt=
if [ ! -z "${inference_config}" ]; then
_opt="--config ${inference_config}"
fi
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
echo "Inference log can be found at ${_logdir}/inference.*.log"
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
python -m funasr.bin.diar_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
--data_path_and_name_and_type "${_data}/profile.scp,profile,kaldi_ark" \
--key_file "${_logdir}"/keys.JOB.scp \
--diar_train_config "${exp_model_dir}"/config.yaml \
--diar_model_file "${exp_model_dir}"/${inference_model} \
--output_dir "${_logdir}"/output.JOB \
--mode sond ${_opt}
done
fi
# Scoring for finetuned model, you may get a DER like:
# oracle_vad | system_vad
# 7.32 | 8.14
if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
echo "stage 9: Scoring finetuned models"
if [ ! -e dscore ]; then
git clone https://github.com/nryant/dscore.git
pip install intervaltree
# add intervaltree to setup.py
fi
for dset in ${test_sets}; do
echo "stage 9: Scoring for ${dset}"
diar_exp=${expdir}/${model_dir}
_data="${datadir}/${dset}"
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
_logdir="${_dir}/logdir"
cat ${_logdir}/*/labels.txt | sort > ${_dir}/labels.txt
cmd="python -Wignore script/convert_label_to_rttm.py ${_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${_dir}/sys.rttm \
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
echo ${cmd}
eval ${cmd}
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${_dir}/sys.rttm.ref_vad
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${_dir}/sys.rttm.sys_vad
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
echo -e "${inference_model} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${_dir}/results.txt
done
fi
# In this stage, we need the raw waveform files of Callhome corpus.
# Due to the data license, we can't provide them, please get them additionally.
# And convert the sph files to wav files (use scripts/dump_pipe_wav.py).
# Then find the wav files to construct wav.scp and put it at data/callhome2/wav.scp.
# After iteratively perform SOAP, you will get DER results like:
# iters : oracle_vad | system_vad
# iter_0: 9.58 | 10.46
# iter_1: 9.22 | 10.15
# iter_2: 9.21 | 10.14
# iter_3: 9.30 | 10.24
# iter_4: 9.29 | 10.23
if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
if [ ! -e ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ]; then
git lfs install
git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
fi
for dset in ${test_sets}; do
echo "stage 10: Evaluating finetuned system on ${dset} set with medfilter_size=83 clustering=EEND-OLA"
sv_exp_dir=${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
diar_exp=${expdir}/${model_dir}
_data="${datadir}/${dset}/dumped_files"
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
for iter in `seq 0 ${told_max_iter}`; do
eval_dir=${_dir}/iter_${iter}
if [ $iter -eq 0 ]; then
prev_rttm=${expdir}/EEND-OLA/sys.rttm
else
prev_rttm=${_dir}/iter_$((${iter}-1))/sys.rttm.sys_vad
fi
echo "Use ${prev_rttm} as system outputs."
echo "Iteration ${iter}, step 1: extracting non-overlap segments"
cmd="python -Wignore script/extract_nonoverlap_segments.py ${datadir}/${dset}/wav.scp \
$prev_rttm ${eval_dir}/nonoverlap_segs/ --min_dur 0.1 --max_spk_num 16 --no_pbar --sr 8000"
# echo ${cmd}
eval ${cmd}
echo "Iteration ${iter}, step 2: make data directory"
mkdir -p ${eval_dir}/data
find `pwd`/${eval_dir}/nonoverlap_segs/ -iname "*.wav" | sort > ${eval_dir}/data/wav.flist
awk -F'[/.]' '{print $(NF-1),$0}' ${eval_dir}/data/wav.flist > ${eval_dir}/data/wav.scp
awk -F'[/.]' '{print $(NF-1),$(NF-2)}' ${eval_dir}/data/wav.flist > ${eval_dir}/data/utt2spk
cp $prev_rttm ${eval_dir}/data/sys.rttm
home_path=`pwd`
echo "Iteration ${iter}, step 3: calc x-vector for each utt"
key_file=${eval_dir}/data/wav.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
_logdir=${eval_dir}/data/xvecs
mkdir -p ${_logdir}
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/sv_inference.JOB.log \
python -m funasr.bin.sv_inference_launch \
--njob ${njob} \
--batch_size 1 \
--ngpu "${_ngpu}" \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${key_file},speech,sound" \
--key_file "${_logdir}"/keys.JOB.scp \
--sv_train_config ${sv_exp_dir}/sv.yaml \
--sv_model_file ${sv_exp_dir}/sv.pth \
--output_dir "${_logdir}"/output.JOB
cat ${_logdir}/output.*/xvector.scp | sort > ${eval_dir}/data/utt2xvec
echo "Iteration ${iter}, step 4: dump x-vector record"
awk '{print $1}' ${_data}/feats.scp > ${eval_dir}/data/idx
python script/dump_speaker_profiles.py --dir ${eval_dir}/data \
--out ${eval_dir}/global_n16 --n_spk 16 --no_pbar --emb_type global
spk_profile=${eval_dir}/global_n16_parts00_xvec.scp
echo "Iteration ${iter}, step 5: perform NN diarization"
_logdir=${eval_dir}/diar
mkdir -p ${_logdir}
key_file=${_data}/feats.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
_opt=
if [ ! -z "${inference_config}" ]; then
_opt="--config ${inference_config}"
fi
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
echo "Inference log can be found at ${_logdir}/inference.*.log"
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
python -m funasr.bin.diar_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
--data_path_and_name_and_type "${spk_profile},profile,kaldi_ark" \
--key_file "${_logdir}"/keys.JOB.scp \
--diar_train_config ${diar_exp}/config.yaml \
--diar_model_file ${diar_exp}/${inference_model} \
--output_dir "${_logdir}"/output.JOB \
--mode sond ${_opt}
echo "Iteration ${iter}, step 6: calc diarization results"
cat ${_logdir}/output.*/labels.txt | sort > ${eval_dir}/labels.txt
cmd="python -Wignore script/convert_label_to_rttm.py ${eval_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${eval_dir}/sys.rttm \
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
# echo ${cmd}
eval ${cmd}
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${eval_dir}/sys.rttm.ref_vad
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${eval_dir}/sys.rttm.sys_vad
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
echo -e "${inference_model}/iter_${iter} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${eval_dir}/results.txt
done
echo "Done."
done
fi

View File

@ -0,0 +1,5 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH

View File

@ -0,0 +1,954 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# This recipe aims at reimplement the results of SOND on Callhome corpus which is represented in
# [1] TOLD: A Novel Two-stage Overlap-aware Framework for Speaker Diarization, ICASSP 2023
# You can also use it on other dataset such AliMeeting to reproduce the results in
# [2] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, EMNLP 2022
# We recommend you run this script stage by stage.
# [developing] This recipe includes:
# 1. simulating data with switchboard and NIST.
# 2. training the model from scratch for 3 stages:
# 2-1. pre-train on simu_swbd_sre
# 2-2. train on simu_swbd_sre
# 2-3. finetune on callhome1
# 3. evaluating model with the results from the first stage EEND-OLA,
# Finally, you will get a similar DER result claimed in the paper.
# environment configuration
kaldi_root=
if [ -z "${kaldi_root}" ]; then
echo "We need kaldi to prepare dataset, extract fbank features, please install kaldi first and set kaldi_root."
echo "Kaldi installation guide can be found at https://kaldi-asr.org/"
exit;
fi
if [ ! -e local ]; then
ln -s ${kaldi_root}/egs/callhome_diarization/v2/local ./local
fi
if [ ! -e utils ]; then
ln -s ${kaldi_root}/egs/callhome_diarization/v2/utils ./utils
fi
# machines configuration
gpu_devices="4,5,6,7" # for V100-16G, use 4 GPUs
gpu_num=4
count=1
# general configuration
stage=3
stop_stage=3
# number of jobs for data process
nj=16
sr=8000
# dataset related
data_root=
callhome_root=path/to/NIST/LDC2001S97
# experiment configuration
lang=en
feats_type=fbank
datadir=data
dumpdir=dump
expdir=exp
train_cmd=utils/run.pl
# training related
tag=""
train_set=simu_swbd_sre
valid_set=callhome1
train_config=conf/EAND_ResNet34_SAN_L4N512_None_FFN_FSMN_L6N512_bce_dia_loss_01.yaml
token_list=${datadir}/token_list/powerset_label_n16k4.txt
init_param=
freeze_param=
# inference related
inference_model=valid.der.ave_5best.pth
inference_config=conf/basic_inference.yaml
inference_tag=""
test_sets="callhome1"
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# number of jobs for inference
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
infer_cmd=utils/run.pl
told_max_iter=2
. utils/parse_options.sh || exit 1;
model_dir="$(basename "${train_config}" .yaml)_${feats_type}_${lang}${tag}"
# you can set gpu num for decoding here
gpuid_list=$gpu_devices # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
inference_nj=$[${ngpu}*${njob}]
_ngpu=1
else
inference_nj=$njob
_ngpu=0
fi
# Prepare datasets
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# 1. Prepare a collection of NIST SRE data.
echp "Stage 0: Prepare a collection of NIST SRE data."
local/make_sre.sh $data_root ${datadir}
# 2.a Prepare SWB.
local/make_swbd2_phase1.pl ${data_root}/LDC98S75 \
${datadir}/swbd2_phase1_train
local/make_swbd2_phase2.pl $data_root/LDC99S79 \
${datadir}/swbd2_phase2_train
local/make_swbd2_phase3.pl $data_root/LDC2002S06 \
${datadir}/swbd2_phase3_train
local/make_swbd_cellular1.pl $data_root/LDC2001S13 \
${datadir}/swbd_cellular1_train
local/make_swbd_cellular2.pl $data_root/LDC2004S07 \
${datadir}/swbd_cellular2_train
# 2.b combine all swbd data.
utils/combine_data.sh ${datadir}/swbd \
${datadir}/swbd2_phase1_train ${datadir}/swbd2_phase2_train ${datadir}/swbd2_phase3_train \
${datadir}/swbd_cellular1_train ${datadir}/swbd_cellular2_train
utils/validate_data_dir.sh --no-text --no-feats ${datadir}/swbd
utils/fix_data_dir.sh ${datadir}/swbd
utils/combine_data.sh ${datadir}/swbd_sre ${datadir}/swbd ${datadir}/sre
utils/validate_data_dir.sh --no-text --no-feats ${datadir}/swbd_sre
utils/fix_data_dir.sh ${datadir}/swbd_sre
# 3. Prepare the Callhome portion of NIST SRE 2000.
local/make_callhome.sh ${callhome_root} ${datadir}/
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Stage 1: Dump sph file to wav"
export PATH=${kaldi_root}/tools/sph2pipe/:${PATH}
if [ ! -f ${kaldi_root}/tools/sph2pipe/sph2pipe ]; then
echo "Can not find sph2pipe in ${kaldi_root}/tools/sph2pipe/,"
echo "please install sph2pipe and put it in the right place."
exit;
fi
for dset in callhome1 callhome2 swbd_sre; do
echo "Stage 1: start to dump ${dset}."
mv ${datadir}/${dset}/wav.scp ${datadir}/${dset}/sph.scp
mkdir -p ${dumpdir}/${dset}/wavs
python -Wignore script/dump_pipe_wav.py ${datadir}/${dset}/sph.scp ${dumpdir}/${dset}/wavs \
--sr ${sr} --nj ${nj} --no_pbar
find `pwd`/${dumpdir}/${dset}/wavs -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/wav.scp
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Stage 2: Extract non-overlap segments from callhome dataset"
for dset in callhome1 callhome2; do
echo "Stage 2: Extracting non-overlap segments for "${dset}
mkdir -p ${dumpdir}/${dset}/nonoverlap_0s
python -Wignore script/extract_nonoverlap_segments.py \
${datadir}/${dset}/wav.scp ${datadir}/${dset}/ref.rttm ${dumpdir}/${dset}/nonoverlap_0s \
--min_dur 0 --max_spk_num 8 --sr ${sr} --no_pbar --nj ${nj}
mkdir -p ${datadir}/${dset}/nonoverlap_0s
find `pwd`/${dumpdir}/${dset}/nonoverlap_0s | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/nonoverlap_0s/wav.scp
awk -F'[/.]' '{print $(NF-1),$(NF-2)}' ${datadir}/${dset}/nonoverlap_0s/wav.scp > ${datadir}/${dset}/nonoverlap_0s/utt2spk
echo "Done."
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Stage 3: Generate concatenated waveforms for each speaker in switchboard, sre and callhome1"
mkdir swb_sre_resources
wget --no-check-certificate -P swb_sre_resources/ https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/Speaker_Diar/swb_sre_resources/noise.scp
wget --no-check-certificate -P swb_sre_resources/ https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/Speaker_Diar/swb_sre_resources/swbd_sre_tdnn_vad_segments
mkdir ${datadir}/swbd_sre/none_silence
ln -s swb_sre_resources/swbd_sre_tdnn_vad_segments ${datadir}/swbd_sre/none_silence/segments
cp ${datadir}/swbd_sre/wav.scp ${datadir}/swbd_sre/none_silence/reco.scp
mkdir -p ${dumpdir}/swbd_sre/none_silence
python -Wignore script/remove_silence_from_wav.py \
${datadir}/swbd_sre/none_silence ${dumpdir}/swbd_sre/none_silence --nj ${nj} --sr 8000
# The utterance number in wav.scp may be different from reco.scp,
# since some recordings don't appear in the segments file, may due to the VAD
echo "find wavs_nosil"
find `pwd`/${dumpdir}/swbd_sre/none_silence -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/swbd_sre/none_silence/wav.scp
echo "concat spk segments"
ln -s ${datadir}/swbd_sre/utt2spk ${datadir}/swbd_sre/none_silence/utt2spk
echo "Stage 3: Start to concatnate waveforms for speakers in switchboard and sre"
python -Wignore egs/callhome/concat_spk_segs.py \
${datadir}/swbd_sre/none_silence ${dumpdir}/swbd_sre/spk_wavs --nj ${nj} --sr 8000
echo "Stage 3: Start to concatnate waveforms for speakers in callhome1"
# only use callhome1 as training set to simulate data
python -Wignore egs/callhome/concat_spk_segs.py \
${datadir}/callhome1/nonoverlap_0s ${dumpdir}/callhome1/spk_wavs --nj ${nj} --sr 8000
fi
# simulate data with the pattern of callhome1
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Stage 4: Start to simulate recordings."
if [ ! -e ${dumpdir}/musan ]; then
echo "Stage 4-1: Start to download MUSAN noises from openslr"
wget --no-check-certificate -P ${dumpdir}/musan https://www.openslr.org/resources/17/musan.tar.gz
tar -C ${dumpdir}/musan -xvf ${dumpdir}/musan/musan.tar.gz
fi
if [ ! -e ${dumpdir}/rirs ]; then
echo "Stage 4-2: Start to download RIRs from openslr"
wget --no-check-certificate -P ${dumpdir}/rirs https://www.openslr.org/resources/28/rirs_noises.zip
unzip ${dumpdir}/rirs/rirs_noises.zip -d ${dumpdir}/rirs
fi
mkdir -p ${datadir}/simu_swbd_sre
# only use background noises instead of all noises in MUSAN.
sed "s:/path/to/musan/:`pwd`/${dumpdir}/musan/:g" swb_sre_resources/noise.scp > ${datadir}/simu_swbd_sre/noise.scp
# use simulated RIRs.
find `pwd`/${dumpdir}/rirs/RIRS_NOISES/simulated_rirs/ -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-3)"-"$(NF-1), $0}' > ${datadir}/simu_swbd_sre/rirs.scp
cp ${datadir}/callhome1/{ref.rttm,reco2num_spk} ${datadir}/simu_swbd_sre
find `pwd`/${dumpdir}/swbd_sre/spk_wavs -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/simu_swbd_sre/spk2wav.scp
echo "Stage 4-3: Start to simulate recordings with variable speakers as Callhome1 patterns."
# average duration of callhome is 125s, about 20 chunk with silence
# simulating 22500 (45 jobs x 500 reco) recordings, without random_assign and random_shift_interval
for i in $(seq 0 44); do
cmd="python -Wignore egs/callhome/simu_whole_recordings.py \
${datadir}/simu_swbd_sre \
${dumpdir}/simu_swbd_sre/wavs \
--corpus_name simu_swbd_sre --task_id $i --total_mix 500 --sr 8000 --no_bar &"
echo $cmd
eval $cmd
done
wait;
echo "Stage 4-4: Start to simulate recordings with fixed speakers as Callhome1 patterns."
# simulating 30000 (30 jobs x 1000 reco) recordings for different speaker number 2, 3, 4
for n_spk in $(seq 2 4); do
mkdir -p /home/neo.dzh/corpus/simu_swbd_sre/${n_spk}spk_wavs
for i in $(seq 0 29); do
cmd="python -Wignore egs/callhome/simu_whole_recordings.py \
${datadir}/simu_swbd_sre \
${dumpdir}/simu_swbd_sre/${n_spk}spk_wavs \
--random_assign_spk --random_interval --spk_num ${n_spk} \
--corpus_name simu_swbd_sre --task_id $i --total_mix 1000 --sr 8000 --no_bar &"
echo $cmd
eval $cmd
done
wait;
done
find `pwd`/${dumpdir}/simu_swbd_sre -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/simu_swbd_sre/wav.scp
awk '{print $1,$1}' ${datadir}/simu_swbd_sre/wav.scp > ${datadir}/simu_swbd_sre/utt2spk
find `pwd`/${dumpdir}/simu_swbd_sre -iname "*.rttm" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/simu_swbd_sre/rttm.scp
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Stage 5: Generate fbank features"
home_path=`pwd`
cd ${kaldi_root}/egs/callhome_diarization/v2 || exit
. ./cmd.sh
. ./path.sh
for dset in simu_swbd_sre callhome1 callhome2; do
steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
utils/fix_data_dir.sh ${datadir}/${dset}
done
for dset in swbd_sre/none_silence callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
utils/fix_data_dir.sh ${datadir}/${dset}
done
cd ${home_path} || exit
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
echo "Stage 6: Extract speaker embeddings."
git lfs install
git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
sv_exp_dir=exp/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
sed "s/input_size: null/input_size: 80/g" ${sv_exp_dir}/sv.yaml > ${sv_exp_dir}/sv_fbank.yaml
for dset in swbd_sre/none_silence callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
key_file=${datadir}/${dset}/feats.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
_logdir=${dumpdir}/${dset}/xvecs
mkdir -p ${_logdir}
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/sv_inference.JOB.log \
python -m funasr.bin.sv_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${key_file},speech,kaldi_ark" \
--key_file "${_logdir}"/keys.JOB.scp \
--sv_train_config ${sv_exp_dir}/sv_fbank.yaml \
--sv_model_file ${sv_exp_dir}/sv.pth \
--output_dir "${_logdir}"/output.JOB
cat ${_logdir}/output.*/xvector.scp | sort > ${datadir}/${dset}/utt2xvec
done
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "Stage 7: Generate label files."
for dset in simu_swbd_sre callhome1 callhome2; do
echo "Stage 7: Generate labels for ${dset}."
python -Wignore script/calc_real_meeting_frame_labels.py \
${datadir}/${dset} ${dumpdir}/${dset}/labels \
--n_spk 8 --frame_shift 0.01 --nj 16 --sr 8000
find `pwd`/${dumpdir}/${dset}/labels -iname "*.lbl.mat" | awk -F'[/.]' '{print $(NF-2),$0}' | sort > ${datadir}/${dset}/labels.scp
done
fi
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
echo "Stage 8: Make training and evaluation files."
# dump simulated data in training mode (randomly shuffle the speaker order).
data_dir=${datadir}/simu_swbd_sre/files_for_dump
mkdir ${data_dir}
cp ${datadir}/simu_swbd_sre/{feats.scp,labels.scp} ${data_dir}/
cp ${datadir}/swbd_sre/none_silence/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
# dump data with the window length of 1600 frames and hop length of 400 frames.
echo "Stage 8: start to dump for simu_swbd_sre."
for i in $(seq 0 49); do
cmd="python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
--out ${dumpdir}/simu_swbd_sre/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode train \
--chunk_size 1600 --chunk_shift 400 \
--task_id ${i} --task_size 2250 &"
echo $cmd
eval $cmd
done
wait;
mkdir -p ${datadir}/simu_swbd_sre/dumped_files
cat ${dumpdir}/simu_swbd_sre/dumped_files/data_parts*_feat.scp | sort > ${datadir}/simu_swbd_sre/dumped_files/feats.scp
cat ${dumpdir}/simu_swbd_sre/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/simu_swbd_sre/dumped_files/profile.scp
cat ${dumpdir}/simu_swbd_sre/dumped_files/data_parts*_label.scp | sort > ${datadir}/simu_swbd_sre/dumped_files/label.scp
mkdir -p ${expdir}/simu_swbd_sre_states
awk '{print $1,"1600"}' ${datadir}/simu_swbd_sre/dumped_files/feats.scp | shuf > ${expdir}/simu_swbd_sre_states/speech_shape
# dump callhome1 data in training mode.
data_dir=${datadir}/callhome1/files_for_dump
mkdir ${data_dir}
# filter out zero duration segments
LC_ALL=C awk '{if ($5 > 0){print $0}}' ${datadir}/callhome1/ref.rttm > ${data_dir}/ref.rttm
cp ${datadir}/callhome1/{feats.scp,labels.scp} ${data_dir}/
cp ${datadir}/callhome1/nonoverlap_0s/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
echo "Stage 8: start to dump for callhome1."
python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
--out ${dumpdir}/callhome1/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode test \
--chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
mkdir -p ${datadir}/callhome1/dumped_files
cat ${dumpdir}/callhome1/dumped_files/data_parts*_feat.scp | sort > ${datadir}/callhome1/dumped_files/feats.scp
cat ${dumpdir}/callhome1/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/callhome1/dumped_files/profile.scp
cat ${dumpdir}/callhome1/dumped_files/data_parts*_label.scp | sort > ${datadir}/callhome1/dumped_files/label.scp
mkdir -p ${expdir}/callhome1_states
awk '{print $1,"1600"}' ${datadir}/callhome1/dumped_files/feats.scp | shuf > ${expdir}/callhome1_states/speech_shape
python -Wignore script/convert_rttm_to_seg_file.py --rttm_scp ${data_dir}/ref.rttm --seg_file ${data_dir}/org_vad.txt
# dump callhome2 data in test mode.
data_dir=${datadir}/callhome2/files_for_dump
mkdir ${data_dir}
# filter out zero duration segments
LC_ALL=C awk '{if ($5 > 0){print $0}}' ${datadir}/callhome2/ref.rttm > ${data_dir}/ref.rttm
cp ${datadir}/callhome2/{feats.scp,labels.scp} ${data_dir}/
cp ${datadir}/callhome2/nonoverlap_0s/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
echo "Stage 8: start to dump for callhome2."
python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
--out ${dumpdir}/callhome2/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode test \
--chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
mkdir -p ${datadir}/callhome2/dumped_files
cat ${dumpdir}/callhome2/dumped_files/data_parts*_feat.scp | sort > ${datadir}/callhome2/dumped_files/feats.scp
cat ${dumpdir}/callhome2/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/callhome2/dumped_files/profile.scp
cat ${dumpdir}/callhome2/dumped_files/data_parts*_label.scp | sort > ${datadir}/callhome2/dumped_files/label.scp
mkdir -p ${expdir}/callhome2_states
awk '{print $1,"1600"}' ${datadir}/callhome2/dumped_files/feats.scp | shuf > ${expdir}/callhome2_states/speech_shape
python -Wignore script/convert_rttm_to_seg_file.py --rttm_scp ${data_dir}/ref.rttm --seg_file ${data_dir}/org_vad.txt
fi
# Training Stage, phase 1, pretraining on simulated data with frozen encoder parameters.
# This training may cost about 1.8 days.
if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
echo "stage 10: training phase 1, pretraining on simulated data"
world_size=$gpu_num # run on one machine
mkdir -p ${expdir}/${model_dir}
mkdir -p ${expdir}/${model_dir}/log
mkdir -p /tmp/${model_dir}
INIT_FILE=/tmp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_opt=""
if [ ! -z "${init_param}" ]; then
init_opt="--init_param ${init_param}"
echo ${init_opt}
fi
freeze_opt=""
if [ ! -z "${freeze_param}" ]; then
freeze_opt="--freeze_param ${freeze_param}"
echo ${freeze_opt}
fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
python -m funasr.bin.diar_train \
--gpu_id $gpu_id \
--use_preprocessor false \
--token_type char \
--token_list $token_list \
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/feats.scp,speech,kaldi_ark \
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/profile.scp,profile,kaldi_ark \
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
--train_shape_file ${expdir}/${train_set}_states/speech_shape \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
--valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
--init_param ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/sv.pth:encoder:encoder \
--unused_parameters true \
--freeze_param encoder \
${init_opt} \
${freeze_opt} \
--ignore_init_mismatch true \
--resume true \
--output_dir ${expdir}/${model_dir} \
--config $train_config \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${expdir}/${model_dir}/log/train.log.$i 2>&1
} &
done
echo "Training log can be found at ${expdir}/${model_dir}/log/train.log.*"
wait
fi
# evaluate for pretrained model
if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
echo "stage 11: evaluation for phase-1 model."
for dset in ${test_sets}; do
echo "Processing for $dset"
exp_model_dir=${expdir}/${model_dir}
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${exp_model_dir}/${_inference_tag}/${inference_model}/${dset}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "WARNING: ${_dir} is already exists."
fi
mkdir -p "${_logdir}"
_data="${datadir}/${dset}/dumped_files"
key_file=${_data}/feats.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
_opt=
if [ ! -z "${inference_config}" ]; then
_opt="--config ${inference_config}"
fi
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
echo "Inference log can be found at ${_logdir}/inference.*.log"
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
python -m funasr.bin.diar_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
--data_path_and_name_and_type "${_data}/profile.scp,profile,kaldi_ark" \
--key_file "${_logdir}"/keys.JOB.scp \
--diar_train_config "${exp_model_dir}"/config.yaml \
--diar_model_file "${exp_model_dir}"/"${inference_model}" \
--output_dir "${_logdir}"/output.JOB \
--mode sond ${_opt}
done
fi
# Scoring for pretrained model, you may get a DER like 13.73 16.25
# 13.73: with oracle VAD, 16.25: with only SOND outputs, aka, system VAD.
if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
echo "stage 12: Scoring phase-1 models"
if [ ! -e dscore ]; then
git clone https://github.com/nryant/dscore.git
# add intervaltree to setup.py
fi
for dset in ${test_sets}; do
echo "stage 12: Scoring for ${dset}"
diar_exp=${expdir}/${model_dir}
_data="${datadir}/${dset}"
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
_logdir="${_dir}/logdir"
cat ${_logdir}/*/labels.txt | sort > ${_dir}/labels.txt
cmd="python -Wignore script/convert_label_to_rttm.py ${_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${_dir}/sys.rttm \
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
# echo ${cmd}
eval ${cmd}
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${_dir}/sys.rttm.ref_vad
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${_dir}/sys.rttm.sys_vad
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
echo -e "${inference_model} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${_dir}/results.txt
done
fi
# Training Stage, phase 2, training on simulated data without frozen parameters.
# For V100-16G, please set batch_size to 8 in the config, and use 4 GPU to train the model with options like --gpu_devices 4,5,6,7 --gpu_num 4.
# For V100-32G, please set batch_size to 16 in the config, and use 2 GPU to train the model with options like --gpu_devices 4,5,6,7 --gpu_num 2.
# This training may cost about 3.5 days.
if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then
echo "stage 13: training phase 2, training on simulated data"
world_size=$gpu_num # run on one machine
mkdir -p ${expdir}/${model_dir}_phase2
mkdir -p ${expdir}/${model_dir}_phase2/log
mkdir -p /tmp/${model_dir}_phase2
INIT_FILE=/tmp/${model_dir}_phase2/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_opt=""
if [ ! -z "${init_param}" ]; then
init_opt="--init_param ${init_param}"
echo ${init_opt}
fi
freeze_opt=""
if [ ! -z "${freeze_param}" ]; then
freeze_opt="--freeze_param ${freeze_param}"
echo ${freeze_opt}
fi
phase2_config="$(dirname "${train_config}")/$(basename "${train_config}" .yaml)_phase2.yaml"
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
python -m funasr.bin.diar_train \
--gpu_id $gpu_id \
--use_preprocessor false \
--token_type char \
--token_list $token_list \
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/feats.scp,speech,kaldi_ark \
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/profile.scp,profile,kaldi_ark \
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
--train_shape_file ${expdir}/${train_set}_states/speech_shape \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
--valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
--init_param exp/${model_dir}/valid.der.ave_5best.pth \
--unused_parameters true \
${init_opt} \
${freeze_opt} \
--ignore_init_mismatch true \
--resume true \
--output_dir ${expdir}/${model_dir}_phase2 \
--config ${phase2_config} \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${expdir}/${model_dir}_phase2/log/train.log.$i 2>&1
} &
done
echo "Training log can be found at ${expdir}/${model_dir}_phase2/log/train.log.*"
wait
fi
# evaluate for phase-2 model
if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then
echo "stage 14: evaluation for phase-2 model ${inference_model}."
for dset in ${test_sets}; do
echo "Processing for $dset"
exp_model_dir=${expdir}/${model_dir}_phase2
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${exp_model_dir}/${_inference_tag}/${inference_model}/${dset}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "WARNING: ${_dir} is already exists."
fi
mkdir -p "${_logdir}"
_data="${datadir}/${dset}/dumped_files"
key_file=${_data}/feats.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
_opt=
if [ ! -z "${inference_config}" ]; then
_opt="--config ${inference_config}"
fi
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
echo "Inference log can be found at ${_logdir}/inference.*.log"
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
python -m funasr.bin.diar_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
--data_path_and_name_and_type "${_data}/profile.scp,profile,kaldi_ark" \
--key_file "${_logdir}"/keys.JOB.scp \
--diar_train_config "${exp_model_dir}"/config.yaml \
--diar_model_file "${exp_model_dir}"/${inference_model} \
--output_dir "${_logdir}"/output.JOB \
--mode sond ${_opt}
done
fi
# Scoring for pretrained model, you may get a DER like 11.25 15.30
# 11.25: with oracle VAD, 15.30: with only SOND outputs, aka, system VAD.
if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then
echo "stage 15: Scoring phase-2 models"
if [ ! -e dscore ]; then
git clone https://github.com/nryant/dscore.git
# add intervaltree to setup.py
fi
for dset in ${test_sets}; do
echo "stage 15: Scoring for ${dset}"
diar_exp=${expdir}/${model_dir}_phase2
_data="${datadir}/${dset}"
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
_logdir="${_dir}/logdir"
cat ${_logdir}/*/labels.txt | sort > ${_dir}/labels.txt
cmd="python -Wignore script/convert_label_to_rttm.py ${_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${_dir}/sys.rttm \
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
# echo ${cmd}
eval ${cmd}
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${_dir}/sys.rttm.ref_vad
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${_dir}/sys.rttm.sys_vad
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
echo -e "${inference_model} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${_dir}/results.txt
done
fi
# Finetune Stage, phase 3, training on callhom1 training set
if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then
echo "stage 16: training phase 3, finetuing on callhome1 real data"
world_size=$gpu_num # run on one machine
mkdir -p ${expdir}/${model_dir}_phase3
mkdir -p ${expdir}/${model_dir}_phase3/log
mkdir -p /tmp/${model_dir}_phase3
INIT_FILE=/tmp/${model_dir}_phase3/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
fi
init_opt=""
if [ ! -z "${init_param}" ]; then
init_opt="--init_param ${init_param}"
echo ${init_opt}
fi
freeze_opt=""
if [ ! -z "${freeze_param}" ]; then
freeze_opt="--freeze_param ${freeze_param}"
echo ${freeze_opt}
fi
phase3_config="$(dirname "${train_config}")/$(basename "${train_config}" .yaml)_phase3.yaml"
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
{
rank=$i
local_rank=$i
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
python -m funasr.bin.diar_train \
--gpu_id $gpu_id \
--use_preprocessor false \
--token_type char \
--token_list $token_list \
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
--train_shape_file ${expdir}/${valid_set}_states/speech_shape \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
--valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
--init_param exp/${model_dir}_phase2/valid.forward_steps.ave_5best.pth \
--unused_parameters true \
${init_opt} \
${freeze_opt} \
--ignore_init_mismatch true \
--resume true \
--output_dir ${expdir}/${model_dir}_phase3 \
--config ${phase3_config} \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${expdir}/${model_dir}_phase3/log/train.log.$i 2>&1
} &
done
echo "Training log can be found at ${expdir}/${model_dir}_phase3/log/train.log.*"
wait
fi
# evaluate for finetuned model
if [ ${stage} -le 17 ] && [ ${stop_stage} -ge 17 ]; then
echo "stage 17: evaluation for finetuned model ${inference_model}."
for dset in ${test_sets}; do
echo "Processing for $dset"
exp_model_dir=${expdir}/${model_dir}_phase3
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${exp_model_dir}/${_inference_tag}/${inference_model}/${dset}"
_logdir="${_dir}/logdir"
if [ -d ${_dir} ]; then
echo "WARNING: ${_dir} is already exists."
fi
mkdir -p "${_logdir}"
_data="${datadir}/${dset}/dumped_files"
key_file=${_data}/feats.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
_opt=
if [ ! -z "${inference_config}" ]; then
_opt="--config ${inference_config}"
fi
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
echo "Inference log can be found at ${_logdir}/inference.*.log"
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
python -m funasr.bin.diar_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
--data_path_and_name_and_type "${_data}/profile.scp,profile,kaldi_ark" \
--key_file "${_logdir}"/keys.JOB.scp \
--diar_train_config "${exp_model_dir}"/config.yaml \
--diar_model_file "${exp_model_dir}"/${inference_model} \
--output_dir "${_logdir}"/output.JOB \
--mode sond ${_opt}
done
fi
# average 3 4 5 6 7 epoch
# Scoring for pretrained model, you may get a DER like
# 7.21 8.05 on callhome1
# 8.31 9.32 on callhome2
if [ ${stage} -le 18 ] && [ ${stop_stage} -ge 18 ]; then
echo "stage 18: Scoring finetuned models"
if [ ! -e dscore ]; then
git clone https://github.com/nryant/dscore.git
# add intervaltree to setup.py
fi
for dset in ${test_sets}; do
echo "stage 18: Scoring for ${dset}"
diar_exp=${expdir}/${model_dir}_phase3
_data="${datadir}/${dset}"
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
_logdir="${_dir}/logdir"
cat ${_logdir}/*/labels.txt | sort > ${_dir}/labels.txt
cmd="python -Wignore script/convert_label_to_rttm.py ${_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${_dir}/sys.rttm \
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
echo ${cmd}
eval ${cmd}
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${_dir}/sys.rttm.ref_vad
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${_dir}/sys.rttm.sys_vad
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
echo -e "${inference_model} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${_dir}/results.txt
done
fi
if [ ${stage} -le 19 ] && [ ${stop_stage} -ge 19 ]; then
for dset in ${test_sets}; do
echo "stage 19: Evaluating phase-3 system on ${dset} set with medfilter_size=83 clustering=EEND-OLA"
sv_exp_dir=${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
diar_exp=${expdir}/${model_dir}_phase3
_data="${datadir}/${dset}/dumped_files"
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
for iter in `seq 0 ${told_max_iter}`; do
eval_dir=${_dir}/iter_${iter}
if [ $iter -eq 0 ]; then
prev_rttm=${expdir}/EEND-OLA/sys.rttm
else
prev_rttm=${_dir}/iter_$((${iter}-1))/sys.rttm.sys_vad
fi
echo "Use ${prev_rttm} as system outputs."
echo "Iteration ${iter}, step 1: extracting non-overlap segments"
cmd="python -Wignore script/extract_nonoverlap_segments.py ${datadir}/${dset}/wav.scp \
$prev_rttm ${eval_dir}/nonoverlap_segs/ --min_dur 0.1 --max_spk_num 16 --no_pbar --sr 8000"
# echo ${cmd}
eval ${cmd}
echo "Iteration ${iter}, step 2: make data directory"
mkdir -p ${eval_dir}/data
find `pwd`/${eval_dir}/nonoverlap_segs/ -iname "*.wav" | sort > ${eval_dir}/data/wav.flist
awk -F'[/.]' '{print $(NF-1),$0}' ${eval_dir}/data/wav.flist > ${eval_dir}/data/wav.scp
awk -F'[/.]' '{print $(NF-1),$(NF-2)}' ${eval_dir}/data/wav.flist > ${eval_dir}/data/utt2spk
cp $prev_rttm ${eval_dir}/data/sys.rttm
home_path=`pwd`
echo "Iteration ${iter}, step 3: calc x-vector for each utt"
key_file=${eval_dir}/data/wav.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
_logdir=${eval_dir}/data/xvecs
mkdir -p ${_logdir}
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/sv_inference.JOB.log \
python -m funasr.bin.sv_inference_launch \
--njob ${njob} \
--batch_size 1 \
--ngpu "${_ngpu}" \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${key_file},speech,sound" \
--key_file "${_logdir}"/keys.JOB.scp \
--sv_train_config ${sv_exp_dir}/sv.yaml \
--sv_model_file ${sv_exp_dir}/sv.pth \
--output_dir "${_logdir}"/output.JOB
cat ${_logdir}/output.*/xvector.scp | sort > ${eval_dir}/data/utt2xvec
echo "Iteration ${iter}, step 4: dump x-vector record"
awk '{print $1}' ${_data}/feats.scp > ${eval_dir}/data/idx
python script/dump_speaker_profiles.py --dir ${eval_dir}/data \
--out ${eval_dir}/global_n16 --n_spk 16 --no_pbar --emb_type global
spk_profile=${eval_dir}/global_n16_parts00_xvec.scp
echo "Iteration ${iter}, step 5: perform NN diarization"
_logdir=${eval_dir}/diar
mkdir -p ${_logdir}
key_file=${_data}/feats.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
split_scps=
for n in $(seq "${_nj}"); do
split_scps+=" ${_logdir}/keys.${n}.scp"
done
_opt=
if [ ! -z "${inference_config}" ]; then
_opt="--config ${inference_config}"
fi
# shellcheck disable=SC2086
utils/split_scp.pl "${key_file}" ${split_scps}
echo "Inference log can be found at ${_logdir}/inference.*.log"
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
python -m funasr.bin.diar_inference_launch \
--batch_size 1 \
--ngpu "${_ngpu}" \
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
--data_path_and_name_and_type "${spk_profile},profile,kaldi_ark" \
--key_file "${_logdir}"/keys.JOB.scp \
--diar_train_config ${diar_exp}/config.yaml \
--diar_model_file ${diar_exp}/${inference_model} \
--output_dir "${_logdir}"/output.JOB \
--mode sond ${_opt}
echo "Iteration ${iter}, step 6: calc diarization results"
cat ${_logdir}/output.*/labels.txt | sort > ${eval_dir}/labels.txt
cmd="python -Wignore script/convert_label_to_rttm.py ${eval_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${eval_dir}/sys.rttm \
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
# echo ${cmd}
eval ${cmd}
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${eval_dir}/sys.rttm.ref_vad
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
ref=${datadir}/${dset}/files_for_dump/ref.rttm
sys=${eval_dir}/sys.rttm.sys_vad
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
echo -e "${inference_model}/iter_${iter} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${eval_dir}/results.txt
done
echo "Done."
done
fi

View File

@ -0,0 +1,21 @@
import os
import sys
import soundfile as sf
from funasr.utils.misc import load_scp_as_list
if __name__ == '__main__':
wav_scp = sys.argv[1]
out_file = sys.argv[2]
frame_shift = 0.01
os.makedirs(os.path.dirname(out_file), exist_ok=True)
out_file = open(out_file, "wt")
for uttid, wav_path in load_scp_as_list(wav_scp):
wav, sr = sf.read(wav_path)
num_frame = wav.shape[0] // int(sr * frame_shift)
out_file.write(f"{uttid} {num_frame}\n")
out_file.flush()
out_file.close()

View File

@ -0,0 +1,101 @@
import numpy as np
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import librosa
import scipy.io as sio
import argparse
from collections import OrderedDict
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("dir", type=str)
parser.add_argument("out_dir", type=str)
parser.add_argument("--n_spk", type=int, default=8)
parser.add_argument("--remove_sil", default=False, action="store_true")
parser.add_argument("--frame_shift", type=float, default=0.01)
args = parser.parse_args()
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
meeting_scp = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
meeting2rttm = self.load_rttm(args.dir)
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
return task_list, None, args
def load_rttm(self, dir_path):
meeting2rttm = OrderedDict()
if os.path.exists(os.path.join(dir_path, "rttm.scp")):
rttm_scp = load_scp_as_list(os.path.join(dir_path, "rttm.scp"))
for mid, rttm_path in rttm_scp:
meeting2rttm[mid] = []
for one_line in open(rttm_path, "rt"):
meeting2rttm[mid].append(one_line.strip())
elif os.path.exists(os.path.join(dir_path, "ref.rttm")):
for one_line in open(os.path.join(dir_path, "ref.rttm"), "rt"):
mid = one_line.strip().split(" ")[1]
if mid not in meeting2rttm:
meeting2rttm[mid] = []
meeting2rttm[mid].append(one_line.strip())
else:
raise IOError("Neither rttm.scp nor ref.rttm exists in {}".format(dir_path))
return meeting2rttm
def post(self, results_list, args):
pass
def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=8000, frame_shift=0.01):
frame_shift = int(frame_shift * sr)
num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
multi_label = np.zeros([n_spk, num_frame], dtype=int)
for _, st, dur, spk in spk_turns:
idx = spk_list.index(spk)
st, dur = int(st * sr), int(dur * sr)
frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
multi_label[idx, frame_st:frame_ed] = 1
if not remove_sil:
return multi_label.T
speech_count = np.sum(multi_label, axis=0)
idx = np.nonzero(speech_count)[0]
label = multi_label[:, idx].T
return label # (T, N)
def build_labels(wav_path, rttms, n_spk, remove_sil=False, sr=8000, frame_shift=0.01):
wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr)
spk_turns = []
spk_list = []
for one_line in rttms:
parts = one_line.strip().split(" ")
mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
if spk not in spk_list:
spk_list.append(spk)
spk_turns.append((mid, st, dur, spk))
labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, sr, frame_shift)
return labels, spk_list
def process(task_args):
_, task_list, _, args = task_args
for mid, wav_path, rttms in task_list:
meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil,
args.sr, args.frame_shift)
save_path = os.path.join(args.out_dir, "{}.lbl.mat".format(mid))
sio.savemat(save_path, {"labels": meeting_labels.astype(bool), "spk_list": spk_list})
# print mid
return None
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -0,0 +1,57 @@
import numpy as np
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import librosa
import soundfile as sf
import argparse
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("dir", type=str)
parser.add_argument("out_dir", type=str)
args = parser.parse_args()
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
print("loading data...")
wav_scp = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
spk2utt = {}
count = 0
for utt, spk in utt2spk.items():
if utt in wav_scp:
if spk not in spk2utt:
spk2utt[spk] = []
spk2utt[spk].append(utt)
count += 1
task_list = spk2utt.keys()
print("total: {} speakers, {} utterances".format(len(spk2utt), count))
print("starting jobs...")
return task_list, [spk2utt, wav_scp], args
def post(self, results_list, args):
pass
def process(task_args):
_, task_list, [spk2utt, wav_scp], args = task_args
for spk in task_list:
spk_wav_list = []
for utt in spk2utt[spk]:
wav = librosa.load(wav_scp[utt], sr=args.sr, mono=True)[0] * 32767
spk_wav_list.append(wav)
sig = np.concatenate(spk_wav_list, axis=0)
save_path = os.path.join(args.out_dir, "{}.wav".format(spk))
sf.write(save_path, sig.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
return None
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -0,0 +1,201 @@
import os
from funasr.utils.job_runner import MultiProcessRunnerV3
import numpy as np
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
from collections import OrderedDict
from tqdm import tqdm
from scipy.ndimage import median_filter
import kaldiio
def load_mid_vad(vad_path):
mid2segment_list = {}
for one_line in open(vad_path, "rt"):
utt_id, mid, start, end = one_line.strip().split(" ")
start, end = float(start), float(end)
if mid not in mid2segment_list:
mid2segment_list[mid] = []
mid2segment_list[mid].append((utt_id, start, end))
return mid2segment_list
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("label_txt", type=str)
parser.add_argument("oracle_vad", type=str, default=None)
parser.add_argument("out_rttm", type=str)
parser.add_argument("--sys_vad_prob", type=str, default=None)
parser.add_argument("--sys_vad_threshold", type=float, default=None)
parser.add_argument("--vad_smooth_size", type=int, default=7)
parser.add_argument("--n_spk", type=int, default=4)
parser.add_argument("--chunk_len", type=int, default=1600)
parser.add_argument("--shift_len", type=int, default=400)
parser.add_argument("--ignore_len", type=int, default=5)
parser.add_argument("--smooth_size", type=int, default=7)
parser.add_argument("--vote_prob", type=float, default=0.5)
args = parser.parse_args()
if not os.path.exists(os.path.dirname(args.out_rttm)):
os.makedirs(os.path.dirname(args.out_rttm))
utt2labels = load_scp_as_list(args.label_txt, 'list')
utt2vad_prob = []
if args.sys_vad_prob is not None and os.path.exists(args.sys_vad_prob):
if args.verbose:
print("Use system vad ark file {}".format(args.sys_vad_prob))
for (key, vad_prob), (utt_id, _) in zip(kaldiio.load_ark(args.sys_vad_prob), utt2labels):
utt2vad_prob.append((utt_id, vad_prob))
utt2vad_prob = sorted(utt2vad_prob, key=lambda x: x[0])
utt2labels = sorted(utt2labels, key=lambda x: x[0])
mid2segment_list = load_mid_vad(args.oracle_vad)
meeting2labels = OrderedDict()
for utt_id, chunk_label in utt2labels:
mid = utt_id.split("-")[0]
if mid not in meeting2labels:
meeting2labels[mid] = []
meeting2labels[mid].append(chunk_label)
mid2vad_list = {}
if len(utt2vad_prob) > 0:
for utt_id, vad_prob in utt2vad_prob:
mid = utt_id.split("-")[0]
if mid not in mid2vad_list:
mid2vad_list[mid] = []
mid2vad_list[mid].append(vad_prob)
task_list = [(mid, labels, mid2segment_list[mid], None) if len(mid2vad_list) == 0 else
(mid, labels, mid2segment_list[mid], mid2vad_list[mid])
for mid, labels in meeting2labels.items()]
return task_list, None, args
def post(self, result_list, args):
ref_vad_rttm = open(args.out_rttm + ".ref_vad", "wt")
sys_vad_rttm = open(args.out_rttm + ".sys_vad", "wt")
for results in result_list:
for one_result in results:
ref_vad_rttm.writelines(one_result[0])
sys_vad_rttm.writelines(one_result[1])
ref_vad_rttm.close()
sys_vad_rttm.close()
def int2vec(x, vec_dim=8, dtype=np.int):
b = ('{:0' + str(vec_dim) + 'b}').format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
def seq2arr(seq, vec_dim=8):
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
def sample2ms(sample, sr=16000):
return int(float(sample) / sr * 100)
def calc_multi_labels(chunk_label_list, chunk_len, shift_len, n_spk, vote_prob=0.5):
n_chunk = len(chunk_label_list)
last_chunk_valid_frame = len(chunk_label_list[-1]) - (chunk_len - shift_len)
n_frame = (n_chunk - 2) * shift_len + chunk_len + last_chunk_valid_frame
multi_labels = np.zeros((n_frame, n_spk), dtype=float)
weight = np.zeros((n_frame, 1), dtype=float)
for i in range(n_chunk):
raw_label = chunk_label_list[i]
for k in range(len(raw_label)):
if raw_label[k] == '<unk>':
raw_label[k] = raw_label[k-1] if k > 0 else '0'
chunk_multi_label = seq2arr(raw_label, n_spk)
chunk_len = chunk_multi_label.shape[0]
multi_labels[i*shift_len:i*shift_len+chunk_len, :] += chunk_multi_label
weight[i*shift_len:i*shift_len+chunk_len, :] += 1
multi_labels = multi_labels / weight # normalizing vote
multi_labels = (multi_labels > vote_prob).astype(int) # voting results
return multi_labels
def calc_spk_turns(label_arr, spk_list):
turn_list = []
length = label_arr.shape[0]
n_spk = label_arr.shape[1]
for k in range(n_spk):
if spk_list[k] == "None":
continue
in_utt = False
start = 0
for i in range(length):
if label_arr[i, k] == 1 and in_utt is False:
start = i
in_utt = True
if label_arr[i, k] == 0 and in_utt is True:
turn_list.append([spk_list[k], start, i - start])
in_utt = False
if in_utt:
turn_list.append([spk_list[k], start, length - start])
return turn_list
def smooth_multi_labels(multi_label, win_len):
multi_label = median_filter(multi_label, (win_len, 1), mode="constant", cval=0.0).astype(int)
return multi_label
def calc_vad_mask(segments, total_len):
vad_mask = np.zeros((total_len, 1), dtype=int)
for _, start, end in segments:
start, end = int(start * 100), int(end * 100)
vad_mask[start: end] = 1
return vad_mask
def calc_system_vad_mask(vad_prob_list, total_len, args):
if vad_prob_list is None:
return 1
threshold = args.sys_vad_threshold
chunk_len = args.chunk_len
shift_len = args.shift_len
frame_vad_mask = np.zeros((total_len, 1), dtype=float)
weight = np.zeros((total_len, 1), dtype=float)
for i, vad_prob in enumerate(vad_prob_list):
frame_vad_mask[i * shift_len: i * shift_len + chunk_len] += np.greater(vad_prob, threshold).astype(float)
weight[i * shift_len: i * shift_len + chunk_len] += 1.0
frame_vad_mask = np.greater(frame_vad_mask / weight, args.vote_prob)
frame_vad_mask = frame_vad_mask.astype(int)
frame_vad_mask = smooth_multi_labels(frame_vad_mask.astype(int), args.vad_smooth_size)
return frame_vad_mask
def generate_rttm(mid, multi_labels, spk_list, args):
template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>\n"
spk_turns = calc_spk_turns(multi_labels, spk_list)
spk_turns = sorted(spk_turns, key=lambda x: x[1])
results = []
for spk, st, dur in spk_turns:
# TODO: handle the leak of segments at the change points
if dur > args.ignore_len:
results.append(template.format(mid, float(st) / 100, float(dur) / 100, spk))
return results
def process(task_args):
_, task_list, _, args = task_args
spk_list = ["spk{}".format(i+1) for i in range(args.n_spk)]
results = []
for mid, chunk_label_list, segments_list, sys_vad_list in tqdm(task_list, total=len(task_list),
ascii=True, disable=args.no_pbar):
multi_labels = calc_multi_labels(chunk_label_list, args.chunk_len, args.shift_len, args.n_spk, args.vote_prob)
multi_labels = smooth_multi_labels(multi_labels, args.smooth_size)
oracle_vad_mask = calc_vad_mask(segments_list, multi_labels.shape[0])
oracle_vad_rttm = generate_rttm(mid, multi_labels * oracle_vad_mask, spk_list, args)
system_vad_mask = calc_system_vad_mask(sys_vad_list, multi_labels.shape[0], args)
system_vad_rttm = generate_rttm(mid, multi_labels * system_vad_mask, spk_list, args)
results.append([oracle_vad_rttm, system_vad_rttm])
return results
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -0,0 +1,35 @@
import kaldiio
import os
import sys
import numpy as np
def int2vec(x, vec_dim=8, dtype=np.float32):
b = ('{:0' + str(vec_dim) + 'b}').format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == '1').astype(dtype)
def seq2arr(seq, vec_dim=8):
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
if __name__ == '__main__':
scp_file = sys.argv[1]
label_file = sys.argv[2]
out_file = sys.argv[3]
max_spk_num = int(sys.argv[4])
os.makedirs(os.path.dirname(out_file), exist_ok=True)
out_file = out_file.rsplit('.', maxsplit=1)[0]
wav_writer = kaldiio.WriteHelper("ark,scp,f:{}.ark,{}.scp".format(out_file, out_file))
for i, (uttid, pse_str) in enumerate(zip(open(scp_file, "rt"), open(label_file, "rt"))):
uttid, pse_str = uttid.strip().split(" ", maxsplit=1)[0], pse_str.strip()
bin_label = seq2arr(pse_str.split(" "), vec_dim=max_spk_num)
wav_writer(uttid, bin_label)
if i % 100 == 0:
print(f"Processed {i} samples, the last is {uttid}")
wav_writer.close()

View File

@ -0,0 +1,63 @@
import numpy as np
from funasr.utils.job_runner import MultiProcessRunnerV3
import os
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("--rttm_scp", type=str)
parser.add_argument("--seg_file", type=str)
args = parser.parse_args()
if not os.path.exists(os.path.dirname(args.seg_file)):
os.makedirs(os.path.dirname(args.seg_file))
meeting2rttms = {}
for one_line in open(args.rttm_scp, "rt"):
parts = [x for x in one_line.strip().split(" ") if x != ""]
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
if mid not in meeting2rttms:
meeting2rttms[mid] = []
meeting2rttms[mid].append(one_line)
task_list = list(meeting2rttms.items())
return task_list, None, args
def post(self, results_list, args):
with open(args.seg_file, "wt") as fd:
for results in results_list:
fd.writelines(results)
def process(task_args):
_, task_list, _, args = task_args
outputs = []
for mid, rttms in task_list:
spk_turns = []
length = 0
for one_line in rttms:
parts = one_line.strip().split(" ")
_, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
st, ed = int(st*100), int((st + dur)*100)
length = ed if ed > length else length
spk_turns.append([mid, st, ed, spk_name])
is_sph = np.zeros((length+1, ), dtype=bool)
for _, st, ed, _ in spk_turns:
is_sph[st:ed] = True
st, in_speech = 0, False
for i in range(length+1):
if not in_speech and is_sph[i]:
st, in_speech = i, True
if in_speech and not is_sph[i]:
in_speech = False
outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format(
mid, st, i, mid, float(st)/100, float(i)/100
))
return outputs
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -0,0 +1,176 @@
import kaldiio
from tqdm import tqdm
import os
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import numpy as np
import argparse
import random
import scipy.io as sio
import logging
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
short_spk_list = []
def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
all_utts = spk2utt[spk]
idx_list = list(range(len(all_utts)))
random.shuffle(idx_list)
count = 0
utt_list = []
for i in idx_list:
utt_id = all_utts[i]
utt_list.append(utt_id)
count += int(utt2frames[utt_id])
if count >= total_len:
break
if count < 300 and spk not in short_spk_list:
logging.warning("{} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
short_spk_list.append(spk)
ivc_list = [kaldiio.load_mat(utt2ivc[utt])[np.newaxis, :] for utt in utt_list]
ivc = np.concatenate(ivc_list, axis=0)
ivc = np.mean(ivc, axis=0, keepdims=False)
return ivc
def process(feat_scp, labels_scp, spk2utt, utt2xvec, utt2frames, args):
out_prefix = "{}_parts{:02d}".format(args.out, args.task_id)
logger = logging.Logger(out_prefix, logging.INFO)
file_handler = logging.FileHandler(out_prefix + ".log", mode="w")
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
ivc_dim = 256
chunk_size, chunk_shift = args.chunk_size, args.chunk_shift
label_weights = 2 ** np.array(list(range(args.n_spk)))
feat_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_prefix}_feat.ark,{out_prefix}_feat.scp")
ivc_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_prefix}_xvec.ark,{out_prefix}_xvec.scp")
label_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_prefix}_label.ark,{out_prefix}_label.scp")
train_spk_list = list(spk2utt.keys())
frames_list = []
non_present_spk_list = []
for mid, feat_path in tqdm(feat_scp, total=len(feat_scp), ascii=True, disable=args.no_pbar):
if mid not in labels_scp:
continue
feat = kaldiio.load_mat(feat_path)
data = sio.loadmat(labels_scp[mid])
labels, meeting_spk_list = data["labels"].astype(int), [x.strip() for x in data["spk_list"]]
if args.add_mid_to_speaker:
meeting_spk_list = ["{}_{}".format(mid, x) if not x.startswith(mid) else x for x in meeting_spk_list]
if labels.shape[0] != feat.shape[0]:
min_len = min(labels.shape[0], feat.shape[0])
labels, feat = labels[:min_len], feat[:min_len]
logger.warning("{}: The expected frame_len is {}, but got {}, clip both to {}".format(
mid, labels.shape[0], feat.shape[0], min_len))
num_frame = feat.shape[0]
num_chunk = int(np.ceil(float(num_frame - chunk_size) / chunk_shift)) + 1
for i in range(num_chunk):
st, ed = i*chunk_shift, i*chunk_shift+chunk_size
utt_id = "{}-{:05d}-{:05d}".format(mid, st, ed)
chunk_feat = feat[st: ed, :]
chunk_label = labels[st: ed, :]
frame_pad = chunk_size - chunk_label.shape[0]
spk_pad = args.n_spk - chunk_label.shape[1]
chunk_feat = np.pad(chunk_feat, [(0, frame_pad), (0, 0)], "constant", constant_values=0)
chunk_label = np.pad(chunk_label, [(0, frame_pad), (0, spk_pad)], "constant", constant_values=0)
feat_writer(utt_id, chunk_feat)
spk_idx = list(range(max(args.n_spk, len(meeting_spk_list))))
spk_list = []
if args.mode == "train":
random.shuffle(spk_idx)
if args.n_spk > len(meeting_spk_list):
n = random.randint(len(meeting_spk_list), args.n_spk)
spk_list.extend(meeting_spk_list)
while len(spk_list) < n:
spk = random.choice(train_spk_list)
if spk not in spk_list:
spk_list.append(spk)
spk_list.extend(["None"] * (args.n_spk - n))
else:
raise ValueError("Argument n_spk is too small ({} < {}).".format(args.n_spk, len(meeting_spk_list)))
else:
spk_list.extend(meeting_spk_list)
spk_list.extend(["None"] * max(args.n_spk - len(meeting_spk_list), 0))
xvec_list = []
for i, spk in enumerate(spk_list):
if spk == "None":
spk_xvec = np.zeros((ivc_dim,), dtype=np.float32)
elif spk not in spk2utt:
# speaker with very short duration
spk_xvec = np.zeros((ivc_dim,), dtype=np.float32)
# chunk_label[:, i] = 0
if spk not in non_present_spk_list:
logging.warning("speaker {}/{} does not appear in spk2utt, since it has very short duration.".format(mid, spk))
non_present_spk_list.append(spk)
else:
spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 3000)[np.newaxis, :]
xvec_list.append(spk_xvec)
xvec = np.row_stack(xvec_list)
# shuffle speaker embedding according spk_idx
xvec = xvec[spk_idx, :]
ivc_writer(utt_id, xvec)
# shuffle labels according spk_idx
feat_label = chunk_label[:, spk_idx]
# feat_label = np.sum(feat_label * label_weights[np.newaxis, :chunk_label.shape[1]], axis=1).astype(str).tolist()
label_writer(utt_id, feat_label.astype(np.float32))
frames_list.append((mid, feat.shape[0]))
logger.info("{:30s}: {:6d} frames split into {:3d} chunks.".format(mid, num_frame, num_chunk))
return frames_list
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dir", required=True, type=str, default=None,
help="feats.scp")
parser.add_argument("--out", required=True, type=str, default=None,
help="The prefix of dumpped files.")
parser.add_argument("--n_spk", type=int, default=16)
parser.add_argument("--use_lfr", default=False, action="store_true")
parser.add_argument("--no_pbar", default=False, action="store_true")
parser.add_argument("--sr", type=int, default=8000)
parser.add_argument("--chunk_size", type=int, default=1600)
parser.add_argument("--chunk_shift", type=int, default=400)
parser.add_argument("--mode", type=str, default="train", choices=["train", "test"])
parser.add_argument("--task_id", type=int, default=0)
parser.add_argument("--task_size", type=int, default=-1)
parser.add_argument("--add_mid_to_speaker", type=bool, default=False)
args = parser.parse_args()
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
if not os.path.exists(os.path.dirname(args.out)):
os.makedirs(os.path.dirname(args.out))
feat_scp = load_scp_as_list(os.path.join(args.dir, "feats.scp"))
if args.task_size > 0:
feat_scp = feat_scp[args.task_size*args.task_id: args.task_size*(args.task_id+1)]
labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
utt2frames = load_scp_as_dict(os.path.join(args.dir, "utt2num_frames"))
spk2utt = {}
for utt, spk in utt2spk.items():
if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
if spk not in spk2utt:
spk2utt[spk] = []
spk2utt[spk].append(utt)
logging.info("Obtain {} speakers.".format(len(spk2utt)))
logging.info("Task {:02d}: start dump {} meetings.".format(args.task_id, len(feat_scp)))
# random.shuffle(feat_scp)
meeting_lens = process(feat_scp, labels_scp, spk2utt, utt2xvec, utt2frames, args)
total_frames = sum([x[1] for x in meeting_lens])
logging.info("Total meetings: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,48 @@
import os
import argparse
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
assert isinstance(parser, argparse.ArgumentParser)
parser.add_argument("wav_scp", type=str)
parser.add_argument("out_dir", type=str)
args = parser.parse_args()
# assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
wav_scp = load_scp_as_list(args.wav_scp)
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
return wav_scp, None, args
def post(self, result_list, args):
count = [0, 0]
for result in result_list:
count[0] += result[0]
count[1] += result[1]
print("All threads done, {} success, {} failed.".format(count[0], count[1]))
def process(task_args):
task_id, task_list, _, args = task_args
count = [0, 0]
for utt_id, cmd in task_list:
try:
wav_path = os.path.join(args.out_dir, "{}.wav".format(utt_id))
cmd = cmd.replace("|", "> {}".format(wav_path))
os.system(cmd)
count[0] += 1
except:
print("Failed execute command for {}.".format(utt_id))
count[1] += 1
return count
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -0,0 +1,117 @@
import kaldiio
from tqdm import tqdm
import os
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import numpy as np
import argparse
from kaldiio import WriteHelper
def calc_global_ivc(spk, spk2utt, utt2ivc):
ivc_list = [kaldiio.load_mat(utt2ivc[utt])[np.newaxis, :] for utt in spk2utt[spk]]
ivc = np.concatenate(ivc_list, axis=0)
ivc = np.mean(ivc, axis=0, keepdims=False)
return ivc
def process(idx_scp, spk2utt, utt2xvec, meeting2spk_list, args):
out_prefix = args.out
ivc_dim = 256
print("ivc_dim = {}".format(ivc_dim))
out_prefix = out_prefix + "_parts00_xvec"
ivc_writer = WriteHelper(f"ark,scp,f:{out_prefix}.ark,{out_prefix}.scp")
idx_writer = open(out_prefix + ".idx", "wt")
spk2xvec = {}
if args.emb_type == "global":
print("Use global speaker embedding.")
for spk in spk2utt.keys():
spk2xvec[spk] = calc_global_ivc(spk, spk2utt, utt2xvec)[np.newaxis, :]
frames_list = []
for utt_id in tqdm(idx_scp, total=len(idx_scp), ascii=True, disable=args.no_pbar):
mid = utt_id.split("-")[0]
idx_writer.write(utt_id+"\n")
xvec_list = []
for spk in meeting2spk_list[mid]:
spk_xvec = spk2xvec[spk]
xvec_list.append(spk_xvec)
for _ in range(args.n_spk - len(xvec_list)):
xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
xvec = np.row_stack(xvec_list)
ivc_writer(utt_id, xvec)
frames_list.append((mid, 1))
return frames_list
def calc_spk_list(rttms):
spk_list = []
for one_line in rttms:
parts = [x for x in one_line.strip().split(" ") if x != ""]
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
if spk_name.isdigit():
spk_name = "{}_S{:03d}".format(mid, int(spk_name))
else:
spk_name = "{}_{}".format(mid, spk_name)
if spk_name not in spk_list:
spk_list.append(spk_name)
return spk_list
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dir", required=True, type=str, default=None,
help="feats.scp")
parser.add_argument("--out", required=True, type=str, default=None,
help="The prefix of dumpped files.")
parser.add_argument("--n_spk", type=int, default=4)
parser.add_argument("--no_pbar", default=False, action="store_true")
parser.add_argument("--sr", type=int, default=16000)
parser.add_argument("--emb_type", type=str, default="rand")
args = parser.parse_args()
if not os.path.exists(os.path.dirname(args.out)):
os.makedirs(os.path.dirname(args.out))
idx_scp = open(os.path.join(args.dir, "idx"), "r").readlines()
idx_scp = [x.strip() for x in idx_scp]
meeting2rttms = {}
for one_line in open(os.path.join(args.dir, "sys.rttm"), "rt"):
parts = [x for x in one_line.strip().split(" ") if x != ""]
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
if mid not in meeting2rttms:
meeting2rttms[mid] = []
meeting2rttms[mid].append(one_line)
utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
spk2utt = {}
for utt, spk in utt2spk.items():
if utt in utt2xvec:
if spk not in spk2utt:
spk2utt[spk] = []
spk2utt[spk].append(utt)
meeting2spk_list = {}
for mid, rttms in meeting2rttms.items():
meeting2spk_list[mid] = calc_spk_list(rttms)
new_spk_list = []
for spk in meeting2spk_list[mid]:
if spk in spk2utt:
new_spk_list.append(spk)
if len(new_spk_list) != len(meeting2spk_list[mid]):
print("{}: Reduce speaker number from {}(according rttm) to {}(according x-vectors)".format(
mid, len(meeting2spk_list[mid]), len(new_spk_list)))
meeting2spk_list[mid] = new_spk_list
meeting_lens = process(idx_scp, spk2utt, utt2xvec, meeting2spk_list, args)
print("Total meetings: {:6d}".format(len(meeting_lens)))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,54 @@
import os
import sys
import torch
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
required=True,
default=None,
type=str,
help="Director contains saved models."
)
parser.add_argument(
"--average_epochs",
nargs="+",
type=int,
default=[],
)
parser.add_argument(
"--metric_name",
type=str,
default="der",
help="The metric name of best models, only used for name."
)
args = parser.parse_args()
root_path = args.model_dir
idx_list = args.average_epochs
n_models = len(idx_list)
metric = args.metric_name
if n_models > 0:
avg = None
for idx in idx_list:
model_file = os.path.join(root_path, "{}epoch.pth".format(str(idx)))
states = torch.load(model_file, map_location="cpu")
if avg is None:
avg = states
else:
for k in avg:
avg[k] = avg[k] + states[k]
for k in avg:
if str(avg[k].dtype).startswith("torch.int"):
pass
else:
avg[k] = avg[k] / n_models
output_file = os.path.join(root_path, "valid.{}.ave_{}best.pth".format(metric, n_models))
torch.save(avg, output_file)
else:
print("Number of models to average is 0, skip.")

View File

@ -0,0 +1,116 @@
import numpy as np
import os
import argparse
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import librosa
import soundfile as sf
from tqdm import tqdm
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
assert isinstance(parser, argparse.ArgumentParser)
parser.add_argument("wav_scp", type=str)
parser.add_argument("rttm_scp", type=str)
parser.add_argument("out_dir", type=str)
parser.add_argument("--min_dur", type=float, default=2.0)
parser.add_argument("--max_spk_num", type=int, default=4)
args = parser.parse_args()
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
wav_scp = load_scp_as_list(args.wav_scp)
meeting2rttms = {}
for one_line in open(args.rttm_scp, "rt"):
parts = [x for x in one_line.strip().split(" ") if x != ""]
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
if mid not in meeting2rttms:
meeting2rttms[mid] = []
meeting2rttms[mid].append(one_line)
task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
return task_list, None, args
def post(self, result_list, args):
count = [0, 0]
for result in result_list:
count[0] += result[0]
count[1] += result[1]
print("Found {} speakers, extracted {}.".format(count[1], count[0]))
# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
labels = np.zeros([max_spk_num, length], int)
spk_list = []
for one_line in rttms:
parts = [x for x in one_line.strip().split(" ") if x != ""]
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
if spk_name.isdigit():
spk_name = "{}_S{:03d}".format(mid, int(spk_name))
else:
spk_name = "{}_{}".format(mid, spk_name)
if spk_name not in spk_list:
spk_list.append(spk_name)
st, dur = int(st*sr), int(dur*sr)
idx = spk_list.index(spk_name)
labels[idx, st:st+dur] = 1
return labels, spk_list
def get_nonoverlap_turns(multi_label, spk_list):
turns = []
label = np.sum(multi_label, axis=0) == 1
spk, in_turn, st = None, False, 0
for i in range(len(label)):
if not in_turn and label[i]:
st, in_turn = i, True
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
if in_turn:
if not label[i]:
in_turn = False
turns.append([st, i, spk])
elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
turns.append([st, i, spk])
st, in_turn = i, True
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
if in_turn:
turns.append([st, len(label), spk])
return turns
def process(task_args):
task_id, task_list, _, args = task_args
spk_count = [0, 0]
for mid, wav_path, rttms in task_list:
wav = librosa.load(wav_path, args.sr)[0] * 32767
multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
turns = get_nonoverlap_turns(multi_label, spk_list)
extracted_spk = []
count = 1
for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
if (ed - st) >= args.min_dur * args.sr and len(wav[st: ed]) >= args.min_dur * args.sr:
seg = wav[st: ed]
save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
count += 1
if spk not in extracted_spk:
extracted_spk.append(spk)
if len(extracted_spk) != len(spk_list):
print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
))
spk_count[0] += len(extracted_spk)
spk_count[1] += len(spk_list)
return spk_count
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -0,0 +1,63 @@
import numpy as np
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import librosa
import soundfile as sf
class MyRunner(MultiProcessRunnerV3):
def prepare(self, parser):
parser.add_argument("dir", type=str)
parser.add_argument("out_dir", type=str)
args = parser.parse_args()
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
meeting_scp = load_scp_as_list(os.path.join(args.dir, "reco.scp"))
vad_file = open(os.path.join(args.dir, "segments"))
meeting2vad = {}
for one_line in vad_file:
uid, mid, st, ed = one_line.strip().split(" ")
st, ed = int(float(st) * args.sr), int(float(ed) * args.sr)
if mid not in meeting2vad:
meeting2vad[mid] = []
meeting2vad[mid].append((uid, st, ed))
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
for mid, _ in meeting_scp:
if mid not in meeting2vad:
print("Recording {} doesn't contains speech segments".format(mid))
task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp if mid in meeting2vad]
return task_list, None, args
def post(self, results_list, args):
pass
def process(task_args):
_, task_list, _, args = task_args
for mid, wav_path, vad_list in task_list:
wav = librosa.load(wav_path, args.sr, True)[0] * 32767
seg_list = []
pos_map = []
offset = 0
for uid, st, ed in vad_list:
seg_list.append(wav[st: ed])
pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st))
offset = offset + ed - st
out = np.concatenate(seg_list, axis=0)
save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
map_path = os.path.join(args.out_dir, "{}.pos".format(mid))
with open(map_path, "wt") as fd:
fd.writelines(pos_map)
# print mid
return None
if __name__ == '__main__':
my_runner = MyRunner(process)
my_runner.run()

View File

@ -0,0 +1,216 @@
import argparse
import numpy as np
import librosa
import soundfile as sf
import os
import random
import json
from funasr.utils.misc import load_scp_as_dict, load_scp_as_list
from tqdm import tqdm
def mix_wav_noise(wav, noise, snr):
n_repeat = len(wav) // len(noise) + 1
noise = np.repeat(noise, n_repeat, axis=0)
st = random.randint(0, len(noise) - len(wav))
noise = noise[st: st+len(wav)]
wav_mag = np.linalg.norm(wav, ord=2)
noise_mag = np.linalg.norm(noise, ord=2)
scale = wav_mag / (10 ** (float(snr) / 20))
noise = noise / noise_mag * scale
check_snr = 20 * np.log10(np.linalg.norm(wav, ord=2) / np.linalg.norm(noise, ord=2))
if abs(check_snr - snr) >= 1e-2:
print("SNR: {:.4f}, real SNR: {:.4f}".format(snr, check_snr))
return wav + noise
def calc_labels(rttms, args):
turns = []
total_length = 0
for spk, st, dur in rttms:
if args.random_interval:
# random shift the interval with 20% of duration
x = random.uniform(-dur*0.2, dur*0.2)
st = max(0, st + x)
# random squeeze or extend the interval
dur += random.uniform(-dur*0.5, dur*0.5)
if st + dur > total_length:
total_length = st + dur
turns.append([spk, st, dur])
# resort the turns according start point
turns = sorted(turns, key=lambda x: x[1])
spk_list = []
for spk, st, dur in turns:
if spk not in spk_list:
spk_list.append(spk)
total_length = int(total_length * args.sr)
labels = np.zeros((len(spk_list), total_length), float)
for spk, org_st, org_dur in turns:
# random re-assign speaker to make more various samples
st, dur = int(org_st * args.sr), int(org_dur * args.sr)
if args.random_assign_spk:
spk = random.choice(spk_list)
spk_id = spk_list.index(spk)
labels[spk_id, st:st+dur] = 1.0
new_turns = []
for i in range(len(spk_list)):
st = 0
in_interval = False
for j in range(total_length):
if labels[i, j] == 1 and not in_interval:
in_interval = True
st = j
if (labels[i, j] == 0 or j == total_length-1) and in_interval:
in_interval = False
new_turns.append((spk_list[i], float(st) / args.sr, float(j - st) / args.sr))
new_turns = sorted(new_turns, key=lambda x: x[1])
return labels, spk_list, new_turns
def save_wav(data, wav_path, sr):
if np.max(np.abs(data)).item() > 32767:
data = data / np.max(np.abs(data)) * 0.9 * 32767
sf.write(wav_path, data.astype(np.int16), sr, "PCM_16", "LITTLE", "WAV", True)
def build(mid, meeting2rttm, spk2wav, noise_scp, room2rirs, args):
mid = "m{:05d}".format(mid+1)
if args.corpus_name is not None:
mid = args.corpus_name + "_" + mid
org_reco_id = random.choice(meeting2rttm.keys())
rttms = meeting2rttm[org_reco_id]
labels, org_spk_list, org_turns = calc_labels(rttms, args)
n_spk = len(org_spk_list)
expected_length = labels.shape[1]
meeting_spk_list = random.sample(spk2wav.keys(), n_spk)
spk_mask = (np.sum(labels, axis=1) > 0).astype(int)
pos_spk_list = [spk for spk, mask in zip(meeting_spk_list, spk_mask) if mask == 1]
noise_id, noise_path = random.choice(noise_scp)
noise_wav = librosa.load(noise_path, args.sr, True)[0] * 32767
snr = random.choice(args.snr_list)
room_id = random.choice(room2rirs.keys())
# different speakers can locate at the same position a.k.a. the same rir.
rir_list = [random.choice(room2rirs[room_id]) for _ in range(n_spk)]
mata = {
"id": mid,
"num_spk": n_spk,
"pos_spk": pos_spk_list,
"spk_list": meeting_spk_list,
"seg_info": [],
"noise": noise_id,
"snr": snr,
"length": expected_length,
"meeting_info": org_reco_id,
"room_id": room_id
}
sig = np.zeros((expected_length, ), dtype=np.float32)
for i, spk in enumerate(meeting_spk_list):
if spk in pos_spk_list:
wav = librosa.load(spk2wav[spk], args.sr, True)[0] * 32767
if len(wav) <= expected_length:
# NOTE: to repeat an array, use np.tile rather than np.repeats
wav = np.tile(wav, expected_length // len(wav) + 1)
spk_st = np.random.randint(0, len(wav) - expected_length)
spk_sig = wav[spk_st: spk_st+expected_length]
spk_sig = spk_sig * labels[i, :]
rir_wav = librosa.load(rir_list[i][1], args.sr, True)[0] * 32767
spk_sig = np.convolve(spk_sig, rir_wav, "full")[:expected_length]
mata["seg_info"].append([spk, spk_st, rir_list[i][0]])
sig += spk_sig
mix = mix_wav_noise(sig, noise_wav, snr)
if np.max(np.abs(mix)).item() > 32767:
mix = mix / np.max(np.abs(mix)) * 0.9 * 32767
save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
sf.write(save_path, mix.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
rttm_file = open(os.path.join(args.out_dir, "{}.rttm".format(mid)), "wt")
for spk, st, dur in org_turns:
rttm_file.write("SPEAKER {} 0 {:.3f} {:.3f} <NA> <NA> {} <NA> <NA>{}".format(
mid, st, dur, meeting_spk_list[org_spk_list.index(spk)], os.linesep))
rttm_file.close()
return mata, mix, labels
def filter_spk_num(meeting2rttm, reco2num_spk, spk_num):
meeting_list = meeting2rttm.keys()
filtered_list = list(filter(lambda x: int(reco2num_spk[x]) == spk_num, meeting_list))
new_dict = {key: meeting2rttm[key] for key in filtered_list}
print("Keep {} out of {} according to speaker number {}".format(len(new_dict), len(meeting2rttm), spk_num))
return new_dict
def main():
parser = argparse.ArgumentParser()
parser.add_argument("dir", type=str)
parser.add_argument("out_dir", type=str)
parser.add_argument("--total_mix", type=int, default=1)
parser.add_argument("--sr", type=int, default=8000)
parser.add_argument("--snr_list", type=int, default=[15, 20, 25], nargs="+")
parser.add_argument("--spk_num", type=int, default=0)
parser.add_argument("--corpus_name", type=str, default=None)
parser.add_argument("--task_id", type=int, default=0)
parser.add_argument("--no_bar", action="store_true", default=False)
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--random_assign_spk", action="store_true", default=False)
parser.add_argument("--random_interval", action="store_true", default=False)
args = parser.parse_args()
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
# SPEAKER iaaa 0 0 1.08 <NA> <NA> B <NA> <NA>
meeting2rttm = {}
for one_line in open(os.path.join(args.dir, "ref.rttm")):
parts = one_line.strip().split(" ")
mid, spk, st, dur = parts[1], parts[7], float(parts[3]), float(parts[4])
if mid not in meeting2rttm:
meeting2rttm[mid] = []
meeting2rttm[mid].append((spk, st, dur))
reco2num_spk = load_scp_as_dict(os.path.join(args.dir, "reco2num_spk"))
if args.spk_num > 1:
meeting2rttm = filter_spk_num(meeting2rttm, reco2num_spk, args.spk_num)
spk2wav = load_scp_as_dict(os.path.join(args.dir, "spk2wav.scp"))
noise_scp = load_scp_as_list(os.path.join(args.dir, "noise.scp"))
rirs_scp = load_scp_as_list(os.path.join(args.dir, "rirs.scp"))
room2rirs = {}
for rir_id, rir_path in rirs_scp:
room_id = rir_id.rsplit("-", 1)[0]
if room_id not in room2rirs:
room2rirs[room_id] = []
room2rirs[room_id].append((rir_id, rir_path))
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
task_list = list(range(args.task_id * args.total_mix, (args.task_id + 1) * args.total_mix))
mata_data = []
total = 0
if args.debug:
one, wav, label = build(0, meeting2rttm, spk2wav, noise_scp, room2rirs, args)
mata_data.append(one)
else:
for mid in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_bar):
one, wav, label = build(mid, meeting2rttm, spk2wav, noise_scp, room2rirs, args)
mata_data.append(one)
total += one["length"]
if args.verbose:
print("File name: {:20s}, segment num: {:5d}, speaker num: {:2d}, duration: {:7.2f}m".format(
one["id"], len(one["seg_info"]), one["num_spk"], float(one["length"]) / args.sr / 60))
print("Total files: {}, Total duration: {:.2f} hours".format(args.total_mix, (1.0 * total / args.sr / 3600)))
json.dump(mata_data, open(os.path.join(args.out_dir, "mata.{}.json".format(args.task_id)), "wt"),
ensure_ascii=False, encoding='utf-8', indent=4, sort_keys=True)
if __name__ == '__main__':
main()

View File

@ -453,11 +453,17 @@ def get_parser():
help="The batch size for inference",
)
group.add_argument(
"--diar_smooth_size",
"--smooth_size",
type=int,
default=121,
help="The smoothing size for post-processing"
)
group.add_argument(
"--dur_threshold",
type=int,
default=10,
help="The threshold of minimum duration"
)
return parser

View File

@ -3,7 +3,7 @@ import logging
import torch
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr.models.e2e_diar_sond import DiarSondModel
@ -26,6 +26,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.models.specaug.abs_profileaug import AbsProfileAug
from funasr.models.specaug.profileaug import ProfileAug
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.torch_utils.initialize import initialize
@ -52,6 +54,15 @@ specaug_choices = ClassChoices(
default=None,
optional=True,
)
profileaug_choices = ClassChoices(
name="profileaug",
classes=dict(
profileaug=ProfileAug,
),
type_check=AbsProfileAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
@ -64,7 +75,8 @@ normalize_choices = ClassChoices(
label_aggregator_choices = ClassChoices(
"label_aggregator",
classes=dict(
label_aggregator=LabelAggregate
label_aggregator=LabelAggregate,
label_aggregator_max_pool=LabelAggregateMaxPooling,
),
default=None,
optional=True,
@ -155,6 +167,8 @@ class_choices_list = [
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --profileaug and --profileaug_conf
profileaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --label_aggregator and --label_aggregator_conf
@ -217,6 +231,13 @@ def build_diar_model(args):
else:
specaug = None
# Data augmentation for Profiles
if hasattr(args, "profileaug") and args.profileaug is not None:
profileaug_class = profileaug_choices.get_class(args.profileaug)
profileaug = profileaug_class(**args.profileaug_conf)
else:
profileaug = None
# normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
@ -261,6 +282,7 @@ def build_diar_model(args):
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
profileaug=profileaug,
normalize=normalize,
label_aggregator=label_aggregator,
encoder=encoder,

View File

@ -6,8 +6,9 @@ from typing import Union
import numpy as np
import torch
from funasr.modules.nets_utils import pad_list
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.modules.nets_utils import pad_list, pad_list_all_dim
class CommonCollateFn:
@ -77,6 +78,78 @@ def common_collate_fn(
output = (uttids, output)
return output
class DiarCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
max_sample_size=None
):
assert check_argument_types()
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.max_sample_size = max_sample_size
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})"
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
return diar_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
)
def diar_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
"""
assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
tensor = pad_list_all_dim(tensor_list, pad_value)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
assert check_return_type(output)
return output
def crop_to_max_size(feature, target_size):
size = len(feature)
diff = size - target_size

View File

@ -1,7 +1,8 @@
import torch
from typing import Optional
from typing import Tuple
from typeguard import check_argument_types
from torch.nn import functional as F
from funasr.modules.nets_utils import make_pad_mask
@ -78,3 +79,37 @@ class LabelAggregate(torch.nn.Module):
olens = None
return output.to(input.dtype), olens
class LabelAggregateMaxPooling(torch.nn.Module):
def __init__(
self,
hop_length: int = 8,
):
assert check_argument_types()
super().__init__()
self.hop_length = hop_length
def extra_repr(self):
return (
f"hop_length={self.hop_length}, "
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""LabelAggregate forward function.
Args:
input: (Batch, Nsamples, Label_dim)
ilens: (Batch)
Returns:
output: (Batch, Frames, Label_dim)
"""
output = F.max_pool1d(input.transpose(1, 2), self.hop_length, self.hop_length).transpose(1, 2)
olens = ilens // self.hop_length
return output.to(input.dtype), olens

View File

@ -75,10 +75,10 @@ class SequenceBinaryCrossEntropy(nn.Module):
self.criterion = criterion
def forward(self, pred, label, lengths):
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask, 0).sum() / denom
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
class NllLoss(nn.Module):

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import logging
import random
from contextlib import contextmanager
from distutils.version import LooseVersion
from itertools import permutations
@ -12,6 +13,7 @@ from typing import Tuple, List
import numpy as np
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
@ -19,11 +21,13 @@ from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.abs_profileaug import AbsProfileAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
from funasr.utils.hinter import hint_once
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@ -35,12 +39,8 @@ else:
class DiarSondModel(FunASRModel):
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""Speaker overlap-aware neural diarization model
reference: https://arxiv.org/abs/2211.10243
"""
def __init__(
@ -48,6 +48,7 @@ class DiarSondModel(FunASRModel):
vocab_size: int,
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
profileaug: Optional[AbsProfileAug],
normalize: Optional[AbsNormalize],
encoder: torch.nn.Module,
speaker_encoder: Optional[torch.nn.Module],
@ -64,7 +65,11 @@ class DiarSondModel(FunASRModel):
speaker_discrimination_loss_weight: float = 1.0,
inter_score_loss_weight: float = 0.0,
inputs_type: str = "raw",
model_regularizer_weight: float = 0.0,
freeze_encoder: bool = False,
onfly_shuffle_speaker: bool = True,
):
assert check_argument_types()
super().__init__()
@ -75,12 +80,16 @@ class DiarSondModel(FunASRModel):
self.normalize = normalize
self.frontend = frontend
self.specaug = specaug
self.profileaug = profileaug
self.label_aggregator = label_aggregator
self.decoder = decoder
self.token_list = token_list
self.max_spk_num = max_spk_num
self.normalize_speech_speaker = normalize_speech_speaker
self.ignore_id = ignore_id
self.model_regularizer_weight = model_regularizer_weight
self.freeze_encoder = freeze_encoder
self.onfly_shuffle_speaker = onfly_shuffle_speaker
self.criterion_diar = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
@ -95,14 +104,45 @@ class DiarSondModel(FunASRModel):
self.inter_score_loss_weight = inter_score_loss_weight
self.forward_steps = 0
self.inputs_type = inputs_type
self.to_regularize_parameters = None
def get_regularize_parameters(self):
to_regularize_parameters, normal_parameters = [], []
for name, param in self.named_parameters():
if ("encoder" in name and "weight" in name and "bn" not in name and
("conv2" in name or "conv1" in name or "conv_sc" in name or "dense" in name)
):
to_regularize_parameters.append((name, param))
else:
normal_parameters.append((name, param))
self.to_regularize_parameters = to_regularize_parameters
return to_regularize_parameters, normal_parameters
def generate_pse_embedding(self):
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float32)
for idx, pse_label in enumerate(self.token_list):
emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float)
emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float32)
embedding[idx] = emb
return torch.from_numpy(embedding)
def rand_permute_speaker(self, raw_profile, raw_binary_labels):
"""
raw_profile: B, N, D
raw_binary_labels: B, T, N
"""
assert raw_profile.shape[1] == raw_binary_labels.shape[2], \
"Num profile: {}, Num label: {}".format(raw_profile.shape[1], raw_binary_labels.shape[-1])
profile = torch.clone(raw_profile)
binary_labels = torch.clone(raw_binary_labels)
bsz, num_spk = profile.shape[0], profile.shape[1]
for i in range(bsz):
idx = list(range(num_spk))
random.shuffle(idx)
profile[i] = profile[i][idx, :]
binary_labels[i] = binary_labels[i][:, idx]
return profile, binary_labels
def forward(
self,
speech: torch.Tensor,
@ -113,6 +153,7 @@ class DiarSondModel(FunASRModel):
binary_labels_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
Args:
speech: (Batch, samples) or (Batch, frames, input_size)
speech_lengths: (Batch,) default None for chunk interator,
@ -127,13 +168,38 @@ class DiarSondModel(FunASRModel):
"""
assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
batch_size = speech.shape[0]
if self.freeze_encoder:
hint_once("Freeze encoder", "freeze_encoder", rank=0)
self.encoder.eval()
self.forward_steps = self.forward_steps + 1
if self.pse_embedding.device != speech.device:
self.pse_embedding = self.pse_embedding.to(speech.device)
self.power_weight = self.power_weight.to(speech.device)
self.int_token_arr = self.int_token_arr.to(speech.device)
# 1. Network forward
if self.onfly_shuffle_speaker:
hint_once("On-the-fly shuffle speaker permutation.", "onfly_shuffle_speaker", rank=0)
profile, binary_labels = self.rand_permute_speaker(profile, binary_labels)
# 0a. Aggregate time-domain labels to match forward outputs
if self.label_aggregator is not None:
binary_labels, binary_labels_lengths = self.label_aggregator(
binary_labels, binary_labels_lengths
)
# 0b. augment profiles
if self.profileaug is not None and self.training:
speech, profile, binary_labels = self.profileaug(
speech, speech_lengths,
profile, profile_lengths,
binary_labels, binary_labels_lengths
)
# 1. Calculate power-set encoding (PSE) labels
pad_bin_labels = F.pad(binary_labels, (0, self.max_spk_num - binary_labels.shape[2]), "constant", 0.0)
raw_pse_labels = torch.sum(pad_bin_labels * self.power_weight, dim=2, keepdim=True)
pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
# 2. Network forward
pred, inter_outputs = self.prediction_forward(
speech, speech_lengths,
profile, profile_lengths,
@ -141,15 +207,6 @@ class DiarSondModel(FunASRModel):
)
(speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
# 2. Aggregate time-domain labels to match forward outputs
if self.label_aggregator is not None:
binary_labels, binary_labels_lengths = self.label_aggregator(
binary_labels, binary_labels_lengths
)
# 2. Calculate power-set encoding (PSE) labels
raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
# If encoder uses conv* as input_layer (i.e., subsampling),
# the sequence length of 'pred' might be slightly less than the
# length of 'spk_labels'. Here we force them to be equal.
@ -165,9 +222,14 @@ class DiarSondModel(FunASRModel):
loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
regularizer_loss = None
if self.model_regularizer_weight > 0 and self.to_regularize_parameters is not None:
regularizer_loss = self.calculate_regularizer_loss()
label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device)
loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
+ self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
# if regularizer_loss is not None:
# loss = loss + regularizer_loss * self.model_regularizer_weight
(
correct,
@ -204,6 +266,7 @@ class DiarSondModel(FunASRModel):
loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
regularizer_loss=regularizer_loss.detach() if regularizer_loss is not None else None,
sad_mr=sad_mr,
sad_fr=sad_fr,
mi=mi,
@ -217,6 +280,12 @@ class DiarSondModel(FunASRModel):
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def calculate_regularizer_loss(self):
regularizer_loss = 0.0
for name, param in self.to_regularize_parameters:
regularizer_loss = regularizer_loss + torch.norm(param, p=2)
return regularizer_loss
def classification_loss(
self,
predictions: torch.Tensor,
@ -388,6 +457,7 @@ class DiarSondModel(FunASRModel):
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
@ -487,4 +557,4 @@ class DiarSondModel(FunASRModel):
speaker_miss,
speaker_falarm,
speaker_error,
)
)

View File

@ -0,0 +1,22 @@
from typing import Optional
from typing import Tuple
import torch
class AbsProfileAug(torch.nn.Module):
"""Abstract class for the augmentation of profile
The process-flow:
Frontend --> SpecAug -> Normalization -> Encoder -> Decoder
`-> ProfileAug -> Speaker Encoder --'
"""
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor = None,
profile: torch.Tensor = None, profile_lengths: torch.Tensor = None,
binary_labels: torch.Tensor = None, labels_length: torch.Tensor = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
raise NotImplementedError

View File

@ -0,0 +1,122 @@
from typing import Tuple, Optional
import numpy as np
import torch
from torch.nn import functional as F
from funasr.models.specaug.abs_profileaug import AbsProfileAug
class ProfileAug(AbsProfileAug):
"""
Implement the augmentation for profiles including:
- Split aug: split one profile into two profiles, i.e., main and inaccurate, labels assigned to main
- Merge aug: merge two profiles into one, labels are also merged into one, the other set to zero
- Disturb aug: disturb some profile with others to simulate the inaccurate clustering centroids.
"""
def __init__(
self,
apply_split_aug: bool = True,
split_aug_prob: float = 0.05,
apply_merge_aug: bool = True,
merge_aug_prob: float = 0.2,
apply_disturb_aug: bool = True,
disturb_aug_prob: float = 0.4,
disturb_alpha: float = 0.2,
) -> None:
super().__init__()
self.apply_split_aug = apply_split_aug
self.split_aug_prob = split_aug_prob
self.apply_merge_aug = apply_merge_aug
self.merge_aug_prob = merge_aug_prob
self.apply_disturb_aug = apply_disturb_aug
self.disturb_aug_prob = disturb_aug_prob
self.disturb_alpha = disturb_alpha
def split_aug(self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor):
# B, N
bsz, dim = profile.shape[0], profile.shape[-1]
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
spk_count = binary_labels.sum(dim=1)
prob = np.random.rand(bsz)
batch_indices = np.nonzero(prob < self.split_aug_prob)[0]
for idx in batch_indices:
valid_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
pad_spk_idx = torch.nonzero((spk_count[idx] == 0) * mask[idx])
if len(valid_spk_idx) == 0 or len(pad_spk_idx) == 0:
continue
split_spk_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
to_cover_idx = pad_spk_idx[torch.randint(len(pad_spk_idx), ())]
disturb_vec = torch.randn((dim,)).to(profile)
disturb_vec = F.normalize(disturb_vec, dim=-1)
profile[idx, to_cover_idx] = F.normalize(profile[idx, split_spk_idx] +
self.disturb_alpha * disturb_vec)
mask[idx, split_spk_idx] = 0
mask[idx, to_cover_idx] = 0
return profile, binary_labels, mask
def merge_aug(self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor):
bsz, dim = profile.shape[0], profile.shape[-1]
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
spk_count = binary_labels.sum(dim=1)
prob = np.random.rand(bsz)
batch_indices = np.nonzero(prob < self.merge_aug_prob)[0]
for idx in batch_indices:
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
if len(valid_spk_idx) == 0:
continue
to_merge = torch.randint(len(valid_spk_idx), (2, ))
spk_idx_1, spk_idx_2 = valid_spk_idx[to_merge[0]], valid_spk_idx[to_merge[1]]
# merge profile
profile[idx, spk_idx_1] = profile[idx, spk_idx_1] + profile[idx, spk_idx_2]
profile[idx, spk_idx_1] = F.normalize(profile[idx, spk_idx_1], dim=-1)
profile[idx, spk_idx_2] = 0
# merge binary labels
binary_labels[idx, :, spk_idx_1] = binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2]
binary_labels[idx, :, spk_idx_1] = (binary_labels[idx, :, spk_idx_1] > 0).to(binary_labels)
binary_labels[idx, :, spk_idx_2] = 0
mask[idx, spk_idx_1] = 0
mask[idx, spk_idx_2] = 0
return profile, binary_labels, mask
def disturb_aug(self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor):
bsz, dim = profile.shape[0], profile.shape[-1]
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
spk_count = binary_labels.sum(dim=1)
prob = np.random.rand(bsz)
batch_indices = np.nonzero(prob < self.disturb_aug_prob)[0]
for idx in batch_indices:
pos_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
if len(pos_spk_idx) == 0 or len(valid_spk_idx) == 0:
continue
to_disturb_idx = pos_spk_idx[torch.randint(len(pos_spk_idx), ())]
disturb_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
alpha = self.disturb_alpha * torch.rand(()).item()
profile[idx, to_disturb_idx] = ((1 - alpha) * profile[idx, to_disturb_idx]
+ alpha * profile[idx, disturb_idx])
profile[idx, to_disturb_idx] = F.normalize(profile[idx, to_disturb_idx], dim=-1)
mask[idx, to_disturb_idx] = 0
return profile, binary_labels, mask
def forward(
self,
speech: torch.Tensor, speech_lengths: torch.Tensor = None,
profile: torch.Tensor = None, profile_lengths: torch.Tensor = None,
binary_labels: torch.Tensor = None, labels_length: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
# copy inputs to avoid inplace-operation
speech, profile, binary_labels = torch.clone(speech), torch.clone(profile), torch.clone(binary_labels)
profile = F.normalize(profile, dim=-1)
profile_mask = torch.ones(profile.shape[:2]).to(profile)
if self.apply_disturb_aug:
profile, binary_labels, profile_mask = self.disturb_aug(profile, binary_labels, profile_mask)
if self.apply_split_aug:
profile, binary_labels, profile_mask = self.split_aug(profile, binary_labels, profile_mask)
if self.apply_merge_aug:
profile, binary_labels, profile_mask = self.merge_aug(profile, binary_labels, profile_mask)
return speech, profile, binary_labels

View File

@ -61,6 +61,48 @@ def pad_list(xs, pad_value):
return pad
def pad_list_all_dim(xs, pad_value):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
num_dim = len(xs[0].shape)
max_len_all_dim = []
for i in range(num_dim):
max_len_all_dim.append(max(x.size(i) for x in xs))
pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value)
for i in range(n_batch):
if num_dim == 1:
pad[i, : xs[i].size(0)] = xs[i]
elif num_dim == 2:
pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i]
elif num_dim == 3:
pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i]
else:
raise ValueError(
"pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim)
)
return pad
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
"""Make mask tensor containing indices of padded part.

View File

@ -1,11 +1,3 @@
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
import argparse
import logging
import os
@ -21,24 +13,26 @@ from typing import Union
import numpy as np
import torch
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.collate_fn import DiarCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_sond import DiarSondModel
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
from funasr.models.e2e_diar_sond import DiarSondModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@ -47,16 +41,21 @@ from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
HuggingFaceTransformersPostEncoder, # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.models.specaug.abs_profileaug import AbsProfileAug
from funasr.models.specaug.profileaug import ProfileAug
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@ -72,7 +71,6 @@ frontend_choices = ClassChoices(
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
wav_frontend_mel23=WavFrontendMel23,
),
type_check=AbsFrontend,
default="default",
@ -87,6 +85,15 @@ specaug_choices = ClassChoices(
default=None,
optional=True,
)
profileaug_choices = ClassChoices(
name="profileaug",
classes=dict(
profileaug=ProfileAug,
),
type_check=AbsProfileAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
@ -100,7 +107,8 @@ normalize_choices = ClassChoices(
label_aggregator_choices = ClassChoices(
"label_aggregator",
classes=dict(
label_aggregator=LabelAggregate
label_aggregator=LabelAggregate,
label_aggregator_max_pool=LabelAggregateMaxPooling,
),
type_check=torch.nn.Module,
default=None,
@ -110,9 +118,8 @@ model_choices = ClassChoices(
"model",
classes=dict(
sond=DiarSondModel,
eend_ola=DiarEENDOLAModel,
),
type_check=FunASRModel,
type_check=torch.nn.Module,
default="sond",
)
encoder_choices = ClassChoices(
@ -130,7 +137,6 @@ encoder_choices = ClassChoices(
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
ecapa_tdnn=ECAPA_TDNN,
eend_ola_transformer=EENDOLATransformerEncoder,
),
type_check=torch.nn.Module,
default="resnet34",
@ -182,15 +188,6 @@ decoder_choices = ClassChoices(
type_check=torch.nn.Module,
default="fsmn",
)
# encoder_decoder_attractor is used for EEND-OLA
encoder_decoder_attractor_choices = ClassChoices(
"encoder_decoder_attractor",
classes=dict(
eda=EncoderDecoderAttractor,
),
type_check=torch.nn.Module,
default="eda",
)
class DiarTask(AbsTask):
@ -203,6 +200,8 @@ class DiarTask(AbsTask):
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --profileaug and --profileaug_conf
profileaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --label_aggregator and --label_aggregator_conf
@ -342,13 +341,15 @@ class DiarTask(AbsTask):
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
@ -378,6 +379,7 @@ class DiarTask(AbsTask):
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
@ -396,10 +398,47 @@ class DiarTask(AbsTask):
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_optimizers(
cls,
args: argparse.Namespace,
model: torch.nn.Module,
) -> List[torch.optim.Optimizer]:
if cls.num_optimizers != 1:
raise RuntimeError(
"build_optimizers() must be overridden if num_optimizers != 1"
)
from funasr.tasks.abs_task import optim_classes
optim_class = optim_classes.get(args.optim)
if optim_class is None:
raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
else:
if (hasattr(model, "model_regularizer_weight") and
model.model_regularizer_weight > 0.0 and
hasattr(model, "get_regularize_parameters")
):
to_regularize_parameters, normal_parameters = model.get_regularize_parameters()
logging.info(f"Set weight decay {model.model_regularizer_weight} for parameters: "
f"{[name for name, value in to_regularize_parameters]}")
module_optim_config = [
{"params": [value for name, value in to_regularize_parameters],
"weight_decay": model.model_regularizer_weight},
{"params": [value for name, value in normal_parameters],
"weight_decay": 0.0}
]
optim = optim_class(module_optim_config, **args.optim_conf)
else:
optim = optim_class(model.parameters(), **args.optim_conf)
optimizers = [optim]
return optimizers
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@ -436,6 +475,13 @@ class DiarTask(AbsTask):
else:
specaug = None
# 2b. Data augmentation for Profiles
if hasattr(args, "profileaug") and args.profileaug is not None:
profileaug_class = profileaug_choices.get_class(args.profileaug)
profileaug = profileaug_class(**args.profileaug_conf)
else:
profileaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
@ -483,6 +529,7 @@ class DiarTask(AbsTask):
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
profileaug=profileaug,
normalize=normalize,
label_aggregator=label_aggregator,
encoder=encoder,
@ -497,7 +544,9 @@ class DiarTask(AbsTask):
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
logging.info(f"Init model parameters with {args.init}.")
assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@ -520,6 +569,7 @@ class DiarTask(AbsTask):
device: Device type, "cpu", "cuda", or "cuda:N".
"""
assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
@ -535,9 +585,9 @@ class DiarTask(AbsTask):
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, FunASRModel):
if not isinstance(model, torch.nn.Module):
raise RuntimeError(
f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
f"model must inherit {torch.nn.Module.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@ -552,13 +602,13 @@ class DiarTask(AbsTask):
if ".bin" in model_name:
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
else:
model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
if os.path.exists(model_name_pth):
logging.info("model_file is load from pth: {}".format(model_name_pth))
model_dict = torch.load(model_name_pth, map_location=device)
else:
model_dict = cls.convert_tf2torch(model, model_file)
model.load_state_dict(model_dict)
# model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
@ -616,287 +666,3 @@ class DiarTask(AbsTask):
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
class EENDOLADiarTask(AbsTask):
# If you need more than 1 optimizer, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
model_choices,
# --encoder and --encoder_conf
encoder_choices,
# --speaker_encoder and --speaker_encoder_conf
encoder_decoder_attractor_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
group.add_argument(
"--split_with_space",
type=str2bool,
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--seg_dict_file",
type=str,
default=None,
help="seg_dict_file for text processing",
)
group.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
group = parser.add_argument_group(description="Preprocess related")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Apply preprocessing to data or not",
)
group.add_argument(
"--token_type",
type=str,
default="char",
choices=["char"],
help="The text will be tokenized in the specified level token",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Scale the maximum amplitude to the given value.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The file path of rir scp file.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="THe probability for applying RIR convolution.",
)
parser.add_argument(
"--cmvn_file",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability applying Noise adding.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
# if args.use_preprocessor:
# retval = CommonPreprocessor(
# train=train,
# token_type=args.token_type,
# token_list=args.token_list,
# bpemodel=None,
# non_linguistic_symbols=None,
# text_cleaner=None,
# g2p_type=None,
# split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
# seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
# # NOTE(kamo): Check attribute existence for backward compatibility
# rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
# rir_apply_prob=args.rir_apply_prob
# if hasattr(args, "rir_apply_prob")
# else 1.0,
# noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
# noise_apply_prob=args.noise_apply_prob
# if hasattr(args, "noise_apply_prob")
# else 1.0,
# noise_db_range=args.noise_db_range
# if hasattr(args, "noise_db_range")
# else "13_15",
# speech_volume_normalize=args.speech_volume_normalize
# if hasattr(args, "rir_scp")
# else None,
# )
# else:
# retval = None
return None
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech", )
else:
# Recognition mode
retval = ("speech", )
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
# 1. frontend
if args.input_size is None or args.frontend == "wav_frontend_mel23":
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(**args.encoder_conf)
# 3. EncoderDecoderAttractor
encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
# 9. Build model
model_class = model_choices.get_class(args.model)
model = model_class(
frontend=frontend,
encoder=encoder,
encoder_decoder_attractor=encoder_decoder_attractor,
**args.model_conf,
)
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
):
"""Build model from the files.
This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
cmvn_file: The cmvn file for front-end
device: Device type, "cpu", "cuda", or "cuda:N".
"""
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
"if the argument 'config_file' is not specified."
)
config_file = Path(model_file).parent / "config.yaml"
else:
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, FunASRModel):
raise RuntimeError(
f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
if model_file is not None:
if device == "cuda":
device = f"cuda:{torch.cuda.current_device()}"
checkpoint = torch.load(model_file, map_location=device)
if "state_dict" in checkpoint.keys():
model.load_state_dict(checkpoint["state_dict"])
else:
model.load_state_dict(checkpoint)
model.to(device)
return model, args

13
funasr/utils/hinter.py Normal file
View File

@ -0,0 +1,13 @@
import sys
import torch.distributed
import logging
HINTED = set()
def hint_once(content, uid, rank=None):
if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank:
if uid not in HINTED:
logging.info(content)
HINTED.add(uid)