mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #799 from alibaba-damo-academy/dev_dzh
Recipe for TOLD/SOND speaker diarization model
This commit is contained in:
commit
c63486e0b9
@ -0,0 +1,133 @@
|
||||
model: sond
|
||||
model_conf:
|
||||
lsm_weight: 0.0
|
||||
length_normalized_loss: true
|
||||
max_spk_num: 16
|
||||
normalize_speech_speaker: true
|
||||
speaker_discrimination_loss_weight: 0
|
||||
inter_score_loss_weight: 0.1
|
||||
model_regularizer_weight: 0.0
|
||||
freeze_encoder: true
|
||||
onfly_shuffle_speaker: false
|
||||
# label aggregator
|
||||
label_aggregator: label_aggregator_max_pool
|
||||
label_aggregator_conf:
|
||||
hop_length: 8
|
||||
|
||||
# speech encoder
|
||||
encoder: resnet34_sp_l2reg
|
||||
encoder_conf:
|
||||
# pass by model, equal to feature dim
|
||||
# input_size: 80
|
||||
batchnorm_momentum: 0.01
|
||||
pooling_type: "window_shift"
|
||||
pool_size: 20
|
||||
stride: 1
|
||||
tf2torch_tensor_name_prefix_torch: encoder
|
||||
tf2torch_tensor_name_prefix_tf: EAND/speech_encoder
|
||||
|
||||
speaker_encoder: null
|
||||
speaker_encoder_conf: {}
|
||||
|
||||
ci_scorer: conv
|
||||
ci_scorer_conf:
|
||||
input_units: 512
|
||||
num_layers: 3
|
||||
num_units: 512
|
||||
kernel_size: 1
|
||||
dropout_rate: 0.0
|
||||
position_encoder: null
|
||||
out_units: 1
|
||||
out_norm: false
|
||||
auxiliary_states: false
|
||||
tf2torch_tensor_name_prefix_torch: ci_scorer
|
||||
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/ci_scorer
|
||||
|
||||
cd_scorer: san
|
||||
cd_scorer_conf:
|
||||
input_size: 512
|
||||
output_size: 512
|
||||
out_units: 1
|
||||
attention_heads: 4
|
||||
linear_units: 1024
|
||||
num_blocks: 4
|
||||
dropout_rate: 0.0
|
||||
positional_dropout_rate: 0.0
|
||||
attention_dropout_rate: 0.0
|
||||
# use string "null" to remove input layer
|
||||
input_layer: "null"
|
||||
pos_enc_class: null
|
||||
normalize_before: true
|
||||
tf2torch_tensor_name_prefix_torch: cd_scorer
|
||||
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/cd_scorer
|
||||
|
||||
# post net
|
||||
decoder: fsmn
|
||||
decoder_conf:
|
||||
in_units: 32
|
||||
out_units: 2517
|
||||
filter_size: 31
|
||||
fsmn_num_layers: 6
|
||||
dnn_num_layers: 1
|
||||
num_memory_units: 16
|
||||
ffn_inner_dim: 512
|
||||
dropout_rate: 0.0
|
||||
tf2torch_tensor_name_prefix_torch: decoder
|
||||
tf2torch_tensor_name_prefix_tf: EAND/post_net
|
||||
|
||||
input_size: 80
|
||||
frontend: null
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: povey
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
filter_length_min: -1
|
||||
filter_length_max: -1
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
dither: 0.0
|
||||
snip_edges: false
|
||||
upsacle_samples: false
|
||||
|
||||
# minibatch related
|
||||
batch_type: unsorted
|
||||
# 16 samples
|
||||
batch_size: 8
|
||||
num_workers: 8
|
||||
max_epoch: 20
|
||||
num_iters_per_epoch: 10000
|
||||
keep_nbest_models: 20
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5.0
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- der
|
||||
- min
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- der
|
||||
- min
|
||||
- - valid
|
||||
- forward_steps
|
||||
- max
|
||||
|
||||
optim: adamw
|
||||
optim_conf:
|
||||
lr: 1.0
|
||||
betas: [0.9, 0.998]
|
||||
weight_decay: 0
|
||||
scheduler: noamlr
|
||||
scheduler_conf:
|
||||
model_size: 512
|
||||
warmup_steps: 10000
|
||||
|
||||
# without spec aug
|
||||
specaug: null
|
||||
|
||||
log_interval: 50
|
||||
# without normalize
|
||||
normalize: null
|
||||
@ -0,0 +1,133 @@
|
||||
model: sond
|
||||
model_conf:
|
||||
lsm_weight: 0.0
|
||||
length_normalized_loss: true
|
||||
max_spk_num: 16
|
||||
normalize_speech_speaker: true
|
||||
speaker_discrimination_loss_weight: 0
|
||||
inter_score_loss_weight: 0.1
|
||||
model_regularizer_weight: 0.0
|
||||
freeze_encoder: false
|
||||
onfly_shuffle_speaker: false
|
||||
# label aggregator
|
||||
label_aggregator: label_aggregator_max_pool
|
||||
label_aggregator_conf:
|
||||
hop_length: 8
|
||||
|
||||
# speech encoder
|
||||
encoder: resnet34_sp_l2reg
|
||||
encoder_conf:
|
||||
# pass by model, equal to feature dim
|
||||
# input_size: 80
|
||||
batchnorm_momentum: 0.01
|
||||
pooling_type: "window_shift"
|
||||
pool_size: 20
|
||||
stride: 1
|
||||
tf2torch_tensor_name_prefix_torch: encoder
|
||||
tf2torch_tensor_name_prefix_tf: EAND/speech_encoder
|
||||
|
||||
speaker_encoder: null
|
||||
speaker_encoder_conf: {}
|
||||
|
||||
ci_scorer: conv
|
||||
ci_scorer_conf:
|
||||
input_units: 512
|
||||
num_layers: 3
|
||||
num_units: 512
|
||||
kernel_size: 1
|
||||
dropout_rate: 0.0
|
||||
position_encoder: null
|
||||
out_units: 1
|
||||
out_norm: false
|
||||
auxiliary_states: false
|
||||
tf2torch_tensor_name_prefix_torch: ci_scorer
|
||||
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/ci_scorer
|
||||
|
||||
cd_scorer: san
|
||||
cd_scorer_conf:
|
||||
input_size: 512
|
||||
output_size: 512
|
||||
out_units: 1
|
||||
attention_heads: 4
|
||||
linear_units: 1024
|
||||
num_blocks: 4
|
||||
dropout_rate: 0.0
|
||||
positional_dropout_rate: 0.0
|
||||
attention_dropout_rate: 0.0
|
||||
# use string "null" to remove input layer
|
||||
input_layer: "null"
|
||||
pos_enc_class: null
|
||||
normalize_before: true
|
||||
tf2torch_tensor_name_prefix_torch: cd_scorer
|
||||
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/cd_scorer
|
||||
|
||||
# post net
|
||||
decoder: fsmn
|
||||
decoder_conf:
|
||||
in_units: 32
|
||||
out_units: 2517
|
||||
filter_size: 31
|
||||
fsmn_num_layers: 6
|
||||
dnn_num_layers: 1
|
||||
num_memory_units: 16
|
||||
ffn_inner_dim: 512
|
||||
dropout_rate: 0.0
|
||||
tf2torch_tensor_name_prefix_torch: decoder
|
||||
tf2torch_tensor_name_prefix_tf: EAND/post_net
|
||||
|
||||
input_size: 80
|
||||
frontend: null
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: povey
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
filter_length_min: -1
|
||||
filter_length_max: -1
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
dither: 0.0
|
||||
snip_edges: false
|
||||
upsacle_samples: false
|
||||
|
||||
# minibatch related
|
||||
batch_type: unsorted
|
||||
# 6 samples
|
||||
batch_size: 6
|
||||
num_workers: 8
|
||||
max_epoch: 30
|
||||
num_iters_per_epoch: 10000
|
||||
keep_nbest_models: 30
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5.0
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- der
|
||||
- min
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- der
|
||||
- min
|
||||
- - valid
|
||||
- forward_steps
|
||||
- max
|
||||
|
||||
optim: adamw
|
||||
optim_conf:
|
||||
lr: 0.0001
|
||||
betas: [0.9, 0.998]
|
||||
weight_decay: 0
|
||||
scheduler: null
|
||||
scheduler_conf:
|
||||
model_size: 512
|
||||
warmup_steps: 10000
|
||||
|
||||
# without spec aug
|
||||
specaug: null
|
||||
|
||||
log_interval: 50
|
||||
# without normalize
|
||||
normalize: null
|
||||
@ -0,0 +1,133 @@
|
||||
model: sond
|
||||
model_conf:
|
||||
lsm_weight: 0.0
|
||||
length_normalized_loss: true
|
||||
max_spk_num: 16
|
||||
normalize_speech_speaker: true
|
||||
speaker_discrimination_loss_weight: 0
|
||||
inter_score_loss_weight: 0.1
|
||||
model_regularizer_weight: 0.0
|
||||
freeze_encoder: false
|
||||
onfly_shuffle_speaker: false
|
||||
# label aggregator
|
||||
label_aggregator: label_aggregator_max_pool
|
||||
label_aggregator_conf:
|
||||
hop_length: 8
|
||||
|
||||
# speech encoder
|
||||
encoder: resnet34_sp_l2reg
|
||||
encoder_conf:
|
||||
# pass by model, equal to feature dim
|
||||
# input_size: 80
|
||||
batchnorm_momentum: 0.01
|
||||
pooling_type: "window_shift"
|
||||
pool_size: 20
|
||||
stride: 1
|
||||
tf2torch_tensor_name_prefix_torch: encoder
|
||||
tf2torch_tensor_name_prefix_tf: EAND/speech_encoder
|
||||
|
||||
speaker_encoder: null
|
||||
speaker_encoder_conf: {}
|
||||
|
||||
ci_scorer: conv
|
||||
ci_scorer_conf:
|
||||
input_units: 512
|
||||
num_layers: 3
|
||||
num_units: 512
|
||||
kernel_size: 1
|
||||
dropout_rate: 0.0
|
||||
position_encoder: null
|
||||
out_units: 1
|
||||
out_norm: false
|
||||
auxiliary_states: false
|
||||
tf2torch_tensor_name_prefix_torch: ci_scorer
|
||||
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/ci_scorer
|
||||
|
||||
cd_scorer: san
|
||||
cd_scorer_conf:
|
||||
input_size: 512
|
||||
output_size: 512
|
||||
out_units: 1
|
||||
attention_heads: 4
|
||||
linear_units: 1024
|
||||
num_blocks: 4
|
||||
dropout_rate: 0.0
|
||||
positional_dropout_rate: 0.0
|
||||
attention_dropout_rate: 0.0
|
||||
# use string "null" to remove input layer
|
||||
input_layer: "null"
|
||||
pos_enc_class: null
|
||||
normalize_before: true
|
||||
tf2torch_tensor_name_prefix_torch: cd_scorer
|
||||
tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer/cd_scorer
|
||||
|
||||
# post net
|
||||
decoder: fsmn
|
||||
decoder_conf:
|
||||
in_units: 32
|
||||
out_units: 2517
|
||||
filter_size: 31
|
||||
fsmn_num_layers: 6
|
||||
dnn_num_layers: 1
|
||||
num_memory_units: 16
|
||||
ffn_inner_dim: 512
|
||||
dropout_rate: 0.0
|
||||
tf2torch_tensor_name_prefix_torch: decoder
|
||||
tf2torch_tensor_name_prefix_tf: EAND/post_net
|
||||
|
||||
input_size: 80
|
||||
frontend: null
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: povey
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
filter_length_min: -1
|
||||
filter_length_max: -1
|
||||
lfr_m: 1
|
||||
lfr_n: 1
|
||||
dither: 0.0
|
||||
snip_edges: false
|
||||
upsacle_samples: false
|
||||
|
||||
# minibatch related
|
||||
batch_type: unsorted
|
||||
# 6 samples
|
||||
batch_size: 6
|
||||
num_workers: 8
|
||||
max_epoch: 12
|
||||
num_iters_per_epoch: 300
|
||||
keep_nbest_models: 5
|
||||
|
||||
# optimization related
|
||||
accum_grad: 1
|
||||
grad_clip: 5.0
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- der
|
||||
- min
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- der
|
||||
- min
|
||||
- - valid
|
||||
- forward_steps
|
||||
- max
|
||||
|
||||
optim: adamw
|
||||
optim_conf:
|
||||
lr: 0.00001
|
||||
betas: [0.9, 0.998]
|
||||
weight_decay: 0
|
||||
scheduler: null
|
||||
scheduler_conf:
|
||||
model_size: 512
|
||||
warmup_steps: 10000
|
||||
|
||||
# without spec aug
|
||||
specaug: null
|
||||
|
||||
log_interval: 50
|
||||
# without normalize
|
||||
normalize: null
|
||||
2
egs/callhome/diarization/sond/conf/basic_inference.yaml
Normal file
2
egs/callhome/diarization/sond/conf/basic_inference.yaml
Normal file
@ -0,0 +1,2 @@
|
||||
smooth_size: 1
|
||||
dur_threshold: 0
|
||||
4
egs/callhome/diarization/sond/conf/fbank.conf
Normal file
4
egs/callhome/diarization/sond/conf/fbank.conf
Normal file
@ -0,0 +1,4 @@
|
||||
--sample-frequency=8000
|
||||
--num-mel-bins=80
|
||||
--frame-length=25
|
||||
--snip-edges=false
|
||||
65536
egs/callhome/diarization/sond/data/token_list/bce_label_n16.txt
Normal file
65536
egs/callhome/diarization/sond/data/token_list/bce_label_n16.txt
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,17 @@
|
||||
0
|
||||
1
|
||||
2
|
||||
4
|
||||
8
|
||||
16
|
||||
32
|
||||
64
|
||||
128
|
||||
256
|
||||
512
|
||||
1024
|
||||
2048
|
||||
4096
|
||||
8192
|
||||
16384
|
||||
32768
|
||||
@ -0,0 +1,137 @@
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
8
|
||||
9
|
||||
10
|
||||
12
|
||||
16
|
||||
17
|
||||
18
|
||||
20
|
||||
24
|
||||
32
|
||||
33
|
||||
34
|
||||
36
|
||||
40
|
||||
48
|
||||
64
|
||||
65
|
||||
66
|
||||
68
|
||||
72
|
||||
80
|
||||
96
|
||||
128
|
||||
129
|
||||
130
|
||||
132
|
||||
136
|
||||
144
|
||||
160
|
||||
192
|
||||
256
|
||||
257
|
||||
258
|
||||
260
|
||||
264
|
||||
272
|
||||
288
|
||||
320
|
||||
384
|
||||
512
|
||||
513
|
||||
514
|
||||
516
|
||||
520
|
||||
528
|
||||
544
|
||||
576
|
||||
640
|
||||
768
|
||||
1024
|
||||
1025
|
||||
1026
|
||||
1028
|
||||
1032
|
||||
1040
|
||||
1056
|
||||
1088
|
||||
1152
|
||||
1280
|
||||
1536
|
||||
2048
|
||||
2049
|
||||
2050
|
||||
2052
|
||||
2056
|
||||
2064
|
||||
2080
|
||||
2112
|
||||
2176
|
||||
2304
|
||||
2560
|
||||
3072
|
||||
4096
|
||||
4097
|
||||
4098
|
||||
4100
|
||||
4104
|
||||
4112
|
||||
4128
|
||||
4160
|
||||
4224
|
||||
4352
|
||||
4608
|
||||
5120
|
||||
6144
|
||||
8192
|
||||
8193
|
||||
8194
|
||||
8196
|
||||
8200
|
||||
8208
|
||||
8224
|
||||
8256
|
||||
8320
|
||||
8448
|
||||
8704
|
||||
9216
|
||||
10240
|
||||
12288
|
||||
16384
|
||||
16385
|
||||
16386
|
||||
16388
|
||||
16392
|
||||
16400
|
||||
16416
|
||||
16448
|
||||
16512
|
||||
16640
|
||||
16896
|
||||
17408
|
||||
18432
|
||||
20480
|
||||
24576
|
||||
32768
|
||||
32769
|
||||
32770
|
||||
32772
|
||||
32776
|
||||
32784
|
||||
32800
|
||||
32832
|
||||
32896
|
||||
33024
|
||||
33280
|
||||
33792
|
||||
34816
|
||||
36864
|
||||
40960
|
||||
49152
|
||||
@ -0,0 +1,697 @@
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
10
|
||||
11
|
||||
12
|
||||
13
|
||||
14
|
||||
16
|
||||
17
|
||||
18
|
||||
19
|
||||
20
|
||||
21
|
||||
22
|
||||
24
|
||||
25
|
||||
26
|
||||
28
|
||||
32
|
||||
33
|
||||
34
|
||||
35
|
||||
36
|
||||
37
|
||||
38
|
||||
40
|
||||
41
|
||||
42
|
||||
44
|
||||
48
|
||||
49
|
||||
50
|
||||
52
|
||||
56
|
||||
64
|
||||
65
|
||||
66
|
||||
67
|
||||
68
|
||||
69
|
||||
70
|
||||
72
|
||||
73
|
||||
74
|
||||
76
|
||||
80
|
||||
81
|
||||
82
|
||||
84
|
||||
88
|
||||
96
|
||||
97
|
||||
98
|
||||
100
|
||||
104
|
||||
112
|
||||
128
|
||||
129
|
||||
130
|
||||
131
|
||||
132
|
||||
133
|
||||
134
|
||||
136
|
||||
137
|
||||
138
|
||||
140
|
||||
144
|
||||
145
|
||||
146
|
||||
148
|
||||
152
|
||||
160
|
||||
161
|
||||
162
|
||||
164
|
||||
168
|
||||
176
|
||||
192
|
||||
193
|
||||
194
|
||||
196
|
||||
200
|
||||
208
|
||||
224
|
||||
256
|
||||
257
|
||||
258
|
||||
259
|
||||
260
|
||||
261
|
||||
262
|
||||
264
|
||||
265
|
||||
266
|
||||
268
|
||||
272
|
||||
273
|
||||
274
|
||||
276
|
||||
280
|
||||
288
|
||||
289
|
||||
290
|
||||
292
|
||||
296
|
||||
304
|
||||
320
|
||||
321
|
||||
322
|
||||
324
|
||||
328
|
||||
336
|
||||
352
|
||||
384
|
||||
385
|
||||
386
|
||||
388
|
||||
392
|
||||
400
|
||||
416
|
||||
448
|
||||
512
|
||||
513
|
||||
514
|
||||
515
|
||||
516
|
||||
517
|
||||
518
|
||||
520
|
||||
521
|
||||
522
|
||||
524
|
||||
528
|
||||
529
|
||||
530
|
||||
532
|
||||
536
|
||||
544
|
||||
545
|
||||
546
|
||||
548
|
||||
552
|
||||
560
|
||||
576
|
||||
577
|
||||
578
|
||||
580
|
||||
584
|
||||
592
|
||||
608
|
||||
640
|
||||
641
|
||||
642
|
||||
644
|
||||
648
|
||||
656
|
||||
672
|
||||
704
|
||||
768
|
||||
769
|
||||
770
|
||||
772
|
||||
776
|
||||
784
|
||||
800
|
||||
832
|
||||
896
|
||||
1024
|
||||
1025
|
||||
1026
|
||||
1027
|
||||
1028
|
||||
1029
|
||||
1030
|
||||
1032
|
||||
1033
|
||||
1034
|
||||
1036
|
||||
1040
|
||||
1041
|
||||
1042
|
||||
1044
|
||||
1048
|
||||
1056
|
||||
1057
|
||||
1058
|
||||
1060
|
||||
1064
|
||||
1072
|
||||
1088
|
||||
1089
|
||||
1090
|
||||
1092
|
||||
1096
|
||||
1104
|
||||
1120
|
||||
1152
|
||||
1153
|
||||
1154
|
||||
1156
|
||||
1160
|
||||
1168
|
||||
1184
|
||||
1216
|
||||
1280
|
||||
1281
|
||||
1282
|
||||
1284
|
||||
1288
|
||||
1296
|
||||
1312
|
||||
1344
|
||||
1408
|
||||
1536
|
||||
1537
|
||||
1538
|
||||
1540
|
||||
1544
|
||||
1552
|
||||
1568
|
||||
1600
|
||||
1664
|
||||
1792
|
||||
2048
|
||||
2049
|
||||
2050
|
||||
2051
|
||||
2052
|
||||
2053
|
||||
2054
|
||||
2056
|
||||
2057
|
||||
2058
|
||||
2060
|
||||
2064
|
||||
2065
|
||||
2066
|
||||
2068
|
||||
2072
|
||||
2080
|
||||
2081
|
||||
2082
|
||||
2084
|
||||
2088
|
||||
2096
|
||||
2112
|
||||
2113
|
||||
2114
|
||||
2116
|
||||
2120
|
||||
2128
|
||||
2144
|
||||
2176
|
||||
2177
|
||||
2178
|
||||
2180
|
||||
2184
|
||||
2192
|
||||
2208
|
||||
2240
|
||||
2304
|
||||
2305
|
||||
2306
|
||||
2308
|
||||
2312
|
||||
2320
|
||||
2336
|
||||
2368
|
||||
2432
|
||||
2560
|
||||
2561
|
||||
2562
|
||||
2564
|
||||
2568
|
||||
2576
|
||||
2592
|
||||
2624
|
||||
2688
|
||||
2816
|
||||
3072
|
||||
3073
|
||||
3074
|
||||
3076
|
||||
3080
|
||||
3088
|
||||
3104
|
||||
3136
|
||||
3200
|
||||
3328
|
||||
3584
|
||||
4096
|
||||
4097
|
||||
4098
|
||||
4099
|
||||
4100
|
||||
4101
|
||||
4102
|
||||
4104
|
||||
4105
|
||||
4106
|
||||
4108
|
||||
4112
|
||||
4113
|
||||
4114
|
||||
4116
|
||||
4120
|
||||
4128
|
||||
4129
|
||||
4130
|
||||
4132
|
||||
4136
|
||||
4144
|
||||
4160
|
||||
4161
|
||||
4162
|
||||
4164
|
||||
4168
|
||||
4176
|
||||
4192
|
||||
4224
|
||||
4225
|
||||
4226
|
||||
4228
|
||||
4232
|
||||
4240
|
||||
4256
|
||||
4288
|
||||
4352
|
||||
4353
|
||||
4354
|
||||
4356
|
||||
4360
|
||||
4368
|
||||
4384
|
||||
4416
|
||||
4480
|
||||
4608
|
||||
4609
|
||||
4610
|
||||
4612
|
||||
4616
|
||||
4624
|
||||
4640
|
||||
4672
|
||||
4736
|
||||
4864
|
||||
5120
|
||||
5121
|
||||
5122
|
||||
5124
|
||||
5128
|
||||
5136
|
||||
5152
|
||||
5184
|
||||
5248
|
||||
5376
|
||||
5632
|
||||
6144
|
||||
6145
|
||||
6146
|
||||
6148
|
||||
6152
|
||||
6160
|
||||
6176
|
||||
6208
|
||||
6272
|
||||
6400
|
||||
6656
|
||||
7168
|
||||
8192
|
||||
8193
|
||||
8194
|
||||
8195
|
||||
8196
|
||||
8197
|
||||
8198
|
||||
8200
|
||||
8201
|
||||
8202
|
||||
8204
|
||||
8208
|
||||
8209
|
||||
8210
|
||||
8212
|
||||
8216
|
||||
8224
|
||||
8225
|
||||
8226
|
||||
8228
|
||||
8232
|
||||
8240
|
||||
8256
|
||||
8257
|
||||
8258
|
||||
8260
|
||||
8264
|
||||
8272
|
||||
8288
|
||||
8320
|
||||
8321
|
||||
8322
|
||||
8324
|
||||
8328
|
||||
8336
|
||||
8352
|
||||
8384
|
||||
8448
|
||||
8449
|
||||
8450
|
||||
8452
|
||||
8456
|
||||
8464
|
||||
8480
|
||||
8512
|
||||
8576
|
||||
8704
|
||||
8705
|
||||
8706
|
||||
8708
|
||||
8712
|
||||
8720
|
||||
8736
|
||||
8768
|
||||
8832
|
||||
8960
|
||||
9216
|
||||
9217
|
||||
9218
|
||||
9220
|
||||
9224
|
||||
9232
|
||||
9248
|
||||
9280
|
||||
9344
|
||||
9472
|
||||
9728
|
||||
10240
|
||||
10241
|
||||
10242
|
||||
10244
|
||||
10248
|
||||
10256
|
||||
10272
|
||||
10304
|
||||
10368
|
||||
10496
|
||||
10752
|
||||
11264
|
||||
12288
|
||||
12289
|
||||
12290
|
||||
12292
|
||||
12296
|
||||
12304
|
||||
12320
|
||||
12352
|
||||
12416
|
||||
12544
|
||||
12800
|
||||
13312
|
||||
14336
|
||||
16384
|
||||
16385
|
||||
16386
|
||||
16387
|
||||
16388
|
||||
16389
|
||||
16390
|
||||
16392
|
||||
16393
|
||||
16394
|
||||
16396
|
||||
16400
|
||||
16401
|
||||
16402
|
||||
16404
|
||||
16408
|
||||
16416
|
||||
16417
|
||||
16418
|
||||
16420
|
||||
16424
|
||||
16432
|
||||
16448
|
||||
16449
|
||||
16450
|
||||
16452
|
||||
16456
|
||||
16464
|
||||
16480
|
||||
16512
|
||||
16513
|
||||
16514
|
||||
16516
|
||||
16520
|
||||
16528
|
||||
16544
|
||||
16576
|
||||
16640
|
||||
16641
|
||||
16642
|
||||
16644
|
||||
16648
|
||||
16656
|
||||
16672
|
||||
16704
|
||||
16768
|
||||
16896
|
||||
16897
|
||||
16898
|
||||
16900
|
||||
16904
|
||||
16912
|
||||
16928
|
||||
16960
|
||||
17024
|
||||
17152
|
||||
17408
|
||||
17409
|
||||
17410
|
||||
17412
|
||||
17416
|
||||
17424
|
||||
17440
|
||||
17472
|
||||
17536
|
||||
17664
|
||||
17920
|
||||
18432
|
||||
18433
|
||||
18434
|
||||
18436
|
||||
18440
|
||||
18448
|
||||
18464
|
||||
18496
|
||||
18560
|
||||
18688
|
||||
18944
|
||||
19456
|
||||
20480
|
||||
20481
|
||||
20482
|
||||
20484
|
||||
20488
|
||||
20496
|
||||
20512
|
||||
20544
|
||||
20608
|
||||
20736
|
||||
20992
|
||||
21504
|
||||
22528
|
||||
24576
|
||||
24577
|
||||
24578
|
||||
24580
|
||||
24584
|
||||
24592
|
||||
24608
|
||||
24640
|
||||
24704
|
||||
24832
|
||||
25088
|
||||
25600
|
||||
26624
|
||||
28672
|
||||
32768
|
||||
32769
|
||||
32770
|
||||
32771
|
||||
32772
|
||||
32773
|
||||
32774
|
||||
32776
|
||||
32777
|
||||
32778
|
||||
32780
|
||||
32784
|
||||
32785
|
||||
32786
|
||||
32788
|
||||
32792
|
||||
32800
|
||||
32801
|
||||
32802
|
||||
32804
|
||||
32808
|
||||
32816
|
||||
32832
|
||||
32833
|
||||
32834
|
||||
32836
|
||||
32840
|
||||
32848
|
||||
32864
|
||||
32896
|
||||
32897
|
||||
32898
|
||||
32900
|
||||
32904
|
||||
32912
|
||||
32928
|
||||
32960
|
||||
33024
|
||||
33025
|
||||
33026
|
||||
33028
|
||||
33032
|
||||
33040
|
||||
33056
|
||||
33088
|
||||
33152
|
||||
33280
|
||||
33281
|
||||
33282
|
||||
33284
|
||||
33288
|
||||
33296
|
||||
33312
|
||||
33344
|
||||
33408
|
||||
33536
|
||||
33792
|
||||
33793
|
||||
33794
|
||||
33796
|
||||
33800
|
||||
33808
|
||||
33824
|
||||
33856
|
||||
33920
|
||||
34048
|
||||
34304
|
||||
34816
|
||||
34817
|
||||
34818
|
||||
34820
|
||||
34824
|
||||
34832
|
||||
34848
|
||||
34880
|
||||
34944
|
||||
35072
|
||||
35328
|
||||
35840
|
||||
36864
|
||||
36865
|
||||
36866
|
||||
36868
|
||||
36872
|
||||
36880
|
||||
36896
|
||||
36928
|
||||
36992
|
||||
37120
|
||||
37376
|
||||
37888
|
||||
38912
|
||||
40960
|
||||
40961
|
||||
40962
|
||||
40964
|
||||
40968
|
||||
40976
|
||||
40992
|
||||
41024
|
||||
41088
|
||||
41216
|
||||
41472
|
||||
41984
|
||||
43008
|
||||
45056
|
||||
49152
|
||||
49153
|
||||
49154
|
||||
49156
|
||||
49160
|
||||
49168
|
||||
49184
|
||||
49216
|
||||
49280
|
||||
49408
|
||||
49664
|
||||
50176
|
||||
51200
|
||||
53248
|
||||
57344
|
||||
File diff suppressed because it is too large
Load Diff
10292
egs/callhome/diarization/sond/exp/EEND-OLA/sys.rttm
Normal file
10292
egs/callhome/diarization/sond/exp/EEND-OLA/sys.rttm
Normal file
File diff suppressed because it is too large
Load Diff
567
egs/callhome/diarization/sond/finetune.sh
Normal file
567
egs/callhome/diarization/sond/finetune.sh
Normal file
@ -0,0 +1,567 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
# This recipe aims at reimplement the results of SOND on Callhome corpus which is represented in
|
||||
# [1] TOLD: A Novel Two-stage Overlap-aware Framework for Speaker Diarization, ICASSP 2023
|
||||
# You can also use it on other dataset such AliMeeting to reproduce the results in
|
||||
# [2] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, EMNLP 2022
|
||||
# We recommend you run this script stage by stage.
|
||||
|
||||
# This recipe includes:
|
||||
# 1. downloading a pretrained model on the simulated data from switchboard and NIST,
|
||||
# 2. finetuning the pretrained model on Callhome1.
|
||||
# Finally, you will get a slightly better DER result 9.95% on Callhome2 than that in the paper 10.14%.
|
||||
|
||||
# environment configuration
|
||||
kaldi_root=
|
||||
|
||||
if [ -z "${kaldi_root}" ]; then
|
||||
echo "We need kaldi to prepare dataset, extract fbank features, please install kaldi first and set kaldi_root."
|
||||
echo "Kaldi installation guide can be found at https://kaldi-asr.org/"
|
||||
exit;
|
||||
fi
|
||||
|
||||
if [ ! -e local ]; then
|
||||
ln -s ${kaldi_root}/egs/callhome_diarization/v2/local ./local
|
||||
fi
|
||||
|
||||
if [ ! -e utils ]; then
|
||||
ln -s ${kaldi_root}/egs/callhome_diarization/v2/utils ./utils
|
||||
fi
|
||||
|
||||
# callhome data root like path/to/NIST/LDC2001S97
|
||||
callhome_root=
|
||||
if [ -z "${kaldi_root}" ]; then
|
||||
echo "We need callhome corpus to prepare data."
|
||||
exit;
|
||||
fi
|
||||
|
||||
# machines configuration
|
||||
gpu_devices="0,1,2,3" # for V100-16G, need 4 gpus.
|
||||
gpu_num=4
|
||||
count=1
|
||||
|
||||
# general configuration
|
||||
stage=0
|
||||
stop_stage=10
|
||||
# number of jobs for data process
|
||||
nj=16
|
||||
sr=8000
|
||||
|
||||
# experiment configuration
|
||||
lang=en
|
||||
feats_type=fbank
|
||||
datadir=data
|
||||
dumpdir=dump
|
||||
expdir=exp
|
||||
train_cmd=utils/run.pl
|
||||
|
||||
# training related
|
||||
tag=""
|
||||
train_set=callhome1
|
||||
valid_set=callhome1
|
||||
train_config=conf/EAND_ResNet34_SAN_L4N512_None_FFN_FSMN_L6N512_bce_dia_loss_01_phase3.yaml
|
||||
token_list=${datadir}/token_list/powerset_label_n16k4.txt
|
||||
init_param=
|
||||
freeze_param=
|
||||
|
||||
# inference related
|
||||
inference_model=valid.der.ave_5best.pb
|
||||
inference_config=conf/basic_inference.yaml
|
||||
inference_tag=""
|
||||
test_sets="callhome2"
|
||||
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
|
||||
# number of jobs for inference
|
||||
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
|
||||
njob=4
|
||||
infer_cmd=utils/run.pl
|
||||
told_max_iter=4
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
model_dir="$(basename "${train_config}" .yaml)_${feats_type}_${lang}${tag}"
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$gpu_devices # set gpus for decoding, the same as training stage by default
|
||||
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
|
||||
|
||||
if ${gpu_inference}; then
|
||||
inference_nj=$[${ngpu}*${njob}]
|
||||
_ngpu=1
|
||||
else
|
||||
inference_nj=$njob
|
||||
_ngpu=0
|
||||
fi
|
||||
|
||||
# Prepare datasets
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Stage 0: Prepare callhome data."
|
||||
local/make_callhome.sh ${callhome_root} ${datadir}/
|
||||
|
||||
# split ref.rttm
|
||||
for dset in callhome1 callhome2; do
|
||||
rm -rf ${datadir}/${dset}/ref.rttm
|
||||
for name in `awk '{print $1}' ${datadir}/${dset}/wav.scp`; do
|
||||
grep ${name} ${datadir}/callhome/fullref.rttm >> ${datadir}/${dset}/ref.rttm;
|
||||
done
|
||||
|
||||
# filter out records which don't have rttm labels.
|
||||
awk '{print $2}' ${datadir}/${dset}/ref.rttm | sort | uniq > ${datadir}/${dset}/uttid
|
||||
mv ${datadir}/${dset}/wav.scp ${datadir}/${dset}/wav.scp.bak
|
||||
awk '{if (NR==FNR){a[$1]=1}else{if (a[$1]==1){print $0}}}' ${datadir}/${dset}/uttid ${datadir}/${dset}/wav.scp.bak > ${datadir}/${dset}/wav.scp
|
||||
mkdir ${datadir}/${dset}/raw
|
||||
mv ${datadir}/${dset}/{reco2num_spk,segments,spk2utt,utt2spk,uttid,wav.scp.bak} ${datadir}/${dset}/raw/
|
||||
awk '{print $1,$1}' ${datadir}/${dset}/wav.scp > ${datadir}/${dset}/utt2spk
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Stage 1: Dump sph file to wav"
|
||||
export PATH=${kaldi_root}/tools/sph2pipe/:${PATH}
|
||||
if [ ! -f ${kaldi_root}/tools/sph2pipe/sph2pipe ]; then
|
||||
echo "Can not find sph2pipe in ${kaldi_root}/tools/sph2pipe/,"
|
||||
echo "please install sph2pipe and put it in the right place."
|
||||
exit;
|
||||
fi
|
||||
|
||||
for dset in callhome1 callhome2; do
|
||||
echo "Stage 1: start to dump ${dset}."
|
||||
mv ${datadir}/${dset}/wav.scp ${datadir}/${dset}/sph.scp
|
||||
|
||||
mkdir -p ${dumpdir}/${dset}/wavs
|
||||
python -Wignore script/dump_pipe_wav.py ${datadir}/${dset}/sph.scp ${dumpdir}/${dset}/wavs \
|
||||
--sr ${sr} --nj ${nj} --no_pbar
|
||||
find `pwd`/${dumpdir}/${dset}/wavs -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/wav.scp
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Stage 2: Extract non-overlap segments from callhome dataset"
|
||||
for dset in callhome1 callhome2; do
|
||||
echo "Stage 2: Extracting non-overlap segments for "${dset}
|
||||
mkdir -p ${dumpdir}/${dset}/nonoverlap_0s
|
||||
python -Wignore script/extract_nonoverlap_segments.py \
|
||||
${datadir}/${dset}/wav.scp ${datadir}/${dset}/ref.rttm ${dumpdir}/${dset}/nonoverlap_0s \
|
||||
--min_dur 0.1 --max_spk_num 8 --sr ${sr} --no_pbar --nj ${nj}
|
||||
|
||||
mkdir -p ${datadir}/${dset}/nonoverlap_0s
|
||||
find ${dumpdir}/${dset}/nonoverlap_0s/ -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/nonoverlap_0s/wav.scp
|
||||
awk -F'[/.]' '{print $(NF-1),$(NF-2)}' ${datadir}/${dset}/nonoverlap_0s/wav.scp > ${datadir}/${dset}/nonoverlap_0s/utt2spk
|
||||
echo "Done."
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Stage 3: Generate fbank features"
|
||||
home_path=$(pwd)
|
||||
cd ${kaldi_root}/egs/callhome_diarization/v2 || exit
|
||||
|
||||
export train_cmd="run.pl"
|
||||
export cmd="run.pl"
|
||||
. ./path.sh
|
||||
cd $home_path || exit
|
||||
|
||||
ln -s ${kaldi_root}/egs/callhome_diarization/v2/steps ./
|
||||
for dset in callhome1 callhome2; do
|
||||
utils/fix_data_dir.sh ${datadir}/${dset}
|
||||
steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
|
||||
${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
|
||||
done
|
||||
rm -f steps
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "Stage 4: Extract speaker embeddings."
|
||||
sv_exp_dir=exp/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
|
||||
|
||||
if [ ! -e ${sv_exp_dir} ]; then
|
||||
echo "start to download sv models"
|
||||
git lfs install
|
||||
git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
|
||||
mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
|
||||
echo "Done."
|
||||
fi
|
||||
|
||||
for dset in callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
|
||||
echo "Start to extract speaker embeddings for ${dset}"
|
||||
key_file=${datadir}/${dset}/wav.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
_logdir=${dumpdir}/${dset}/xvecs
|
||||
mkdir -p ${_logdir}
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/sv_inference.JOB.log \
|
||||
python -m funasr.bin.sv_inference_launch \
|
||||
--batch_size 1 \
|
||||
--njob ${njob} \
|
||||
--ngpu "${_ngpu}" \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${key_file},speech,sound" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--sv_train_config ${sv_exp_dir}/sv.yaml \
|
||||
--sv_model_file ${sv_exp_dir}/sv.pth \
|
||||
--output_dir "${_logdir}"/output.JOB
|
||||
cat ${_logdir}/output.*/xvector.scp | sort > ${datadir}/${dset}/utt2xvec
|
||||
|
||||
python script/calc_num_frames.py ${key_file} ${datadir}/${dset}/utt2num_frames
|
||||
echo "Done."
|
||||
done
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "Stage 5: Generate label files."
|
||||
|
||||
for dset in callhome1 callhome2; do
|
||||
echo "Stage 5: Generate labels for ${dset}."
|
||||
python -Wignore script/calc_real_meeting_frame_labels.py \
|
||||
${datadir}/${dset} ${dumpdir}/${dset}/labels \
|
||||
--n_spk 8 --frame_shift 0.01 --nj 16 --sr 8000
|
||||
find `pwd`/${dumpdir}/${dset}/labels/ -iname "*.lbl.mat" | awk -F'[/.]' '{print $(NF-2),$0}' | sort > ${datadir}/${dset}/labels.scp
|
||||
done
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||
echo "Stage 6: Make training and evaluation files."
|
||||
|
||||
# dump callhome1 data in training mode.
|
||||
data_dir=${datadir}/callhome1/files_for_dump
|
||||
mkdir ${data_dir}
|
||||
# filter out zero duration segments
|
||||
LC_ALL=C awk '{if ($5 > 0){print $0}}' ${datadir}/callhome1/ref.rttm > ${data_dir}/ref.rttm
|
||||
cp ${datadir}/callhome1/{feats.scp,labels.scp} ${data_dir}/
|
||||
cp ${datadir}/callhome1/nonoverlap_0s/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
|
||||
|
||||
echo "Stage 6: start to dump for callhome1."
|
||||
python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
|
||||
--out ${dumpdir}/callhome1/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode train \
|
||||
--chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
|
||||
|
||||
mkdir -p ${datadir}/callhome1/dumped_files
|
||||
cat ${dumpdir}/callhome1/dumped_files/data_parts*_feat.scp | sort > ${datadir}/callhome1/dumped_files/feats.scp
|
||||
cat ${dumpdir}/callhome1/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/callhome1/dumped_files/profile.scp
|
||||
cat ${dumpdir}/callhome1/dumped_files/data_parts*_label.scp | sort > ${datadir}/callhome1/dumped_files/label.scp
|
||||
mkdir -p ${expdir}/callhome1_states
|
||||
awk '{print $1,"1600"}' ${datadir}/callhome1/dumped_files/feats.scp | shuf > ${expdir}/callhome1_states/speech_shape
|
||||
python -Wignore script/convert_rttm_to_seg_file.py --rttm_scp ${data_dir}/ref.rttm --seg_file ${data_dir}/org_vad.txt
|
||||
|
||||
# dump callhome2 data in test mode.
|
||||
data_dir=${datadir}/callhome2/files_for_dump
|
||||
mkdir ${data_dir}
|
||||
# filter out zero duration segments
|
||||
LC_ALL=C awk '{if ($5 > 0){print $0}}' ${datadir}/callhome2/ref.rttm > ${data_dir}/ref.rttm
|
||||
cp ${datadir}/callhome2/{feats.scp,labels.scp} ${data_dir}/
|
||||
cp ${datadir}/callhome2/nonoverlap_0s/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
|
||||
|
||||
echo "Stage 6: start to dump for callhome2."
|
||||
python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
|
||||
--out ${dumpdir}/callhome2/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode test \
|
||||
--chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
|
||||
|
||||
mkdir -p ${datadir}/callhome2/dumped_files
|
||||
cat ${dumpdir}/callhome2/dumped_files/data_parts*_feat.scp | sort > ${datadir}/callhome2/dumped_files/feats.scp
|
||||
cat ${dumpdir}/callhome2/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/callhome2/dumped_files/profile.scp
|
||||
cat ${dumpdir}/callhome2/dumped_files/data_parts*_label.scp | sort > ${datadir}/callhome2/dumped_files/label.scp
|
||||
mkdir -p ${expdir}/callhome2_states
|
||||
awk '{print $1,"1600"}' ${datadir}/callhome2/dumped_files/feats.scp | shuf > ${expdir}/callhome2_states/speech_shape
|
||||
python -Wignore script/convert_rttm_to_seg_file.py --rttm_scp ${data_dir}/ref.rttm --seg_file ${data_dir}/org_vad.txt
|
||||
|
||||
fi
|
||||
|
||||
# Finetune model on callhome1, this will take about 1.5 hours.
|
||||
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||
echo "Stage 7: Finetune pretrained model on callhome1."
|
||||
|
||||
if [ ! -e ${expdir}/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch ]; then
|
||||
echo "start to download pretrained models"
|
||||
git lfs install
|
||||
git clone https://www.modelscope.cn/damo/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch.git
|
||||
mv speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch ${expdir}/
|
||||
echo "Done."
|
||||
fi
|
||||
|
||||
world_size=$gpu_num # run on one machine
|
||||
mkdir -p ${expdir}/${model_dir}
|
||||
mkdir -p ${expdir}/${model_dir}/log
|
||||
mkdir -p /tmp/${model_dir}
|
||||
INIT_FILE=/tmp/${model_dir}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_opt=""
|
||||
if [ ! -z "${init_param}" ]; then
|
||||
init_opt="--init_param ${init_param}"
|
||||
echo ${init_opt}
|
||||
fi
|
||||
|
||||
freeze_opt=""
|
||||
if [ ! -z "${freeze_param}" ]; then
|
||||
freeze_opt="--freeze_param ${freeze_param}"
|
||||
echo ${freeze_opt}
|
||||
fi
|
||||
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
|
||||
python -m funasr.bin.diar_train \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor false \
|
||||
--token_type char \
|
||||
--token_list $token_list \
|
||||
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
|
||||
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
|
||||
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
|
||||
--train_shape_file ${expdir}/${valid_set}_states/speech_shape \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
|
||||
--valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
|
||||
--init_param ${expdir}/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch/sond.pth \
|
||||
--unused_parameters true \
|
||||
${init_opt} \
|
||||
${freeze_opt} \
|
||||
--ignore_init_mismatch true \
|
||||
--resume true \
|
||||
--output_dir ${expdir}/${model_dir} \
|
||||
--config ${train_config} \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--multiprocessing_distributed true \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${expdir}/${model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
echo "Training log can be found at ${expdir}/${model_dir}/log/train.log.*"
|
||||
wait
|
||||
fi
|
||||
|
||||
|
||||
# evaluate for finetuned model
|
||||
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
|
||||
echo "stage 8: evaluation for finetuned model ${inference_model}."
|
||||
for dset in ${test_sets}; do
|
||||
echo "Processing for $dset"
|
||||
exp_model_dir=${expdir}/${model_dir}
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${exp_model_dir}/${_inference_tag}/${inference_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
if [ -d ${_dir} ]; then
|
||||
echo "WARNING: ${_dir} is already exists."
|
||||
fi
|
||||
mkdir -p "${_logdir}"
|
||||
_data="${datadir}/${dset}/dumped_files"
|
||||
key_file=${_data}/feats.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
_opt=
|
||||
if [ ! -z "${inference_config}" ]; then
|
||||
_opt="--config ${inference_config}"
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
echo "Inference log can be found at ${_logdir}/inference.*.log"
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
|
||||
python -m funasr.bin.diar_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
|
||||
--data_path_and_name_and_type "${_data}/profile.scp,profile,kaldi_ark" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--diar_train_config "${exp_model_dir}"/config.yaml \
|
||||
--diar_model_file "${exp_model_dir}"/${inference_model} \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode sond ${_opt}
|
||||
done
|
||||
fi
|
||||
|
||||
# Scoring for finetuned model, you may get a DER like:
|
||||
# oracle_vad | system_vad
|
||||
# 7.32 | 8.14
|
||||
if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
|
||||
echo "stage 9: Scoring finetuned models"
|
||||
if [ ! -e dscore ]; then
|
||||
git clone https://github.com/nryant/dscore.git
|
||||
pip install intervaltree
|
||||
# add intervaltree to setup.py
|
||||
fi
|
||||
for dset in ${test_sets}; do
|
||||
echo "stage 9: Scoring for ${dset}"
|
||||
diar_exp=${expdir}/${model_dir}
|
||||
_data="${datadir}/${dset}"
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
cat ${_logdir}/*/labels.txt | sort > ${_dir}/labels.txt
|
||||
|
||||
cmd="python -Wignore script/convert_label_to_rttm.py ${_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${_dir}/sys.rttm \
|
||||
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
|
||||
echo ${cmd}
|
||||
eval ${cmd}
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${_dir}/sys.rttm.ref_vad
|
||||
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${_dir}/sys.rttm.sys_vad
|
||||
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
echo -e "${inference_model} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${_dir}/results.txt
|
||||
done
|
||||
fi
|
||||
|
||||
# In this stage, we need the raw waveform files of Callhome corpus.
|
||||
# Due to the data license, we can't provide them, please get them additionally.
|
||||
# And convert the sph files to wav files (use scripts/dump_pipe_wav.py).
|
||||
# Then find the wav files to construct wav.scp and put it at data/callhome2/wav.scp.
|
||||
# After iteratively perform SOAP, you will get DER results like:
|
||||
# iters : oracle_vad | system_vad
|
||||
# iter_0: 9.58 | 10.46
|
||||
# iter_1: 9.22 | 10.15
|
||||
# iter_2: 9.21 | 10.14
|
||||
# iter_3: 9.30 | 10.24
|
||||
# iter_4: 9.29 | 10.23
|
||||
if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
|
||||
if [ ! -e ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ]; then
|
||||
git lfs install
|
||||
git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
|
||||
mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
|
||||
fi
|
||||
|
||||
for dset in ${test_sets}; do
|
||||
echo "stage 10: Evaluating finetuned system on ${dset} set with medfilter_size=83 clustering=EEND-OLA"
|
||||
sv_exp_dir=${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
|
||||
diar_exp=${expdir}/${model_dir}
|
||||
_data="${datadir}/${dset}/dumped_files"
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
|
||||
|
||||
for iter in `seq 0 ${told_max_iter}`; do
|
||||
eval_dir=${_dir}/iter_${iter}
|
||||
if [ $iter -eq 0 ]; then
|
||||
prev_rttm=${expdir}/EEND-OLA/sys.rttm
|
||||
else
|
||||
prev_rttm=${_dir}/iter_$((${iter}-1))/sys.rttm.sys_vad
|
||||
fi
|
||||
echo "Use ${prev_rttm} as system outputs."
|
||||
|
||||
echo "Iteration ${iter}, step 1: extracting non-overlap segments"
|
||||
cmd="python -Wignore script/extract_nonoverlap_segments.py ${datadir}/${dset}/wav.scp \
|
||||
$prev_rttm ${eval_dir}/nonoverlap_segs/ --min_dur 0.1 --max_spk_num 16 --no_pbar --sr 8000"
|
||||
# echo ${cmd}
|
||||
eval ${cmd}
|
||||
|
||||
echo "Iteration ${iter}, step 2: make data directory"
|
||||
mkdir -p ${eval_dir}/data
|
||||
find `pwd`/${eval_dir}/nonoverlap_segs/ -iname "*.wav" | sort > ${eval_dir}/data/wav.flist
|
||||
awk -F'[/.]' '{print $(NF-1),$0}' ${eval_dir}/data/wav.flist > ${eval_dir}/data/wav.scp
|
||||
awk -F'[/.]' '{print $(NF-1),$(NF-2)}' ${eval_dir}/data/wav.flist > ${eval_dir}/data/utt2spk
|
||||
cp $prev_rttm ${eval_dir}/data/sys.rttm
|
||||
home_path=`pwd`
|
||||
|
||||
echo "Iteration ${iter}, step 3: calc x-vector for each utt"
|
||||
key_file=${eval_dir}/data/wav.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
_logdir=${eval_dir}/data/xvecs
|
||||
mkdir -p ${_logdir}
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/sv_inference.JOB.log \
|
||||
python -m funasr.bin.sv_inference_launch \
|
||||
--njob ${njob} \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${key_file},speech,sound" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--sv_train_config ${sv_exp_dir}/sv.yaml \
|
||||
--sv_model_file ${sv_exp_dir}/sv.pth \
|
||||
--output_dir "${_logdir}"/output.JOB
|
||||
cat ${_logdir}/output.*/xvector.scp | sort > ${eval_dir}/data/utt2xvec
|
||||
|
||||
echo "Iteration ${iter}, step 4: dump x-vector record"
|
||||
awk '{print $1}' ${_data}/feats.scp > ${eval_dir}/data/idx
|
||||
python script/dump_speaker_profiles.py --dir ${eval_dir}/data \
|
||||
--out ${eval_dir}/global_n16 --n_spk 16 --no_pbar --emb_type global
|
||||
spk_profile=${eval_dir}/global_n16_parts00_xvec.scp
|
||||
|
||||
echo "Iteration ${iter}, step 5: perform NN diarization"
|
||||
_logdir=${eval_dir}/diar
|
||||
mkdir -p ${_logdir}
|
||||
key_file=${_data}/feats.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
_opt=
|
||||
if [ ! -z "${inference_config}" ]; then
|
||||
_opt="--config ${inference_config}"
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
echo "Inference log can be found at ${_logdir}/inference.*.log"
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
|
||||
python -m funasr.bin.diar_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
|
||||
--data_path_and_name_and_type "${spk_profile},profile,kaldi_ark" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--diar_train_config ${diar_exp}/config.yaml \
|
||||
--diar_model_file ${diar_exp}/${inference_model} \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode sond ${_opt}
|
||||
|
||||
echo "Iteration ${iter}, step 6: calc diarization results"
|
||||
cat ${_logdir}/output.*/labels.txt | sort > ${eval_dir}/labels.txt
|
||||
|
||||
cmd="python -Wignore script/convert_label_to_rttm.py ${eval_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${eval_dir}/sys.rttm \
|
||||
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
|
||||
# echo ${cmd}
|
||||
eval ${cmd}
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${eval_dir}/sys.rttm.ref_vad
|
||||
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${eval_dir}/sys.rttm.sys_vad
|
||||
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
echo -e "${inference_model}/iter_${iter} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${eval_dir}/results.txt
|
||||
done
|
||||
|
||||
echo "Done."
|
||||
done
|
||||
fi
|
||||
5
egs/callhome/diarization/sond/path.sh
Normal file
5
egs/callhome/diarization/sond/path.sh
Normal file
@ -0,0 +1,5 @@
|
||||
export FUNASR_DIR=$PWD/../../..
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PATH=$FUNASR_DIR/funasr/bin:$PATH
|
||||
954
egs/callhome/diarization/sond/run.sh
Normal file
954
egs/callhome/diarization/sond/run.sh
Normal file
@ -0,0 +1,954 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
# This recipe aims at reimplement the results of SOND on Callhome corpus which is represented in
|
||||
# [1] TOLD: A Novel Two-stage Overlap-aware Framework for Speaker Diarization, ICASSP 2023
|
||||
# You can also use it on other dataset such AliMeeting to reproduce the results in
|
||||
# [2] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, EMNLP 2022
|
||||
# We recommend you run this script stage by stage.
|
||||
|
||||
# [developing] This recipe includes:
|
||||
# 1. simulating data with switchboard and NIST.
|
||||
# 2. training the model from scratch for 3 stages:
|
||||
# 2-1. pre-train on simu_swbd_sre
|
||||
# 2-2. train on simu_swbd_sre
|
||||
# 2-3. finetune on callhome1
|
||||
# 3. evaluating model with the results from the first stage EEND-OLA,
|
||||
# Finally, you will get a similar DER result claimed in the paper.
|
||||
|
||||
# environment configuration
|
||||
kaldi_root=
|
||||
|
||||
if [ -z "${kaldi_root}" ]; then
|
||||
echo "We need kaldi to prepare dataset, extract fbank features, please install kaldi first and set kaldi_root."
|
||||
echo "Kaldi installation guide can be found at https://kaldi-asr.org/"
|
||||
exit;
|
||||
fi
|
||||
|
||||
if [ ! -e local ]; then
|
||||
ln -s ${kaldi_root}/egs/callhome_diarization/v2/local ./local
|
||||
fi
|
||||
|
||||
if [ ! -e utils ]; then
|
||||
ln -s ${kaldi_root}/egs/callhome_diarization/v2/utils ./utils
|
||||
fi
|
||||
|
||||
# machines configuration
|
||||
gpu_devices="4,5,6,7" # for V100-16G, use 4 GPUs
|
||||
gpu_num=4
|
||||
count=1
|
||||
|
||||
# general configuration
|
||||
stage=3
|
||||
stop_stage=3
|
||||
# number of jobs for data process
|
||||
nj=16
|
||||
sr=8000
|
||||
|
||||
# dataset related
|
||||
data_root=
|
||||
callhome_root=path/to/NIST/LDC2001S97
|
||||
|
||||
# experiment configuration
|
||||
lang=en
|
||||
feats_type=fbank
|
||||
datadir=data
|
||||
dumpdir=dump
|
||||
expdir=exp
|
||||
train_cmd=utils/run.pl
|
||||
|
||||
# training related
|
||||
tag=""
|
||||
train_set=simu_swbd_sre
|
||||
valid_set=callhome1
|
||||
train_config=conf/EAND_ResNet34_SAN_L4N512_None_FFN_FSMN_L6N512_bce_dia_loss_01.yaml
|
||||
token_list=${datadir}/token_list/powerset_label_n16k4.txt
|
||||
init_param=
|
||||
freeze_param=
|
||||
|
||||
# inference related
|
||||
inference_model=valid.der.ave_5best.pth
|
||||
inference_config=conf/basic_inference.yaml
|
||||
inference_tag=""
|
||||
test_sets="callhome1"
|
||||
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
|
||||
# number of jobs for inference
|
||||
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
|
||||
njob=5
|
||||
infer_cmd=utils/run.pl
|
||||
told_max_iter=2
|
||||
|
||||
. utils/parse_options.sh || exit 1;
|
||||
|
||||
model_dir="$(basename "${train_config}" .yaml)_${feats_type}_${lang}${tag}"
|
||||
|
||||
# you can set gpu num for decoding here
|
||||
gpuid_list=$gpu_devices # set gpus for decoding, the same as training stage by default
|
||||
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
|
||||
|
||||
if ${gpu_inference}; then
|
||||
inference_nj=$[${ngpu}*${njob}]
|
||||
_ngpu=1
|
||||
else
|
||||
inference_nj=$njob
|
||||
_ngpu=0
|
||||
fi
|
||||
|
||||
# Prepare datasets
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# 1. Prepare a collection of NIST SRE data.
|
||||
echp "Stage 0: Prepare a collection of NIST SRE data."
|
||||
|
||||
local/make_sre.sh $data_root ${datadir}
|
||||
|
||||
# 2.a Prepare SWB.
|
||||
local/make_swbd2_phase1.pl ${data_root}/LDC98S75 \
|
||||
${datadir}/swbd2_phase1_train
|
||||
local/make_swbd2_phase2.pl $data_root/LDC99S79 \
|
||||
${datadir}/swbd2_phase2_train
|
||||
local/make_swbd2_phase3.pl $data_root/LDC2002S06 \
|
||||
${datadir}/swbd2_phase3_train
|
||||
local/make_swbd_cellular1.pl $data_root/LDC2001S13 \
|
||||
${datadir}/swbd_cellular1_train
|
||||
local/make_swbd_cellular2.pl $data_root/LDC2004S07 \
|
||||
${datadir}/swbd_cellular2_train
|
||||
# 2.b combine all swbd data.
|
||||
utils/combine_data.sh ${datadir}/swbd \
|
||||
${datadir}/swbd2_phase1_train ${datadir}/swbd2_phase2_train ${datadir}/swbd2_phase3_train \
|
||||
${datadir}/swbd_cellular1_train ${datadir}/swbd_cellular2_train
|
||||
utils/validate_data_dir.sh --no-text --no-feats ${datadir}/swbd
|
||||
utils/fix_data_dir.sh ${datadir}/swbd
|
||||
|
||||
utils/combine_data.sh ${datadir}/swbd_sre ${datadir}/swbd ${datadir}/sre
|
||||
utils/validate_data_dir.sh --no-text --no-feats ${datadir}/swbd_sre
|
||||
utils/fix_data_dir.sh ${datadir}/swbd_sre
|
||||
|
||||
# 3. Prepare the Callhome portion of NIST SRE 2000.
|
||||
local/make_callhome.sh ${callhome_root} ${datadir}/
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Stage 1: Dump sph file to wav"
|
||||
export PATH=${kaldi_root}/tools/sph2pipe/:${PATH}
|
||||
if [ ! -f ${kaldi_root}/tools/sph2pipe/sph2pipe ]; then
|
||||
echo "Can not find sph2pipe in ${kaldi_root}/tools/sph2pipe/,"
|
||||
echo "please install sph2pipe and put it in the right place."
|
||||
exit;
|
||||
fi
|
||||
|
||||
for dset in callhome1 callhome2 swbd_sre; do
|
||||
echo "Stage 1: start to dump ${dset}."
|
||||
mv ${datadir}/${dset}/wav.scp ${datadir}/${dset}/sph.scp
|
||||
|
||||
mkdir -p ${dumpdir}/${dset}/wavs
|
||||
python -Wignore script/dump_pipe_wav.py ${datadir}/${dset}/sph.scp ${dumpdir}/${dset}/wavs \
|
||||
--sr ${sr} --nj ${nj} --no_pbar
|
||||
find `pwd`/${dumpdir}/${dset}/wavs -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/wav.scp
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Stage 2: Extract non-overlap segments from callhome dataset"
|
||||
for dset in callhome1 callhome2; do
|
||||
echo "Stage 2: Extracting non-overlap segments for "${dset}
|
||||
mkdir -p ${dumpdir}/${dset}/nonoverlap_0s
|
||||
python -Wignore script/extract_nonoverlap_segments.py \
|
||||
${datadir}/${dset}/wav.scp ${datadir}/${dset}/ref.rttm ${dumpdir}/${dset}/nonoverlap_0s \
|
||||
--min_dur 0 --max_spk_num 8 --sr ${sr} --no_pbar --nj ${nj}
|
||||
|
||||
mkdir -p ${datadir}/${dset}/nonoverlap_0s
|
||||
find `pwd`/${dumpdir}/${dset}/nonoverlap_0s | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/nonoverlap_0s/wav.scp
|
||||
awk -F'[/.]' '{print $(NF-1),$(NF-2)}' ${datadir}/${dset}/nonoverlap_0s/wav.scp > ${datadir}/${dset}/nonoverlap_0s/utt2spk
|
||||
echo "Done."
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Stage 3: Generate concatenated waveforms for each speaker in switchboard, sre and callhome1"
|
||||
mkdir swb_sre_resources
|
||||
wget --no-check-certificate -P swb_sre_resources/ https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/Speaker_Diar/swb_sre_resources/noise.scp
|
||||
wget --no-check-certificate -P swb_sre_resources/ https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/Speaker_Diar/swb_sre_resources/swbd_sre_tdnn_vad_segments
|
||||
mkdir ${datadir}/swbd_sre/none_silence
|
||||
ln -s swb_sre_resources/swbd_sre_tdnn_vad_segments ${datadir}/swbd_sre/none_silence/segments
|
||||
cp ${datadir}/swbd_sre/wav.scp ${datadir}/swbd_sre/none_silence/reco.scp
|
||||
|
||||
mkdir -p ${dumpdir}/swbd_sre/none_silence
|
||||
python -Wignore script/remove_silence_from_wav.py \
|
||||
${datadir}/swbd_sre/none_silence ${dumpdir}/swbd_sre/none_silence --nj ${nj} --sr 8000
|
||||
# The utterance number in wav.scp may be different from reco.scp,
|
||||
# since some recordings don't appear in the segments file, may due to the VAD
|
||||
echo "find wavs_nosil"
|
||||
find `pwd`/${dumpdir}/swbd_sre/none_silence -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/swbd_sre/none_silence/wav.scp
|
||||
echo "concat spk segments"
|
||||
|
||||
ln -s ${datadir}/swbd_sre/utt2spk ${datadir}/swbd_sre/none_silence/utt2spk
|
||||
|
||||
echo "Stage 3: Start to concatnate waveforms for speakers in switchboard and sre"
|
||||
python -Wignore egs/callhome/concat_spk_segs.py \
|
||||
${datadir}/swbd_sre/none_silence ${dumpdir}/swbd_sre/spk_wavs --nj ${nj} --sr 8000
|
||||
|
||||
echo "Stage 3: Start to concatnate waveforms for speakers in callhome1"
|
||||
# only use callhome1 as training set to simulate data
|
||||
python -Wignore egs/callhome/concat_spk_segs.py \
|
||||
${datadir}/callhome1/nonoverlap_0s ${dumpdir}/callhome1/spk_wavs --nj ${nj} --sr 8000
|
||||
|
||||
fi
|
||||
|
||||
# simulate data with the pattern of callhome1
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "Stage 4: Start to simulate recordings."
|
||||
|
||||
if [ ! -e ${dumpdir}/musan ]; then
|
||||
echo "Stage 4-1: Start to download MUSAN noises from openslr"
|
||||
wget --no-check-certificate -P ${dumpdir}/musan https://www.openslr.org/resources/17/musan.tar.gz
|
||||
tar -C ${dumpdir}/musan -xvf ${dumpdir}/musan/musan.tar.gz
|
||||
fi
|
||||
|
||||
if [ ! -e ${dumpdir}/rirs ]; then
|
||||
echo "Stage 4-2: Start to download RIRs from openslr"
|
||||
wget --no-check-certificate -P ${dumpdir}/rirs https://www.openslr.org/resources/28/rirs_noises.zip
|
||||
unzip ${dumpdir}/rirs/rirs_noises.zip -d ${dumpdir}/rirs
|
||||
fi
|
||||
|
||||
mkdir -p ${datadir}/simu_swbd_sre
|
||||
# only use background noises instead of all noises in MUSAN.
|
||||
sed "s:/path/to/musan/:`pwd`/${dumpdir}/musan/:g" swb_sre_resources/noise.scp > ${datadir}/simu_swbd_sre/noise.scp
|
||||
# use simulated RIRs.
|
||||
find `pwd`/${dumpdir}/rirs/RIRS_NOISES/simulated_rirs/ -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-3)"-"$(NF-1), $0}' > ${datadir}/simu_swbd_sre/rirs.scp
|
||||
cp ${datadir}/callhome1/{ref.rttm,reco2num_spk} ${datadir}/simu_swbd_sre
|
||||
find `pwd`/${dumpdir}/swbd_sre/spk_wavs -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/simu_swbd_sre/spk2wav.scp
|
||||
|
||||
echo "Stage 4-3: Start to simulate recordings with variable speakers as Callhome1 patterns."
|
||||
# average duration of callhome is 125s, about 20 chunk with silence
|
||||
# simulating 22500 (45 jobs x 500 reco) recordings, without random_assign and random_shift_interval
|
||||
for i in $(seq 0 44); do
|
||||
cmd="python -Wignore egs/callhome/simu_whole_recordings.py \
|
||||
${datadir}/simu_swbd_sre \
|
||||
${dumpdir}/simu_swbd_sre/wavs \
|
||||
--corpus_name simu_swbd_sre --task_id $i --total_mix 500 --sr 8000 --no_bar &"
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
done
|
||||
wait;
|
||||
|
||||
echo "Stage 4-4: Start to simulate recordings with fixed speakers as Callhome1 patterns."
|
||||
# simulating 30000 (30 jobs x 1000 reco) recordings for different speaker number 2, 3, 4
|
||||
for n_spk in $(seq 2 4); do
|
||||
mkdir -p /home/neo.dzh/corpus/simu_swbd_sre/${n_spk}spk_wavs
|
||||
for i in $(seq 0 29); do
|
||||
cmd="python -Wignore egs/callhome/simu_whole_recordings.py \
|
||||
${datadir}/simu_swbd_sre \
|
||||
${dumpdir}/simu_swbd_sre/${n_spk}spk_wavs \
|
||||
--random_assign_spk --random_interval --spk_num ${n_spk} \
|
||||
--corpus_name simu_swbd_sre --task_id $i --total_mix 1000 --sr 8000 --no_bar &"
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
done
|
||||
wait;
|
||||
done
|
||||
|
||||
find `pwd`/${dumpdir}/simu_swbd_sre -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/simu_swbd_sre/wav.scp
|
||||
awk '{print $1,$1}' ${datadir}/simu_swbd_sre/wav.scp > ${datadir}/simu_swbd_sre/utt2spk
|
||||
find `pwd`/${dumpdir}/simu_swbd_sre -iname "*.rttm" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/simu_swbd_sre/rttm.scp
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "Stage 5: Generate fbank features"
|
||||
home_path=`pwd`
|
||||
cd ${kaldi_root}/egs/callhome_diarization/v2 || exit
|
||||
|
||||
. ./cmd.sh
|
||||
. ./path.sh
|
||||
|
||||
for dset in simu_swbd_sre callhome1 callhome2; do
|
||||
steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
|
||||
${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
|
||||
utils/fix_data_dir.sh ${datadir}/${dset}
|
||||
done
|
||||
|
||||
for dset in swbd_sre/none_silence callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
|
||||
steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
|
||||
${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
|
||||
utils/fix_data_dir.sh ${datadir}/${dset}
|
||||
done
|
||||
|
||||
cd ${home_path} || exit
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||
echo "Stage 6: Extract speaker embeddings."
|
||||
git lfs install
|
||||
git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
|
||||
mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
|
||||
|
||||
sv_exp_dir=exp/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
|
||||
sed "s/input_size: null/input_size: 80/g" ${sv_exp_dir}/sv.yaml > ${sv_exp_dir}/sv_fbank.yaml
|
||||
for dset in swbd_sre/none_silence callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
|
||||
key_file=${datadir}/${dset}/feats.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
_logdir=${dumpdir}/${dset}/xvecs
|
||||
mkdir -p ${_logdir}
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/sv_inference.JOB.log \
|
||||
python -m funasr.bin.sv_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${key_file},speech,kaldi_ark" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--sv_train_config ${sv_exp_dir}/sv_fbank.yaml \
|
||||
--sv_model_file ${sv_exp_dir}/sv.pth \
|
||||
--output_dir "${_logdir}"/output.JOB
|
||||
cat ${_logdir}/output.*/xvector.scp | sort > ${datadir}/${dset}/utt2xvec
|
||||
done
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||
echo "Stage 7: Generate label files."
|
||||
|
||||
for dset in simu_swbd_sre callhome1 callhome2; do
|
||||
echo "Stage 7: Generate labels for ${dset}."
|
||||
python -Wignore script/calc_real_meeting_frame_labels.py \
|
||||
${datadir}/${dset} ${dumpdir}/${dset}/labels \
|
||||
--n_spk 8 --frame_shift 0.01 --nj 16 --sr 8000
|
||||
find `pwd`/${dumpdir}/${dset}/labels -iname "*.lbl.mat" | awk -F'[/.]' '{print $(NF-2),$0}' | sort > ${datadir}/${dset}/labels.scp
|
||||
done
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
|
||||
echo "Stage 8: Make training and evaluation files."
|
||||
|
||||
# dump simulated data in training mode (randomly shuffle the speaker order).
|
||||
data_dir=${datadir}/simu_swbd_sre/files_for_dump
|
||||
mkdir ${data_dir}
|
||||
cp ${datadir}/simu_swbd_sre/{feats.scp,labels.scp} ${data_dir}/
|
||||
cp ${datadir}/swbd_sre/none_silence/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
|
||||
# dump data with the window length of 1600 frames and hop length of 400 frames.
|
||||
echo "Stage 8: start to dump for simu_swbd_sre."
|
||||
for i in $(seq 0 49); do
|
||||
cmd="python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
|
||||
--out ${dumpdir}/simu_swbd_sre/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode train \
|
||||
--chunk_size 1600 --chunk_shift 400 \
|
||||
--task_id ${i} --task_size 2250 &"
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
done
|
||||
wait;
|
||||
mkdir -p ${datadir}/simu_swbd_sre/dumped_files
|
||||
cat ${dumpdir}/simu_swbd_sre/dumped_files/data_parts*_feat.scp | sort > ${datadir}/simu_swbd_sre/dumped_files/feats.scp
|
||||
cat ${dumpdir}/simu_swbd_sre/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/simu_swbd_sre/dumped_files/profile.scp
|
||||
cat ${dumpdir}/simu_swbd_sre/dumped_files/data_parts*_label.scp | sort > ${datadir}/simu_swbd_sre/dumped_files/label.scp
|
||||
mkdir -p ${expdir}/simu_swbd_sre_states
|
||||
awk '{print $1,"1600"}' ${datadir}/simu_swbd_sre/dumped_files/feats.scp | shuf > ${expdir}/simu_swbd_sre_states/speech_shape
|
||||
|
||||
# dump callhome1 data in training mode.
|
||||
data_dir=${datadir}/callhome1/files_for_dump
|
||||
mkdir ${data_dir}
|
||||
# filter out zero duration segments
|
||||
LC_ALL=C awk '{if ($5 > 0){print $0}}' ${datadir}/callhome1/ref.rttm > ${data_dir}/ref.rttm
|
||||
cp ${datadir}/callhome1/{feats.scp,labels.scp} ${data_dir}/
|
||||
cp ${datadir}/callhome1/nonoverlap_0s/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
|
||||
|
||||
echo "Stage 8: start to dump for callhome1."
|
||||
python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
|
||||
--out ${dumpdir}/callhome1/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode test \
|
||||
--chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
|
||||
|
||||
mkdir -p ${datadir}/callhome1/dumped_files
|
||||
cat ${dumpdir}/callhome1/dumped_files/data_parts*_feat.scp | sort > ${datadir}/callhome1/dumped_files/feats.scp
|
||||
cat ${dumpdir}/callhome1/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/callhome1/dumped_files/profile.scp
|
||||
cat ${dumpdir}/callhome1/dumped_files/data_parts*_label.scp | sort > ${datadir}/callhome1/dumped_files/label.scp
|
||||
mkdir -p ${expdir}/callhome1_states
|
||||
awk '{print $1,"1600"}' ${datadir}/callhome1/dumped_files/feats.scp | shuf > ${expdir}/callhome1_states/speech_shape
|
||||
python -Wignore script/convert_rttm_to_seg_file.py --rttm_scp ${data_dir}/ref.rttm --seg_file ${data_dir}/org_vad.txt
|
||||
|
||||
# dump callhome2 data in test mode.
|
||||
data_dir=${datadir}/callhome2/files_for_dump
|
||||
mkdir ${data_dir}
|
||||
# filter out zero duration segments
|
||||
LC_ALL=C awk '{if ($5 > 0){print $0}}' ${datadir}/callhome2/ref.rttm > ${data_dir}/ref.rttm
|
||||
cp ${datadir}/callhome2/{feats.scp,labels.scp} ${data_dir}/
|
||||
cp ${datadir}/callhome2/nonoverlap_0s/{utt2spk,utt2xvec,utt2num_frames} ${data_dir}/
|
||||
|
||||
echo "Stage 8: start to dump for callhome2."
|
||||
python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
|
||||
--out ${dumpdir}/callhome2/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode test \
|
||||
--chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
|
||||
|
||||
mkdir -p ${datadir}/callhome2/dumped_files
|
||||
cat ${dumpdir}/callhome2/dumped_files/data_parts*_feat.scp | sort > ${datadir}/callhome2/dumped_files/feats.scp
|
||||
cat ${dumpdir}/callhome2/dumped_files/data_parts*_xvec.scp | sort > ${datadir}/callhome2/dumped_files/profile.scp
|
||||
cat ${dumpdir}/callhome2/dumped_files/data_parts*_label.scp | sort > ${datadir}/callhome2/dumped_files/label.scp
|
||||
mkdir -p ${expdir}/callhome2_states
|
||||
awk '{print $1,"1600"}' ${datadir}/callhome2/dumped_files/feats.scp | shuf > ${expdir}/callhome2_states/speech_shape
|
||||
python -Wignore script/convert_rttm_to_seg_file.py --rttm_scp ${data_dir}/ref.rttm --seg_file ${data_dir}/org_vad.txt
|
||||
|
||||
fi
|
||||
|
||||
# Training Stage, phase 1, pretraining on simulated data with frozen encoder parameters.
|
||||
# This training may cost about 1.8 days.
|
||||
if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
|
||||
echo "stage 10: training phase 1, pretraining on simulated data"
|
||||
world_size=$gpu_num # run on one machine
|
||||
mkdir -p ${expdir}/${model_dir}
|
||||
mkdir -p ${expdir}/${model_dir}/log
|
||||
mkdir -p /tmp/${model_dir}
|
||||
INIT_FILE=/tmp/${model_dir}/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_opt=""
|
||||
if [ ! -z "${init_param}" ]; then
|
||||
init_opt="--init_param ${init_param}"
|
||||
echo ${init_opt}
|
||||
fi
|
||||
|
||||
freeze_opt=""
|
||||
if [ ! -z "${freeze_param}" ]; then
|
||||
freeze_opt="--freeze_param ${freeze_param}"
|
||||
echo ${freeze_opt}
|
||||
fi
|
||||
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
|
||||
python -m funasr.bin.diar_train \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor false \
|
||||
--token_type char \
|
||||
--token_list $token_list \
|
||||
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/feats.scp,speech,kaldi_ark \
|
||||
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/profile.scp,profile,kaldi_ark \
|
||||
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
|
||||
--train_shape_file ${expdir}/${train_set}_states/speech_shape \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
|
||||
--valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
|
||||
--init_param ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/sv.pth:encoder:encoder \
|
||||
--unused_parameters true \
|
||||
--freeze_param encoder \
|
||||
${init_opt} \
|
||||
${freeze_opt} \
|
||||
--ignore_init_mismatch true \
|
||||
--resume true \
|
||||
--output_dir ${expdir}/${model_dir} \
|
||||
--config $train_config \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--multiprocessing_distributed true \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${expdir}/${model_dir}/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
echo "Training log can be found at ${expdir}/${model_dir}/log/train.log.*"
|
||||
wait
|
||||
fi
|
||||
|
||||
# evaluate for pretrained model
|
||||
if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
|
||||
echo "stage 11: evaluation for phase-1 model."
|
||||
for dset in ${test_sets}; do
|
||||
echo "Processing for $dset"
|
||||
exp_model_dir=${expdir}/${model_dir}
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${exp_model_dir}/${_inference_tag}/${inference_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
if [ -d ${_dir} ]; then
|
||||
echo "WARNING: ${_dir} is already exists."
|
||||
fi
|
||||
mkdir -p "${_logdir}"
|
||||
_data="${datadir}/${dset}/dumped_files"
|
||||
key_file=${_data}/feats.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
_opt=
|
||||
if [ ! -z "${inference_config}" ]; then
|
||||
_opt="--config ${inference_config}"
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
echo "Inference log can be found at ${_logdir}/inference.*.log"
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
|
||||
python -m funasr.bin.diar_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
|
||||
--data_path_and_name_and_type "${_data}/profile.scp,profile,kaldi_ark" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--diar_train_config "${exp_model_dir}"/config.yaml \
|
||||
--diar_model_file "${exp_model_dir}"/"${inference_model}" \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode sond ${_opt}
|
||||
done
|
||||
fi
|
||||
|
||||
# Scoring for pretrained model, you may get a DER like 13.73 16.25
|
||||
# 13.73: with oracle VAD, 16.25: with only SOND outputs, aka, system VAD.
|
||||
if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
|
||||
echo "stage 12: Scoring phase-1 models"
|
||||
if [ ! -e dscore ]; then
|
||||
git clone https://github.com/nryant/dscore.git
|
||||
# add intervaltree to setup.py
|
||||
fi
|
||||
for dset in ${test_sets}; do
|
||||
echo "stage 12: Scoring for ${dset}"
|
||||
diar_exp=${expdir}/${model_dir}
|
||||
_data="${datadir}/${dset}"
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
cat ${_logdir}/*/labels.txt | sort > ${_dir}/labels.txt
|
||||
|
||||
cmd="python -Wignore script/convert_label_to_rttm.py ${_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${_dir}/sys.rttm \
|
||||
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
|
||||
# echo ${cmd}
|
||||
eval ${cmd}
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${_dir}/sys.rttm.ref_vad
|
||||
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${_dir}/sys.rttm.sys_vad
|
||||
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
echo -e "${inference_model} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${_dir}/results.txt
|
||||
done
|
||||
fi
|
||||
|
||||
# Training Stage, phase 2, training on simulated data without frozen parameters.
|
||||
# For V100-16G, please set batch_size to 8 in the config, and use 4 GPU to train the model with options like --gpu_devices 4,5,6,7 --gpu_num 4.
|
||||
# For V100-32G, please set batch_size to 16 in the config, and use 2 GPU to train the model with options like --gpu_devices 4,5,6,7 --gpu_num 2.
|
||||
# This training may cost about 3.5 days.
|
||||
if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then
|
||||
echo "stage 13: training phase 2, training on simulated data"
|
||||
world_size=$gpu_num # run on one machine
|
||||
mkdir -p ${expdir}/${model_dir}_phase2
|
||||
mkdir -p ${expdir}/${model_dir}_phase2/log
|
||||
mkdir -p /tmp/${model_dir}_phase2
|
||||
INIT_FILE=/tmp/${model_dir}_phase2/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_opt=""
|
||||
if [ ! -z "${init_param}" ]; then
|
||||
init_opt="--init_param ${init_param}"
|
||||
echo ${init_opt}
|
||||
fi
|
||||
|
||||
freeze_opt=""
|
||||
if [ ! -z "${freeze_param}" ]; then
|
||||
freeze_opt="--freeze_param ${freeze_param}"
|
||||
echo ${freeze_opt}
|
||||
fi
|
||||
|
||||
phase2_config="$(dirname "${train_config}")/$(basename "${train_config}" .yaml)_phase2.yaml"
|
||||
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
|
||||
python -m funasr.bin.diar_train \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor false \
|
||||
--token_type char \
|
||||
--token_list $token_list \
|
||||
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/feats.scp,speech,kaldi_ark \
|
||||
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/profile.scp,profile,kaldi_ark \
|
||||
--train_data_path_and_name_and_type ${datadir}/${train_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
|
||||
--train_shape_file ${expdir}/${train_set}_states/speech_shape \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
|
||||
--valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
|
||||
--init_param exp/${model_dir}/valid.der.ave_5best.pth \
|
||||
--unused_parameters true \
|
||||
${init_opt} \
|
||||
${freeze_opt} \
|
||||
--ignore_init_mismatch true \
|
||||
--resume true \
|
||||
--output_dir ${expdir}/${model_dir}_phase2 \
|
||||
--config ${phase2_config} \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--multiprocessing_distributed true \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${expdir}/${model_dir}_phase2/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
echo "Training log can be found at ${expdir}/${model_dir}_phase2/log/train.log.*"
|
||||
wait
|
||||
fi
|
||||
|
||||
# evaluate for phase-2 model
|
||||
if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then
|
||||
echo "stage 14: evaluation for phase-2 model ${inference_model}."
|
||||
for dset in ${test_sets}; do
|
||||
echo "Processing for $dset"
|
||||
exp_model_dir=${expdir}/${model_dir}_phase2
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${exp_model_dir}/${_inference_tag}/${inference_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
if [ -d ${_dir} ]; then
|
||||
echo "WARNING: ${_dir} is already exists."
|
||||
fi
|
||||
mkdir -p "${_logdir}"
|
||||
_data="${datadir}/${dset}/dumped_files"
|
||||
key_file=${_data}/feats.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
_opt=
|
||||
if [ ! -z "${inference_config}" ]; then
|
||||
_opt="--config ${inference_config}"
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
echo "Inference log can be found at ${_logdir}/inference.*.log"
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
|
||||
python -m funasr.bin.diar_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
|
||||
--data_path_and_name_and_type "${_data}/profile.scp,profile,kaldi_ark" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--diar_train_config "${exp_model_dir}"/config.yaml \
|
||||
--diar_model_file "${exp_model_dir}"/${inference_model} \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode sond ${_opt}
|
||||
done
|
||||
fi
|
||||
|
||||
# Scoring for pretrained model, you may get a DER like 11.25 15.30
|
||||
# 11.25: with oracle VAD, 15.30: with only SOND outputs, aka, system VAD.
|
||||
if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then
|
||||
echo "stage 15: Scoring phase-2 models"
|
||||
if [ ! -e dscore ]; then
|
||||
git clone https://github.com/nryant/dscore.git
|
||||
# add intervaltree to setup.py
|
||||
fi
|
||||
for dset in ${test_sets}; do
|
||||
echo "stage 15: Scoring for ${dset}"
|
||||
diar_exp=${expdir}/${model_dir}_phase2
|
||||
_data="${datadir}/${dset}"
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
cat ${_logdir}/*/labels.txt | sort > ${_dir}/labels.txt
|
||||
|
||||
cmd="python -Wignore script/convert_label_to_rttm.py ${_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${_dir}/sys.rttm \
|
||||
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
|
||||
# echo ${cmd}
|
||||
eval ${cmd}
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${_dir}/sys.rttm.ref_vad
|
||||
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${_dir}/sys.rttm.sys_vad
|
||||
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
echo -e "${inference_model} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${_dir}/results.txt
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
# Finetune Stage, phase 3, training on callhom1 training set
|
||||
if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then
|
||||
echo "stage 16: training phase 3, finetuing on callhome1 real data"
|
||||
world_size=$gpu_num # run on one machine
|
||||
mkdir -p ${expdir}/${model_dir}_phase3
|
||||
mkdir -p ${expdir}/${model_dir}_phase3/log
|
||||
mkdir -p /tmp/${model_dir}_phase3
|
||||
INIT_FILE=/tmp/${model_dir}_phase3/ddp_init
|
||||
if [ -f $INIT_FILE ];then
|
||||
rm -f $INIT_FILE
|
||||
fi
|
||||
init_opt=""
|
||||
if [ ! -z "${init_param}" ]; then
|
||||
init_opt="--init_param ${init_param}"
|
||||
echo ${init_opt}
|
||||
fi
|
||||
|
||||
freeze_opt=""
|
||||
if [ ! -z "${freeze_param}" ]; then
|
||||
freeze_opt="--freeze_param ${freeze_param}"
|
||||
echo ${freeze_opt}
|
||||
fi
|
||||
|
||||
phase3_config="$(dirname "${train_config}")/$(basename "${train_config}" .yaml)_phase3.yaml"
|
||||
|
||||
init_method=file://$(readlink -f $INIT_FILE)
|
||||
echo "$0: init method is $init_method"
|
||||
for ((i = 0; i < $gpu_num; ++i)); do
|
||||
{
|
||||
rank=$i
|
||||
local_rank=$i
|
||||
gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
|
||||
python -m funasr.bin.diar_train \
|
||||
--gpu_id $gpu_id \
|
||||
--use_preprocessor false \
|
||||
--token_type char \
|
||||
--token_list $token_list \
|
||||
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
|
||||
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
|
||||
--train_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
|
||||
--train_shape_file ${expdir}/${valid_set}_states/speech_shape \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/feats.scp,speech,kaldi_ark \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
|
||||
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
|
||||
--valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
|
||||
--init_param exp/${model_dir}_phase2/valid.forward_steps.ave_5best.pth \
|
||||
--unused_parameters true \
|
||||
${init_opt} \
|
||||
${freeze_opt} \
|
||||
--ignore_init_mismatch true \
|
||||
--resume true \
|
||||
--output_dir ${expdir}/${model_dir}_phase3 \
|
||||
--config ${phase3_config} \
|
||||
--ngpu $gpu_num \
|
||||
--num_worker_count $count \
|
||||
--multiprocessing_distributed true \
|
||||
--dist_init_method $init_method \
|
||||
--dist_world_size $world_size \
|
||||
--dist_rank $rank \
|
||||
--local_rank $local_rank 1> ${expdir}/${model_dir}_phase3/log/train.log.$i 2>&1
|
||||
} &
|
||||
done
|
||||
echo "Training log can be found at ${expdir}/${model_dir}_phase3/log/train.log.*"
|
||||
wait
|
||||
fi
|
||||
|
||||
# evaluate for finetuned model
|
||||
if [ ${stage} -le 17 ] && [ ${stop_stage} -ge 17 ]; then
|
||||
echo "stage 17: evaluation for finetuned model ${inference_model}."
|
||||
for dset in ${test_sets}; do
|
||||
echo "Processing for $dset"
|
||||
exp_model_dir=${expdir}/${model_dir}_phase3
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${exp_model_dir}/${_inference_tag}/${inference_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
if [ -d ${_dir} ]; then
|
||||
echo "WARNING: ${_dir} is already exists."
|
||||
fi
|
||||
mkdir -p "${_logdir}"
|
||||
_data="${datadir}/${dset}/dumped_files"
|
||||
key_file=${_data}/feats.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
_opt=
|
||||
if [ ! -z "${inference_config}" ]; then
|
||||
_opt="--config ${inference_config}"
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
echo "Inference log can be found at ${_logdir}/inference.*.log"
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
|
||||
python -m funasr.bin.diar_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
|
||||
--data_path_and_name_and_type "${_data}/profile.scp,profile,kaldi_ark" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--diar_train_config "${exp_model_dir}"/config.yaml \
|
||||
--diar_model_file "${exp_model_dir}"/${inference_model} \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode sond ${_opt}
|
||||
done
|
||||
fi
|
||||
|
||||
# average 3 4 5 6 7 epoch
|
||||
# Scoring for pretrained model, you may get a DER like
|
||||
# 7.21 8.05 on callhome1
|
||||
# 8.31 9.32 on callhome2
|
||||
if [ ${stage} -le 18 ] && [ ${stop_stage} -ge 18 ]; then
|
||||
echo "stage 18: Scoring finetuned models"
|
||||
if [ ! -e dscore ]; then
|
||||
git clone https://github.com/nryant/dscore.git
|
||||
# add intervaltree to setup.py
|
||||
fi
|
||||
for dset in ${test_sets}; do
|
||||
echo "stage 18: Scoring for ${dset}"
|
||||
diar_exp=${expdir}/${model_dir}_phase3
|
||||
_data="${datadir}/${dset}"
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
|
||||
_logdir="${_dir}/logdir"
|
||||
cat ${_logdir}/*/labels.txt | sort > ${_dir}/labels.txt
|
||||
|
||||
cmd="python -Wignore script/convert_label_to_rttm.py ${_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${_dir}/sys.rttm \
|
||||
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
|
||||
echo ${cmd}
|
||||
eval ${cmd}
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${_dir}/sys.rttm.ref_vad
|
||||
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${_dir}/sys.rttm.sys_vad
|
||||
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
echo -e "${inference_model} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${_dir}/results.txt
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 19 ] && [ ${stop_stage} -ge 19 ]; then
|
||||
for dset in ${test_sets}; do
|
||||
echo "stage 19: Evaluating phase-3 system on ${dset} set with medfilter_size=83 clustering=EEND-OLA"
|
||||
sv_exp_dir=${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
|
||||
diar_exp=${expdir}/${model_dir}_phase3
|
||||
_data="${datadir}/${dset}/dumped_files"
|
||||
_inference_tag="$(basename "${inference_config}" .yaml)${inference_tag}"
|
||||
_dir="${diar_exp}/${_inference_tag}/${inference_model}/${dset}"
|
||||
|
||||
for iter in `seq 0 ${told_max_iter}`; do
|
||||
eval_dir=${_dir}/iter_${iter}
|
||||
if [ $iter -eq 0 ]; then
|
||||
prev_rttm=${expdir}/EEND-OLA/sys.rttm
|
||||
else
|
||||
prev_rttm=${_dir}/iter_$((${iter}-1))/sys.rttm.sys_vad
|
||||
fi
|
||||
echo "Use ${prev_rttm} as system outputs."
|
||||
|
||||
echo "Iteration ${iter}, step 1: extracting non-overlap segments"
|
||||
cmd="python -Wignore script/extract_nonoverlap_segments.py ${datadir}/${dset}/wav.scp \
|
||||
$prev_rttm ${eval_dir}/nonoverlap_segs/ --min_dur 0.1 --max_spk_num 16 --no_pbar --sr 8000"
|
||||
# echo ${cmd}
|
||||
eval ${cmd}
|
||||
|
||||
echo "Iteration ${iter}, step 2: make data directory"
|
||||
mkdir -p ${eval_dir}/data
|
||||
find `pwd`/${eval_dir}/nonoverlap_segs/ -iname "*.wav" | sort > ${eval_dir}/data/wav.flist
|
||||
awk -F'[/.]' '{print $(NF-1),$0}' ${eval_dir}/data/wav.flist > ${eval_dir}/data/wav.scp
|
||||
awk -F'[/.]' '{print $(NF-1),$(NF-2)}' ${eval_dir}/data/wav.flist > ${eval_dir}/data/utt2spk
|
||||
cp $prev_rttm ${eval_dir}/data/sys.rttm
|
||||
home_path=`pwd`
|
||||
|
||||
echo "Iteration ${iter}, step 3: calc x-vector for each utt"
|
||||
key_file=${eval_dir}/data/wav.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
_logdir=${eval_dir}/data/xvecs
|
||||
mkdir -p ${_logdir}
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/sv_inference.JOB.log \
|
||||
python -m funasr.bin.sv_inference_launch \
|
||||
--njob ${njob} \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${key_file},speech,sound" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--sv_train_config ${sv_exp_dir}/sv.yaml \
|
||||
--sv_model_file ${sv_exp_dir}/sv.pth \
|
||||
--output_dir "${_logdir}"/output.JOB
|
||||
cat ${_logdir}/output.*/xvector.scp | sort > ${eval_dir}/data/utt2xvec
|
||||
|
||||
echo "Iteration ${iter}, step 4: dump x-vector record"
|
||||
awk '{print $1}' ${_data}/feats.scp > ${eval_dir}/data/idx
|
||||
python script/dump_speaker_profiles.py --dir ${eval_dir}/data \
|
||||
--out ${eval_dir}/global_n16 --n_spk 16 --no_pbar --emb_type global
|
||||
spk_profile=${eval_dir}/global_n16_parts00_xvec.scp
|
||||
|
||||
echo "Iteration ${iter}, step 5: perform NN diarization"
|
||||
_logdir=${eval_dir}/diar
|
||||
mkdir -p ${_logdir}
|
||||
key_file=${_data}/feats.scp
|
||||
num_scp_file="$(<${key_file} wc -l)"
|
||||
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
|
||||
split_scps=
|
||||
for n in $(seq "${_nj}"); do
|
||||
split_scps+=" ${_logdir}/keys.${n}.scp"
|
||||
done
|
||||
_opt=
|
||||
if [ ! -z "${inference_config}" ]; then
|
||||
_opt="--config ${inference_config}"
|
||||
fi
|
||||
# shellcheck disable=SC2086
|
||||
utils/split_scp.pl "${key_file}" ${split_scps}
|
||||
|
||||
echo "Inference log can be found at ${_logdir}/inference.*.log"
|
||||
${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/inference.JOB.log \
|
||||
python -m funasr.bin.diar_inference_launch \
|
||||
--batch_size 1 \
|
||||
--ngpu "${_ngpu}" \
|
||||
--njob ${njob} \
|
||||
--gpuid_list ${gpuid_list} \
|
||||
--data_path_and_name_and_type "${_data}/feats.scp,speech,kaldi_ark" \
|
||||
--data_path_and_name_and_type "${spk_profile},profile,kaldi_ark" \
|
||||
--key_file "${_logdir}"/keys.JOB.scp \
|
||||
--diar_train_config ${diar_exp}/config.yaml \
|
||||
--diar_model_file ${diar_exp}/${inference_model} \
|
||||
--output_dir "${_logdir}"/output.JOB \
|
||||
--mode sond ${_opt}
|
||||
|
||||
echo "Iteration ${iter}, step 6: calc diarization results"
|
||||
cat ${_logdir}/output.*/labels.txt | sort > ${eval_dir}/labels.txt
|
||||
|
||||
cmd="python -Wignore script/convert_label_to_rttm.py ${eval_dir}/labels.txt ${datadir}/${dset}/files_for_dump/org_vad.txt ${eval_dir}/sys.rttm \
|
||||
--ignore_len 10 --no_pbar --smooth_size 83 --vote_prob 0.5 --n_spk 16"
|
||||
# echo ${cmd}
|
||||
eval ${cmd}
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${eval_dir}/sys.rttm.ref_vad
|
||||
OVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
ref=${datadir}/${dset}/files_for_dump/ref.rttm
|
||||
sys=${eval_dir}/sys.rttm.sys_vad
|
||||
SysVAD_DER=$(python -Wignore dscore/score.py -r $ref -s $sys --collar 0.25 2>&1 | grep OVERALL | awk '{print $4}')
|
||||
|
||||
echo -e "${inference_model}/iter_${iter} ${OVAD_DER} ${SysVAD_DER}" | tee -a ${eval_dir}/results.txt
|
||||
done
|
||||
|
||||
echo "Done."
|
||||
done
|
||||
fi
|
||||
21
egs/callhome/diarization/sond/script/calc_num_frames.py
Normal file
21
egs/callhome/diarization/sond/script/calc_num_frames.py
Normal file
@ -0,0 +1,21 @@
|
||||
import os
|
||||
import sys
|
||||
import soundfile as sf
|
||||
from funasr.utils.misc import load_scp_as_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
wav_scp = sys.argv[1]
|
||||
out_file = sys.argv[2]
|
||||
frame_shift = 0.01
|
||||
|
||||
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
||||
|
||||
out_file = open(out_file, "wt")
|
||||
for uttid, wav_path in load_scp_as_list(wav_scp):
|
||||
wav, sr = sf.read(wav_path)
|
||||
num_frame = wav.shape[0] // int(sr * frame_shift)
|
||||
out_file.write(f"{uttid} {num_frame}\n")
|
||||
out_file.flush()
|
||||
|
||||
out_file.close()
|
||||
@ -0,0 +1,101 @@
|
||||
import numpy as np
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import librosa
|
||||
import scipy.io as sio
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("dir", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
parser.add_argument("--n_spk", type=int, default=8)
|
||||
parser.add_argument("--remove_sil", default=False, action="store_true")
|
||||
parser.add_argument("--frame_shift", type=float, default=0.01)
|
||||
args = parser.parse_args()
|
||||
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
|
||||
|
||||
meeting_scp = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
|
||||
meeting2rttm = self.load_rttm(args.dir)
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
|
||||
return task_list, None, args
|
||||
|
||||
def load_rttm(self, dir_path):
|
||||
meeting2rttm = OrderedDict()
|
||||
if os.path.exists(os.path.join(dir_path, "rttm.scp")):
|
||||
rttm_scp = load_scp_as_list(os.path.join(dir_path, "rttm.scp"))
|
||||
for mid, rttm_path in rttm_scp:
|
||||
meeting2rttm[mid] = []
|
||||
for one_line in open(rttm_path, "rt"):
|
||||
meeting2rttm[mid].append(one_line.strip())
|
||||
elif os.path.exists(os.path.join(dir_path, "ref.rttm")):
|
||||
for one_line in open(os.path.join(dir_path, "ref.rttm"), "rt"):
|
||||
mid = one_line.strip().split(" ")[1]
|
||||
if mid not in meeting2rttm:
|
||||
meeting2rttm[mid] = []
|
||||
meeting2rttm[mid].append(one_line.strip())
|
||||
else:
|
||||
raise IOError("Neither rttm.scp nor ref.rttm exists in {}".format(dir_path))
|
||||
|
||||
return meeting2rttm
|
||||
|
||||
def post(self, results_list, args):
|
||||
pass
|
||||
|
||||
|
||||
def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=8000, frame_shift=0.01):
|
||||
frame_shift = int(frame_shift * sr)
|
||||
num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
|
||||
multi_label = np.zeros([n_spk, num_frame], dtype=int)
|
||||
for _, st, dur, spk in spk_turns:
|
||||
idx = spk_list.index(spk)
|
||||
|
||||
st, dur = int(st * sr), int(dur * sr)
|
||||
frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
|
||||
frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
|
||||
multi_label[idx, frame_st:frame_ed] = 1
|
||||
if not remove_sil:
|
||||
return multi_label.T
|
||||
|
||||
speech_count = np.sum(multi_label, axis=0)
|
||||
idx = np.nonzero(speech_count)[0]
|
||||
label = multi_label[:, idx].T
|
||||
return label # (T, N)
|
||||
|
||||
|
||||
def build_labels(wav_path, rttms, n_spk, remove_sil=False, sr=8000, frame_shift=0.01):
|
||||
wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr)
|
||||
spk_turns = []
|
||||
spk_list = []
|
||||
for one_line in rttms:
|
||||
parts = one_line.strip().split(" ")
|
||||
mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
if spk not in spk_list:
|
||||
spk_list.append(spk)
|
||||
spk_turns.append((mid, st, dur, spk))
|
||||
labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, sr, frame_shift)
|
||||
return labels, spk_list
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
for mid, wav_path, rttms in task_list:
|
||||
meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil,
|
||||
args.sr, args.frame_shift)
|
||||
save_path = os.path.join(args.out_dir, "{}.lbl.mat".format(mid))
|
||||
sio.savemat(save_path, {"labels": meeting_labels.astype(bool), "spk_list": spk_list})
|
||||
# print mid
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
57
egs/callhome/diarization/sond/script/concat_spk_segs.py
Normal file
57
egs/callhome/diarization/sond/script/concat_spk_segs.py
Normal file
@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
import argparse
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("dir", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
args = parser.parse_args()
|
||||
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
print("loading data...")
|
||||
wav_scp = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
|
||||
utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
|
||||
|
||||
spk2utt = {}
|
||||
count = 0
|
||||
for utt, spk in utt2spk.items():
|
||||
if utt in wav_scp:
|
||||
if spk not in spk2utt:
|
||||
spk2utt[spk] = []
|
||||
spk2utt[spk].append(utt)
|
||||
count += 1
|
||||
task_list = spk2utt.keys()
|
||||
print("total: {} speakers, {} utterances".format(len(spk2utt), count))
|
||||
print("starting jobs...")
|
||||
return task_list, [spk2utt, wav_scp], args
|
||||
|
||||
def post(self, results_list, args):
|
||||
pass
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, [spk2utt, wav_scp], args = task_args
|
||||
for spk in task_list:
|
||||
spk_wav_list = []
|
||||
for utt in spk2utt[spk]:
|
||||
wav = librosa.load(wav_scp[utt], sr=args.sr, mono=True)[0] * 32767
|
||||
spk_wav_list.append(wav)
|
||||
sig = np.concatenate(spk_wav_list, axis=0)
|
||||
save_path = os.path.join(args.out_dir, "{}.wav".format(spk))
|
||||
sf.write(save_path, sig.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
201
egs/callhome/diarization/sond/script/convert_label_to_rttm.py
Normal file
201
egs/callhome/diarization/sond/script/convert_label_to_rttm.py
Normal file
@ -0,0 +1,201 @@
|
||||
import os
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
import numpy as np
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
from collections import OrderedDict
|
||||
from tqdm import tqdm
|
||||
from scipy.ndimage import median_filter
|
||||
import kaldiio
|
||||
|
||||
|
||||
def load_mid_vad(vad_path):
|
||||
mid2segment_list = {}
|
||||
for one_line in open(vad_path, "rt"):
|
||||
utt_id, mid, start, end = one_line.strip().split(" ")
|
||||
start, end = float(start), float(end)
|
||||
if mid not in mid2segment_list:
|
||||
mid2segment_list[mid] = []
|
||||
mid2segment_list[mid].append((utt_id, start, end))
|
||||
|
||||
return mid2segment_list
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("label_txt", type=str)
|
||||
parser.add_argument("oracle_vad", type=str, default=None)
|
||||
parser.add_argument("out_rttm", type=str)
|
||||
parser.add_argument("--sys_vad_prob", type=str, default=None)
|
||||
parser.add_argument("--sys_vad_threshold", type=float, default=None)
|
||||
parser.add_argument("--vad_smooth_size", type=int, default=7)
|
||||
parser.add_argument("--n_spk", type=int, default=4)
|
||||
parser.add_argument("--chunk_len", type=int, default=1600)
|
||||
parser.add_argument("--shift_len", type=int, default=400)
|
||||
parser.add_argument("--ignore_len", type=int, default=5)
|
||||
parser.add_argument("--smooth_size", type=int, default=7)
|
||||
parser.add_argument("--vote_prob", type=float, default=0.5)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(os.path.dirname(args.out_rttm)):
|
||||
os.makedirs(os.path.dirname(args.out_rttm))
|
||||
|
||||
utt2labels = load_scp_as_list(args.label_txt, 'list')
|
||||
utt2vad_prob = []
|
||||
if args.sys_vad_prob is not None and os.path.exists(args.sys_vad_prob):
|
||||
if args.verbose:
|
||||
print("Use system vad ark file {}".format(args.sys_vad_prob))
|
||||
for (key, vad_prob), (utt_id, _) in zip(kaldiio.load_ark(args.sys_vad_prob), utt2labels):
|
||||
utt2vad_prob.append((utt_id, vad_prob))
|
||||
utt2vad_prob = sorted(utt2vad_prob, key=lambda x: x[0])
|
||||
|
||||
utt2labels = sorted(utt2labels, key=lambda x: x[0])
|
||||
mid2segment_list = load_mid_vad(args.oracle_vad)
|
||||
meeting2labels = OrderedDict()
|
||||
for utt_id, chunk_label in utt2labels:
|
||||
mid = utt_id.split("-")[0]
|
||||
if mid not in meeting2labels:
|
||||
meeting2labels[mid] = []
|
||||
meeting2labels[mid].append(chunk_label)
|
||||
|
||||
mid2vad_list = {}
|
||||
if len(utt2vad_prob) > 0:
|
||||
for utt_id, vad_prob in utt2vad_prob:
|
||||
mid = utt_id.split("-")[0]
|
||||
if mid not in mid2vad_list:
|
||||
mid2vad_list[mid] = []
|
||||
mid2vad_list[mid].append(vad_prob)
|
||||
|
||||
task_list = [(mid, labels, mid2segment_list[mid], None) if len(mid2vad_list) == 0 else
|
||||
(mid, labels, mid2segment_list[mid], mid2vad_list[mid])
|
||||
for mid, labels in meeting2labels.items()]
|
||||
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, result_list, args):
|
||||
ref_vad_rttm = open(args.out_rttm + ".ref_vad", "wt")
|
||||
sys_vad_rttm = open(args.out_rttm + ".sys_vad", "wt")
|
||||
for results in result_list:
|
||||
for one_result in results:
|
||||
ref_vad_rttm.writelines(one_result[0])
|
||||
sys_vad_rttm.writelines(one_result[1])
|
||||
ref_vad_rttm.close()
|
||||
sys_vad_rttm.close()
|
||||
|
||||
|
||||
def int2vec(x, vec_dim=8, dtype=np.int):
|
||||
b = ('{:0' + str(vec_dim) + 'b}').format(x)
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
||||
|
||||
|
||||
def seq2arr(seq, vec_dim=8):
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
|
||||
|
||||
def sample2ms(sample, sr=16000):
|
||||
return int(float(sample) / sr * 100)
|
||||
|
||||
|
||||
def calc_multi_labels(chunk_label_list, chunk_len, shift_len, n_spk, vote_prob=0.5):
|
||||
n_chunk = len(chunk_label_list)
|
||||
last_chunk_valid_frame = len(chunk_label_list[-1]) - (chunk_len - shift_len)
|
||||
n_frame = (n_chunk - 2) * shift_len + chunk_len + last_chunk_valid_frame
|
||||
multi_labels = np.zeros((n_frame, n_spk), dtype=float)
|
||||
weight = np.zeros((n_frame, 1), dtype=float)
|
||||
for i in range(n_chunk):
|
||||
raw_label = chunk_label_list[i]
|
||||
for k in range(len(raw_label)):
|
||||
if raw_label[k] == '<unk>':
|
||||
raw_label[k] = raw_label[k-1] if k > 0 else '0'
|
||||
chunk_multi_label = seq2arr(raw_label, n_spk)
|
||||
chunk_len = chunk_multi_label.shape[0]
|
||||
multi_labels[i*shift_len:i*shift_len+chunk_len, :] += chunk_multi_label
|
||||
weight[i*shift_len:i*shift_len+chunk_len, :] += 1
|
||||
multi_labels = multi_labels / weight # normalizing vote
|
||||
multi_labels = (multi_labels > vote_prob).astype(int) # voting results
|
||||
return multi_labels
|
||||
|
||||
|
||||
def calc_spk_turns(label_arr, spk_list):
|
||||
turn_list = []
|
||||
length = label_arr.shape[0]
|
||||
n_spk = label_arr.shape[1]
|
||||
for k in range(n_spk):
|
||||
if spk_list[k] == "None":
|
||||
continue
|
||||
in_utt = False
|
||||
start = 0
|
||||
for i in range(length):
|
||||
if label_arr[i, k] == 1 and in_utt is False:
|
||||
start = i
|
||||
in_utt = True
|
||||
if label_arr[i, k] == 0 and in_utt is True:
|
||||
turn_list.append([spk_list[k], start, i - start])
|
||||
in_utt = False
|
||||
if in_utt:
|
||||
turn_list.append([spk_list[k], start, length - start])
|
||||
return turn_list
|
||||
|
||||
|
||||
def smooth_multi_labels(multi_label, win_len):
|
||||
multi_label = median_filter(multi_label, (win_len, 1), mode="constant", cval=0.0).astype(int)
|
||||
return multi_label
|
||||
|
||||
|
||||
def calc_vad_mask(segments, total_len):
|
||||
vad_mask = np.zeros((total_len, 1), dtype=int)
|
||||
for _, start, end in segments:
|
||||
start, end = int(start * 100), int(end * 100)
|
||||
vad_mask[start: end] = 1
|
||||
return vad_mask
|
||||
|
||||
|
||||
def calc_system_vad_mask(vad_prob_list, total_len, args):
|
||||
if vad_prob_list is None:
|
||||
return 1
|
||||
|
||||
threshold = args.sys_vad_threshold
|
||||
chunk_len = args.chunk_len
|
||||
shift_len = args.shift_len
|
||||
frame_vad_mask = np.zeros((total_len, 1), dtype=float)
|
||||
weight = np.zeros((total_len, 1), dtype=float)
|
||||
for i, vad_prob in enumerate(vad_prob_list):
|
||||
frame_vad_mask[i * shift_len: i * shift_len + chunk_len] += np.greater(vad_prob, threshold).astype(float)
|
||||
weight[i * shift_len: i * shift_len + chunk_len] += 1.0
|
||||
frame_vad_mask = np.greater(frame_vad_mask / weight, args.vote_prob)
|
||||
frame_vad_mask = frame_vad_mask.astype(int)
|
||||
frame_vad_mask = smooth_multi_labels(frame_vad_mask.astype(int), args.vad_smooth_size)
|
||||
return frame_vad_mask
|
||||
|
||||
|
||||
def generate_rttm(mid, multi_labels, spk_list, args):
|
||||
template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>\n"
|
||||
spk_turns = calc_spk_turns(multi_labels, spk_list)
|
||||
spk_turns = sorted(spk_turns, key=lambda x: x[1])
|
||||
results = []
|
||||
for spk, st, dur in spk_turns:
|
||||
# TODO: handle the leak of segments at the change points
|
||||
if dur > args.ignore_len:
|
||||
results.append(template.format(mid, float(st) / 100, float(dur) / 100, spk))
|
||||
return results
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
spk_list = ["spk{}".format(i+1) for i in range(args.n_spk)]
|
||||
results = []
|
||||
for mid, chunk_label_list, segments_list, sys_vad_list in tqdm(task_list, total=len(task_list),
|
||||
ascii=True, disable=args.no_pbar):
|
||||
multi_labels = calc_multi_labels(chunk_label_list, args.chunk_len, args.shift_len, args.n_spk, args.vote_prob)
|
||||
multi_labels = smooth_multi_labels(multi_labels, args.smooth_size)
|
||||
oracle_vad_mask = calc_vad_mask(segments_list, multi_labels.shape[0])
|
||||
oracle_vad_rttm = generate_rttm(mid, multi_labels * oracle_vad_mask, spk_list, args)
|
||||
system_vad_mask = calc_system_vad_mask(sys_vad_list, multi_labels.shape[0], args)
|
||||
system_vad_rttm = generate_rttm(mid, multi_labels * system_vad_mask, spk_list, args)
|
||||
results.append([oracle_vad_rttm, system_vad_rttm])
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -0,0 +1,35 @@
|
||||
import kaldiio
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
|
||||
def int2vec(x, vec_dim=8, dtype=np.float32):
|
||||
b = ('{:0' + str(vec_dim) + 'b}').format(x)
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == '1').astype(dtype)
|
||||
|
||||
|
||||
def seq2arr(seq, vec_dim=8):
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
scp_file = sys.argv[1]
|
||||
label_file = sys.argv[2]
|
||||
out_file = sys.argv[3]
|
||||
max_spk_num = int(sys.argv[4])
|
||||
|
||||
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
||||
|
||||
out_file = out_file.rsplit('.', maxsplit=1)[0]
|
||||
wav_writer = kaldiio.WriteHelper("ark,scp,f:{}.ark,{}.scp".format(out_file, out_file))
|
||||
for i, (uttid, pse_str) in enumerate(zip(open(scp_file, "rt"), open(label_file, "rt"))):
|
||||
uttid, pse_str = uttid.strip().split(" ", maxsplit=1)[0], pse_str.strip()
|
||||
bin_label = seq2arr(pse_str.split(" "), vec_dim=max_spk_num)
|
||||
wav_writer(uttid, bin_label)
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f"Processed {i} samples, the last is {uttid}")
|
||||
|
||||
wav_writer.close()
|
||||
@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
import os
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("--rttm_scp", type=str)
|
||||
parser.add_argument("--seg_file", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(os.path.dirname(args.seg_file)):
|
||||
os.makedirs(os.path.dirname(args.seg_file))
|
||||
|
||||
meeting2rttms = {}
|
||||
for one_line in open(args.rttm_scp, "rt"):
|
||||
parts = [x for x in one_line.strip().split(" ") if x != ""]
|
||||
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
if mid not in meeting2rttms:
|
||||
meeting2rttms[mid] = []
|
||||
meeting2rttms[mid].append(one_line)
|
||||
|
||||
task_list = list(meeting2rttms.items())
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, results_list, args):
|
||||
with open(args.seg_file, "wt") as fd:
|
||||
for results in results_list:
|
||||
fd.writelines(results)
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
outputs = []
|
||||
for mid, rttms in task_list:
|
||||
spk_turns = []
|
||||
length = 0
|
||||
for one_line in rttms:
|
||||
parts = one_line.strip().split(" ")
|
||||
_, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
st, ed = int(st*100), int((st + dur)*100)
|
||||
length = ed if ed > length else length
|
||||
spk_turns.append([mid, st, ed, spk_name])
|
||||
is_sph = np.zeros((length+1, ), dtype=bool)
|
||||
for _, st, ed, _ in spk_turns:
|
||||
is_sph[st:ed] = True
|
||||
|
||||
st, in_speech = 0, False
|
||||
for i in range(length+1):
|
||||
if not in_speech and is_sph[i]:
|
||||
st, in_speech = i, True
|
||||
if in_speech and not is_sph[i]:
|
||||
in_speech = False
|
||||
outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format(
|
||||
mid, st, i, mid, float(st)/100, float(i)/100
|
||||
))
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
176
egs/callhome/diarization/sond/script/dump_meeting_chunks.py
Normal file
176
egs/callhome/diarization/sond/script/dump_meeting_chunks.py
Normal file
@ -0,0 +1,176 @@
|
||||
import kaldiio
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import numpy as np
|
||||
import argparse
|
||||
import random
|
||||
import scipy.io as sio
|
||||
import logging
|
||||
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
|
||||
|
||||
|
||||
short_spk_list = []
|
||||
def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
|
||||
all_utts = spk2utt[spk]
|
||||
idx_list = list(range(len(all_utts)))
|
||||
random.shuffle(idx_list)
|
||||
count = 0
|
||||
utt_list = []
|
||||
for i in idx_list:
|
||||
utt_id = all_utts[i]
|
||||
utt_list.append(utt_id)
|
||||
count += int(utt2frames[utt_id])
|
||||
if count >= total_len:
|
||||
break
|
||||
if count < 300 and spk not in short_spk_list:
|
||||
logging.warning("{} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
|
||||
short_spk_list.append(spk)
|
||||
|
||||
ivc_list = [kaldiio.load_mat(utt2ivc[utt])[np.newaxis, :] for utt in utt_list]
|
||||
ivc = np.concatenate(ivc_list, axis=0)
|
||||
ivc = np.mean(ivc, axis=0, keepdims=False)
|
||||
return ivc
|
||||
|
||||
|
||||
def process(feat_scp, labels_scp, spk2utt, utt2xvec, utt2frames, args):
|
||||
out_prefix = "{}_parts{:02d}".format(args.out, args.task_id)
|
||||
logger = logging.Logger(out_prefix, logging.INFO)
|
||||
file_handler = logging.FileHandler(out_prefix + ".log", mode="w")
|
||||
file_handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
ivc_dim = 256
|
||||
chunk_size, chunk_shift = args.chunk_size, args.chunk_shift
|
||||
label_weights = 2 ** np.array(list(range(args.n_spk)))
|
||||
feat_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_prefix}_feat.ark,{out_prefix}_feat.scp")
|
||||
ivc_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_prefix}_xvec.ark,{out_prefix}_xvec.scp")
|
||||
label_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_prefix}_label.ark,{out_prefix}_label.scp")
|
||||
train_spk_list = list(spk2utt.keys())
|
||||
|
||||
frames_list = []
|
||||
non_present_spk_list = []
|
||||
for mid, feat_path in tqdm(feat_scp, total=len(feat_scp), ascii=True, disable=args.no_pbar):
|
||||
if mid not in labels_scp:
|
||||
continue
|
||||
feat = kaldiio.load_mat(feat_path)
|
||||
data = sio.loadmat(labels_scp[mid])
|
||||
labels, meeting_spk_list = data["labels"].astype(int), [x.strip() for x in data["spk_list"]]
|
||||
if args.add_mid_to_speaker:
|
||||
meeting_spk_list = ["{}_{}".format(mid, x) if not x.startswith(mid) else x for x in meeting_spk_list]
|
||||
if labels.shape[0] != feat.shape[0]:
|
||||
min_len = min(labels.shape[0], feat.shape[0])
|
||||
labels, feat = labels[:min_len], feat[:min_len]
|
||||
logger.warning("{}: The expected frame_len is {}, but got {}, clip both to {}".format(
|
||||
mid, labels.shape[0], feat.shape[0], min_len))
|
||||
num_frame = feat.shape[0]
|
||||
num_chunk = int(np.ceil(float(num_frame - chunk_size) / chunk_shift)) + 1
|
||||
for i in range(num_chunk):
|
||||
st, ed = i*chunk_shift, i*chunk_shift+chunk_size
|
||||
utt_id = "{}-{:05d}-{:05d}".format(mid, st, ed)
|
||||
chunk_feat = feat[st: ed, :]
|
||||
chunk_label = labels[st: ed, :]
|
||||
frame_pad = chunk_size - chunk_label.shape[0]
|
||||
spk_pad = args.n_spk - chunk_label.shape[1]
|
||||
chunk_feat = np.pad(chunk_feat, [(0, frame_pad), (0, 0)], "constant", constant_values=0)
|
||||
chunk_label = np.pad(chunk_label, [(0, frame_pad), (0, spk_pad)], "constant", constant_values=0)
|
||||
|
||||
feat_writer(utt_id, chunk_feat)
|
||||
|
||||
spk_idx = list(range(max(args.n_spk, len(meeting_spk_list))))
|
||||
spk_list = []
|
||||
if args.mode == "train":
|
||||
random.shuffle(spk_idx)
|
||||
|
||||
if args.n_spk > len(meeting_spk_list):
|
||||
n = random.randint(len(meeting_spk_list), args.n_spk)
|
||||
spk_list.extend(meeting_spk_list)
|
||||
while len(spk_list) < n:
|
||||
spk = random.choice(train_spk_list)
|
||||
if spk not in spk_list:
|
||||
spk_list.append(spk)
|
||||
spk_list.extend(["None"] * (args.n_spk - n))
|
||||
else:
|
||||
raise ValueError("Argument n_spk is too small ({} < {}).".format(args.n_spk, len(meeting_spk_list)))
|
||||
else:
|
||||
spk_list.extend(meeting_spk_list)
|
||||
spk_list.extend(["None"] * max(args.n_spk - len(meeting_spk_list), 0))
|
||||
|
||||
xvec_list = []
|
||||
for i, spk in enumerate(spk_list):
|
||||
if spk == "None":
|
||||
spk_xvec = np.zeros((ivc_dim,), dtype=np.float32)
|
||||
elif spk not in spk2utt:
|
||||
# speaker with very short duration
|
||||
spk_xvec = np.zeros((ivc_dim,), dtype=np.float32)
|
||||
# chunk_label[:, i] = 0
|
||||
if spk not in non_present_spk_list:
|
||||
logging.warning("speaker {}/{} does not appear in spk2utt, since it has very short duration.".format(mid, spk))
|
||||
non_present_spk_list.append(spk)
|
||||
else:
|
||||
spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 3000)[np.newaxis, :]
|
||||
xvec_list.append(spk_xvec)
|
||||
xvec = np.row_stack(xvec_list)
|
||||
# shuffle speaker embedding according spk_idx
|
||||
xvec = xvec[spk_idx, :]
|
||||
ivc_writer(utt_id, xvec)
|
||||
|
||||
# shuffle labels according spk_idx
|
||||
feat_label = chunk_label[:, spk_idx]
|
||||
# feat_label = np.sum(feat_label * label_weights[np.newaxis, :chunk_label.shape[1]], axis=1).astype(str).tolist()
|
||||
label_writer(utt_id, feat_label.astype(np.float32))
|
||||
|
||||
frames_list.append((mid, feat.shape[0]))
|
||||
|
||||
logger.info("{:30s}: {:6d} frames split into {:3d} chunks.".format(mid, num_frame, num_chunk))
|
||||
return frames_list
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dir", required=True, type=str, default=None,
|
||||
help="feats.scp")
|
||||
parser.add_argument("--out", required=True, type=str, default=None,
|
||||
help="The prefix of dumpped files.")
|
||||
parser.add_argument("--n_spk", type=int, default=16)
|
||||
parser.add_argument("--use_lfr", default=False, action="store_true")
|
||||
parser.add_argument("--no_pbar", default=False, action="store_true")
|
||||
parser.add_argument("--sr", type=int, default=8000)
|
||||
parser.add_argument("--chunk_size", type=int, default=1600)
|
||||
parser.add_argument("--chunk_shift", type=int, default=400)
|
||||
parser.add_argument("--mode", type=str, default="train", choices=["train", "test"])
|
||||
parser.add_argument("--task_id", type=int, default=0)
|
||||
parser.add_argument("--task_size", type=int, default=-1)
|
||||
parser.add_argument("--add_mid_to_speaker", type=bool, default=False)
|
||||
args = parser.parse_args()
|
||||
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
|
||||
|
||||
if not os.path.exists(os.path.dirname(args.out)):
|
||||
os.makedirs(os.path.dirname(args.out))
|
||||
|
||||
feat_scp = load_scp_as_list(os.path.join(args.dir, "feats.scp"))
|
||||
if args.task_size > 0:
|
||||
feat_scp = feat_scp[args.task_size*args.task_id: args.task_size*(args.task_id+1)]
|
||||
labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
|
||||
utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
|
||||
utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
|
||||
utt2frames = load_scp_as_dict(os.path.join(args.dir, "utt2num_frames"))
|
||||
|
||||
spk2utt = {}
|
||||
for utt, spk in utt2spk.items():
|
||||
if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
|
||||
if spk not in spk2utt:
|
||||
spk2utt[spk] = []
|
||||
spk2utt[spk].append(utt)
|
||||
logging.info("Obtain {} speakers.".format(len(spk2utt)))
|
||||
logging.info("Task {:02d}: start dump {} meetings.".format(args.task_id, len(feat_scp)))
|
||||
# random.shuffle(feat_scp)
|
||||
meeting_lens = process(feat_scp, labels_scp, spk2utt, utt2xvec, utt2frames, args)
|
||||
total_frames = sum([x[1] for x in meeting_lens])
|
||||
logging.info("Total meetings: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
48
egs/callhome/diarization/sond/script/dump_pipe_wav.py
Normal file
48
egs/callhome/diarization/sond/script/dump_pipe_wav.py
Normal file
@ -0,0 +1,48 @@
|
||||
import os
|
||||
import argparse
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
def prepare(self, parser):
|
||||
assert isinstance(parser, argparse.ArgumentParser)
|
||||
parser.add_argument("wav_scp", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
args = parser.parse_args()
|
||||
# assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
|
||||
|
||||
wav_scp = load_scp_as_list(args.wav_scp)
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
return wav_scp, None, args
|
||||
|
||||
def post(self, result_list, args):
|
||||
count = [0, 0]
|
||||
for result in result_list:
|
||||
count[0] += result[0]
|
||||
count[1] += result[1]
|
||||
print("All threads done, {} success, {} failed.".format(count[0], count[1]))
|
||||
|
||||
|
||||
def process(task_args):
|
||||
task_id, task_list, _, args = task_args
|
||||
|
||||
count = [0, 0]
|
||||
for utt_id, cmd in task_list:
|
||||
try:
|
||||
wav_path = os.path.join(args.out_dir, "{}.wav".format(utt_id))
|
||||
cmd = cmd.replace("|", "> {}".format(wav_path))
|
||||
os.system(cmd)
|
||||
count[0] += 1
|
||||
except:
|
||||
print("Failed execute command for {}.".format(utt_id))
|
||||
count[1] += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
117
egs/callhome/diarization/sond/script/dump_speaker_profiles.py
Normal file
117
egs/callhome/diarization/sond/script/dump_speaker_profiles.py
Normal file
@ -0,0 +1,117 @@
|
||||
import kaldiio
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import numpy as np
|
||||
import argparse
|
||||
from kaldiio import WriteHelper
|
||||
|
||||
|
||||
def calc_global_ivc(spk, spk2utt, utt2ivc):
|
||||
ivc_list = [kaldiio.load_mat(utt2ivc[utt])[np.newaxis, :] for utt in spk2utt[spk]]
|
||||
ivc = np.concatenate(ivc_list, axis=0)
|
||||
ivc = np.mean(ivc, axis=0, keepdims=False)
|
||||
return ivc
|
||||
|
||||
|
||||
def process(idx_scp, spk2utt, utt2xvec, meeting2spk_list, args):
|
||||
out_prefix = args.out
|
||||
|
||||
ivc_dim = 256
|
||||
print("ivc_dim = {}".format(ivc_dim))
|
||||
out_prefix = out_prefix + "_parts00_xvec"
|
||||
ivc_writer = WriteHelper(f"ark,scp,f:{out_prefix}.ark,{out_prefix}.scp")
|
||||
idx_writer = open(out_prefix + ".idx", "wt")
|
||||
spk2xvec = {}
|
||||
if args.emb_type == "global":
|
||||
print("Use global speaker embedding.")
|
||||
for spk in spk2utt.keys():
|
||||
spk2xvec[spk] = calc_global_ivc(spk, spk2utt, utt2xvec)[np.newaxis, :]
|
||||
|
||||
frames_list = []
|
||||
for utt_id in tqdm(idx_scp, total=len(idx_scp), ascii=True, disable=args.no_pbar):
|
||||
mid = utt_id.split("-")[0]
|
||||
idx_writer.write(utt_id+"\n")
|
||||
|
||||
xvec_list = []
|
||||
for spk in meeting2spk_list[mid]:
|
||||
spk_xvec = spk2xvec[spk]
|
||||
xvec_list.append(spk_xvec)
|
||||
for _ in range(args.n_spk - len(xvec_list)):
|
||||
xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
|
||||
xvec = np.row_stack(xvec_list)
|
||||
ivc_writer(utt_id, xvec)
|
||||
|
||||
frames_list.append((mid, 1))
|
||||
return frames_list
|
||||
|
||||
|
||||
def calc_spk_list(rttms):
|
||||
spk_list = []
|
||||
for one_line in rttms:
|
||||
parts = [x for x in one_line.strip().split(" ") if x != ""]
|
||||
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
|
||||
if spk_name.isdigit():
|
||||
spk_name = "{}_S{:03d}".format(mid, int(spk_name))
|
||||
else:
|
||||
spk_name = "{}_{}".format(mid, spk_name)
|
||||
if spk_name not in spk_list:
|
||||
spk_list.append(spk_name)
|
||||
|
||||
return spk_list
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dir", required=True, type=str, default=None,
|
||||
help="feats.scp")
|
||||
parser.add_argument("--out", required=True, type=str, default=None,
|
||||
help="The prefix of dumpped files.")
|
||||
parser.add_argument("--n_spk", type=int, default=4)
|
||||
parser.add_argument("--no_pbar", default=False, action="store_true")
|
||||
parser.add_argument("--sr", type=int, default=16000)
|
||||
parser.add_argument("--emb_type", type=str, default="rand")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(os.path.dirname(args.out)):
|
||||
os.makedirs(os.path.dirname(args.out))
|
||||
|
||||
idx_scp = open(os.path.join(args.dir, "idx"), "r").readlines()
|
||||
idx_scp = [x.strip() for x in idx_scp]
|
||||
meeting2rttms = {}
|
||||
for one_line in open(os.path.join(args.dir, "sys.rttm"), "rt"):
|
||||
parts = [x for x in one_line.strip().split(" ") if x != ""]
|
||||
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
if mid not in meeting2rttms:
|
||||
meeting2rttms[mid] = []
|
||||
meeting2rttms[mid].append(one_line)
|
||||
|
||||
utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
|
||||
utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
|
||||
|
||||
spk2utt = {}
|
||||
for utt, spk in utt2spk.items():
|
||||
if utt in utt2xvec:
|
||||
if spk not in spk2utt:
|
||||
spk2utt[spk] = []
|
||||
spk2utt[spk].append(utt)
|
||||
|
||||
meeting2spk_list = {}
|
||||
for mid, rttms in meeting2rttms.items():
|
||||
meeting2spk_list[mid] = calc_spk_list(rttms)
|
||||
new_spk_list = []
|
||||
for spk in meeting2spk_list[mid]:
|
||||
if spk in spk2utt:
|
||||
new_spk_list.append(spk)
|
||||
if len(new_spk_list) != len(meeting2spk_list[mid]):
|
||||
print("{}: Reduce speaker number from {}(according rttm) to {}(according x-vectors)".format(
|
||||
mid, len(meeting2spk_list[mid]), len(new_spk_list)))
|
||||
meeting2spk_list[mid] = new_spk_list
|
||||
|
||||
meeting_lens = process(idx_scp, spk2utt, utt2xvec, meeting2spk_list, args)
|
||||
print("Total meetings: {:6d}".format(len(meeting_lens)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
54
egs/callhome/diarization/sond/script/easy_average_models.py
Normal file
54
egs/callhome/diarization/sond/script/easy_average_models.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
required=True,
|
||||
default=None,
|
||||
type=str,
|
||||
help="Director contains saved models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--average_epochs",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric_name",
|
||||
type=str,
|
||||
default="der",
|
||||
help="The metric name of best models, only used for name."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
root_path = args.model_dir
|
||||
idx_list = args.average_epochs
|
||||
n_models = len(idx_list)
|
||||
metric = args.metric_name
|
||||
|
||||
if n_models > 0:
|
||||
avg = None
|
||||
for idx in idx_list:
|
||||
model_file = os.path.join(root_path, "{}epoch.pth".format(str(idx)))
|
||||
states = torch.load(model_file, map_location="cpu")
|
||||
if avg is None:
|
||||
avg = states
|
||||
else:
|
||||
for k in avg:
|
||||
avg[k] = avg[k] + states[k]
|
||||
|
||||
for k in avg:
|
||||
if str(avg[k].dtype).startswith("torch.int"):
|
||||
pass
|
||||
else:
|
||||
avg[k] = avg[k] / n_models
|
||||
|
||||
output_file = os.path.join(root_path, "valid.{}.ave_{}best.pth".format(metric, n_models))
|
||||
torch.save(avg, output_file)
|
||||
else:
|
||||
print("Number of models to average is 0, skip.")
|
||||
@ -0,0 +1,116 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import argparse
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
def prepare(self, parser):
|
||||
assert isinstance(parser, argparse.ArgumentParser)
|
||||
parser.add_argument("wav_scp", type=str)
|
||||
parser.add_argument("rttm_scp", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
parser.add_argument("--min_dur", type=float, default=2.0)
|
||||
parser.add_argument("--max_spk_num", type=int, default=4)
|
||||
args = parser.parse_args()
|
||||
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
wav_scp = load_scp_as_list(args.wav_scp)
|
||||
meeting2rttms = {}
|
||||
for one_line in open(args.rttm_scp, "rt"):
|
||||
parts = [x for x in one_line.strip().split(" ") if x != ""]
|
||||
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
if mid not in meeting2rttms:
|
||||
meeting2rttms[mid] = []
|
||||
meeting2rttms[mid].append(one_line)
|
||||
|
||||
task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, result_list, args):
|
||||
count = [0, 0]
|
||||
for result in result_list:
|
||||
count[0] += result[0]
|
||||
count[1] += result[1]
|
||||
print("Found {} speakers, extracted {}.".format(count[1], count[0]))
|
||||
|
||||
|
||||
# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
|
||||
def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
|
||||
labels = np.zeros([max_spk_num, length], int)
|
||||
spk_list = []
|
||||
for one_line in rttms:
|
||||
parts = [x for x in one_line.strip().split(" ") if x != ""]
|
||||
mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
|
||||
spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
|
||||
if spk_name.isdigit():
|
||||
spk_name = "{}_S{:03d}".format(mid, int(spk_name))
|
||||
else:
|
||||
spk_name = "{}_{}".format(mid, spk_name)
|
||||
if spk_name not in spk_list:
|
||||
spk_list.append(spk_name)
|
||||
st, dur = int(st*sr), int(dur*sr)
|
||||
idx = spk_list.index(spk_name)
|
||||
labels[idx, st:st+dur] = 1
|
||||
return labels, spk_list
|
||||
|
||||
|
||||
def get_nonoverlap_turns(multi_label, spk_list):
|
||||
turns = []
|
||||
label = np.sum(multi_label, axis=0) == 1
|
||||
spk, in_turn, st = None, False, 0
|
||||
for i in range(len(label)):
|
||||
if not in_turn and label[i]:
|
||||
st, in_turn = i, True
|
||||
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
|
||||
if in_turn:
|
||||
if not label[i]:
|
||||
in_turn = False
|
||||
turns.append([st, i, spk])
|
||||
elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
|
||||
turns.append([st, i, spk])
|
||||
st, in_turn = i, True
|
||||
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
|
||||
if in_turn:
|
||||
turns.append([st, len(label), spk])
|
||||
return turns
|
||||
|
||||
|
||||
def process(task_args):
|
||||
task_id, task_list, _, args = task_args
|
||||
spk_count = [0, 0]
|
||||
for mid, wav_path, rttms in task_list:
|
||||
wav = librosa.load(wav_path, args.sr)[0] * 32767
|
||||
multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
|
||||
turns = get_nonoverlap_turns(multi_label, spk_list)
|
||||
extracted_spk = []
|
||||
count = 1
|
||||
for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
|
||||
if (ed - st) >= args.min_dur * args.sr and len(wav[st: ed]) >= args.min_dur * args.sr:
|
||||
seg = wav[st: ed]
|
||||
save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
|
||||
if not os.path.exists(os.path.dirname(save_path)):
|
||||
os.makedirs(os.path.dirname(save_path))
|
||||
sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
count += 1
|
||||
if spk not in extracted_spk:
|
||||
extracted_spk.append(spk)
|
||||
if len(extracted_spk) != len(spk_list):
|
||||
print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
|
||||
mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
|
||||
))
|
||||
spk_count[0] += len(extracted_spk)
|
||||
spk_count[1] += len(spk_list)
|
||||
return spk_count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
from funasr.utils.job_runner import MultiProcessRunnerV3
|
||||
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
class MyRunner(MultiProcessRunnerV3):
|
||||
|
||||
def prepare(self, parser):
|
||||
parser.add_argument("dir", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
args = parser.parse_args()
|
||||
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
|
||||
|
||||
meeting_scp = load_scp_as_list(os.path.join(args.dir, "reco.scp"))
|
||||
vad_file = open(os.path.join(args.dir, "segments"))
|
||||
meeting2vad = {}
|
||||
for one_line in vad_file:
|
||||
uid, mid, st, ed = one_line.strip().split(" ")
|
||||
st, ed = int(float(st) * args.sr), int(float(ed) * args.sr)
|
||||
if mid not in meeting2vad:
|
||||
meeting2vad[mid] = []
|
||||
meeting2vad[mid].append((uid, st, ed))
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
|
||||
for mid, _ in meeting_scp:
|
||||
if mid not in meeting2vad:
|
||||
print("Recording {} doesn't contains speech segments".format(mid))
|
||||
task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp if mid in meeting2vad]
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, results_list, args):
|
||||
pass
|
||||
|
||||
|
||||
def process(task_args):
|
||||
_, task_list, _, args = task_args
|
||||
for mid, wav_path, vad_list in task_list:
|
||||
wav = librosa.load(wav_path, args.sr, True)[0] * 32767
|
||||
seg_list = []
|
||||
pos_map = []
|
||||
offset = 0
|
||||
for uid, st, ed in vad_list:
|
||||
seg_list.append(wav[st: ed])
|
||||
pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st))
|
||||
offset = offset + ed - st
|
||||
out = np.concatenate(seg_list, axis=0)
|
||||
save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
|
||||
sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
map_path = os.path.join(args.out_dir, "{}.pos".format(mid))
|
||||
with open(map_path, "wt") as fd:
|
||||
fd.writelines(pos_map)
|
||||
# print mid
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_runner = MyRunner(process)
|
||||
my_runner.run()
|
||||
216
egs/callhome/diarization/sond/script/simu_whole_recordings.py
Normal file
216
egs/callhome/diarization/sond/script/simu_whole_recordings.py
Normal file
@ -0,0 +1,216 @@
|
||||
import argparse
|
||||
import numpy as np
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
from funasr.utils.misc import load_scp_as_dict, load_scp_as_list
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def mix_wav_noise(wav, noise, snr):
|
||||
n_repeat = len(wav) // len(noise) + 1
|
||||
noise = np.repeat(noise, n_repeat, axis=0)
|
||||
st = random.randint(0, len(noise) - len(wav))
|
||||
noise = noise[st: st+len(wav)]
|
||||
|
||||
wav_mag = np.linalg.norm(wav, ord=2)
|
||||
noise_mag = np.linalg.norm(noise, ord=2)
|
||||
scale = wav_mag / (10 ** (float(snr) / 20))
|
||||
noise = noise / noise_mag * scale
|
||||
check_snr = 20 * np.log10(np.linalg.norm(wav, ord=2) / np.linalg.norm(noise, ord=2))
|
||||
if abs(check_snr - snr) >= 1e-2:
|
||||
print("SNR: {:.4f}, real SNR: {:.4f}".format(snr, check_snr))
|
||||
return wav + noise
|
||||
|
||||
|
||||
def calc_labels(rttms, args):
|
||||
turns = []
|
||||
total_length = 0
|
||||
for spk, st, dur in rttms:
|
||||
if args.random_interval:
|
||||
# random shift the interval with 20% of duration
|
||||
x = random.uniform(-dur*0.2, dur*0.2)
|
||||
st = max(0, st + x)
|
||||
# random squeeze or extend the interval
|
||||
dur += random.uniform(-dur*0.5, dur*0.5)
|
||||
if st + dur > total_length:
|
||||
total_length = st + dur
|
||||
turns.append([spk, st, dur])
|
||||
|
||||
# resort the turns according start point
|
||||
turns = sorted(turns, key=lambda x: x[1])
|
||||
|
||||
spk_list = []
|
||||
for spk, st, dur in turns:
|
||||
if spk not in spk_list:
|
||||
spk_list.append(spk)
|
||||
|
||||
total_length = int(total_length * args.sr)
|
||||
labels = np.zeros((len(spk_list), total_length), float)
|
||||
for spk, org_st, org_dur in turns:
|
||||
# random re-assign speaker to make more various samples
|
||||
st, dur = int(org_st * args.sr), int(org_dur * args.sr)
|
||||
if args.random_assign_spk:
|
||||
spk = random.choice(spk_list)
|
||||
spk_id = spk_list.index(spk)
|
||||
labels[spk_id, st:st+dur] = 1.0
|
||||
|
||||
new_turns = []
|
||||
for i in range(len(spk_list)):
|
||||
st = 0
|
||||
in_interval = False
|
||||
for j in range(total_length):
|
||||
if labels[i, j] == 1 and not in_interval:
|
||||
in_interval = True
|
||||
st = j
|
||||
if (labels[i, j] == 0 or j == total_length-1) and in_interval:
|
||||
in_interval = False
|
||||
new_turns.append((spk_list[i], float(st) / args.sr, float(j - st) / args.sr))
|
||||
new_turns = sorted(new_turns, key=lambda x: x[1])
|
||||
|
||||
return labels, spk_list, new_turns
|
||||
|
||||
|
||||
def save_wav(data, wav_path, sr):
|
||||
if np.max(np.abs(data)).item() > 32767:
|
||||
data = data / np.max(np.abs(data)) * 0.9 * 32767
|
||||
sf.write(wav_path, data.astype(np.int16), sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
|
||||
|
||||
def build(mid, meeting2rttm, spk2wav, noise_scp, room2rirs, args):
|
||||
mid = "m{:05d}".format(mid+1)
|
||||
if args.corpus_name is not None:
|
||||
mid = args.corpus_name + "_" + mid
|
||||
org_reco_id = random.choice(meeting2rttm.keys())
|
||||
rttms = meeting2rttm[org_reco_id]
|
||||
labels, org_spk_list, org_turns = calc_labels(rttms, args)
|
||||
n_spk = len(org_spk_list)
|
||||
|
||||
expected_length = labels.shape[1]
|
||||
meeting_spk_list = random.sample(spk2wav.keys(), n_spk)
|
||||
spk_mask = (np.sum(labels, axis=1) > 0).astype(int)
|
||||
pos_spk_list = [spk for spk, mask in zip(meeting_spk_list, spk_mask) if mask == 1]
|
||||
noise_id, noise_path = random.choice(noise_scp)
|
||||
noise_wav = librosa.load(noise_path, args.sr, True)[0] * 32767
|
||||
snr = random.choice(args.snr_list)
|
||||
room_id = random.choice(room2rirs.keys())
|
||||
# different speakers can locate at the same position a.k.a. the same rir.
|
||||
rir_list = [random.choice(room2rirs[room_id]) for _ in range(n_spk)]
|
||||
|
||||
mata = {
|
||||
"id": mid,
|
||||
"num_spk": n_spk,
|
||||
"pos_spk": pos_spk_list,
|
||||
"spk_list": meeting_spk_list,
|
||||
"seg_info": [],
|
||||
"noise": noise_id,
|
||||
"snr": snr,
|
||||
"length": expected_length,
|
||||
"meeting_info": org_reco_id,
|
||||
"room_id": room_id
|
||||
}
|
||||
sig = np.zeros((expected_length, ), dtype=np.float32)
|
||||
for i, spk in enumerate(meeting_spk_list):
|
||||
if spk in pos_spk_list:
|
||||
wav = librosa.load(spk2wav[spk], args.sr, True)[0] * 32767
|
||||
if len(wav) <= expected_length:
|
||||
# NOTE: to repeat an array, use np.tile rather than np.repeats
|
||||
wav = np.tile(wav, expected_length // len(wav) + 1)
|
||||
spk_st = np.random.randint(0, len(wav) - expected_length)
|
||||
spk_sig = wav[spk_st: spk_st+expected_length]
|
||||
spk_sig = spk_sig * labels[i, :]
|
||||
rir_wav = librosa.load(rir_list[i][1], args.sr, True)[0] * 32767
|
||||
spk_sig = np.convolve(spk_sig, rir_wav, "full")[:expected_length]
|
||||
mata["seg_info"].append([spk, spk_st, rir_list[i][0]])
|
||||
sig += spk_sig
|
||||
|
||||
mix = mix_wav_noise(sig, noise_wav, snr)
|
||||
if np.max(np.abs(mix)).item() > 32767:
|
||||
mix = mix / np.max(np.abs(mix)) * 0.9 * 32767
|
||||
save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
|
||||
sf.write(save_path, mix.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
|
||||
rttm_file = open(os.path.join(args.out_dir, "{}.rttm".format(mid)), "wt")
|
||||
for spk, st, dur in org_turns:
|
||||
rttm_file.write("SPEAKER {} 0 {:.3f} {:.3f} <NA> <NA> {} <NA> <NA>{}".format(
|
||||
mid, st, dur, meeting_spk_list[org_spk_list.index(spk)], os.linesep))
|
||||
rttm_file.close()
|
||||
|
||||
return mata, mix, labels
|
||||
|
||||
|
||||
def filter_spk_num(meeting2rttm, reco2num_spk, spk_num):
|
||||
meeting_list = meeting2rttm.keys()
|
||||
filtered_list = list(filter(lambda x: int(reco2num_spk[x]) == spk_num, meeting_list))
|
||||
new_dict = {key: meeting2rttm[key] for key in filtered_list}
|
||||
print("Keep {} out of {} according to speaker number {}".format(len(new_dict), len(meeting2rttm), spk_num))
|
||||
return new_dict
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("dir", type=str)
|
||||
parser.add_argument("out_dir", type=str)
|
||||
parser.add_argument("--total_mix", type=int, default=1)
|
||||
parser.add_argument("--sr", type=int, default=8000)
|
||||
parser.add_argument("--snr_list", type=int, default=[15, 20, 25], nargs="+")
|
||||
parser.add_argument("--spk_num", type=int, default=0)
|
||||
|
||||
parser.add_argument("--corpus_name", type=str, default=None)
|
||||
parser.add_argument("--task_id", type=int, default=0)
|
||||
parser.add_argument("--no_bar", action="store_true", default=False)
|
||||
parser.add_argument("--verbose", action="store_true", default=False)
|
||||
parser.add_argument("--debug", action="store_true", default=False)
|
||||
parser.add_argument("--random_assign_spk", action="store_true", default=False)
|
||||
parser.add_argument("--random_interval", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000."
|
||||
|
||||
# SPEAKER iaaa 0 0 1.08 <NA> <NA> B <NA> <NA>
|
||||
meeting2rttm = {}
|
||||
for one_line in open(os.path.join(args.dir, "ref.rttm")):
|
||||
parts = one_line.strip().split(" ")
|
||||
mid, spk, st, dur = parts[1], parts[7], float(parts[3]), float(parts[4])
|
||||
if mid not in meeting2rttm:
|
||||
meeting2rttm[mid] = []
|
||||
meeting2rttm[mid].append((spk, st, dur))
|
||||
reco2num_spk = load_scp_as_dict(os.path.join(args.dir, "reco2num_spk"))
|
||||
if args.spk_num > 1:
|
||||
meeting2rttm = filter_spk_num(meeting2rttm, reco2num_spk, args.spk_num)
|
||||
|
||||
spk2wav = load_scp_as_dict(os.path.join(args.dir, "spk2wav.scp"))
|
||||
noise_scp = load_scp_as_list(os.path.join(args.dir, "noise.scp"))
|
||||
rirs_scp = load_scp_as_list(os.path.join(args.dir, "rirs.scp"))
|
||||
room2rirs = {}
|
||||
for rir_id, rir_path in rirs_scp:
|
||||
room_id = rir_id.rsplit("-", 1)[0]
|
||||
if room_id not in room2rirs:
|
||||
room2rirs[room_id] = []
|
||||
room2rirs[room_id].append((rir_id, rir_path))
|
||||
|
||||
if not os.path.exists(args.out_dir):
|
||||
os.makedirs(args.out_dir)
|
||||
task_list = list(range(args.task_id * args.total_mix, (args.task_id + 1) * args.total_mix))
|
||||
|
||||
mata_data = []
|
||||
total = 0
|
||||
if args.debug:
|
||||
one, wav, label = build(0, meeting2rttm, spk2wav, noise_scp, room2rirs, args)
|
||||
mata_data.append(one)
|
||||
else:
|
||||
for mid in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_bar):
|
||||
one, wav, label = build(mid, meeting2rttm, spk2wav, noise_scp, room2rirs, args)
|
||||
mata_data.append(one)
|
||||
total += one["length"]
|
||||
if args.verbose:
|
||||
print("File name: {:20s}, segment num: {:5d}, speaker num: {:2d}, duration: {:7.2f}m".format(
|
||||
one["id"], len(one["seg_info"]), one["num_spk"], float(one["length"]) / args.sr / 60))
|
||||
print("Total files: {}, Total duration: {:.2f} hours".format(args.total_mix, (1.0 * total / args.sr / 3600)))
|
||||
json.dump(mata_data, open(os.path.join(args.out_dir, "mata.{}.json".format(args.task_id)), "wt"),
|
||||
ensure_ascii=False, encoding='utf-8', indent=4, sort_keys=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -453,11 +453,17 @@ def get_parser():
|
||||
help="The batch size for inference",
|
||||
)
|
||||
group.add_argument(
|
||||
"--diar_smooth_size",
|
||||
"--smooth_size",
|
||||
type=int,
|
||||
default=121,
|
||||
help="The smoothing size for post-processing"
|
||||
)
|
||||
group.add_argument(
|
||||
"--dur_threshold",
|
||||
type=int,
|
||||
default=10,
|
||||
help="The threshold of minimum duration"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import torch
|
||||
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
from funasr.layers.label_aggregation import LabelAggregate
|
||||
from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
|
||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
|
||||
from funasr.models.e2e_diar_sond import DiarSondModel
|
||||
@ -26,6 +26,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||
from funasr.models.frontend.windowing import SlidingWindow
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.models.specaug.specaug import SpecAugLFR
|
||||
from funasr.models.specaug.abs_profileaug import AbsProfileAug
|
||||
from funasr.models.specaug.profileaug import ProfileAug
|
||||
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
||||
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
@ -52,6 +54,15 @@ specaug_choices = ClassChoices(
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
profileaug_choices = ClassChoices(
|
||||
name="profileaug",
|
||||
classes=dict(
|
||||
profileaug=ProfileAug,
|
||||
),
|
||||
type_check=AbsProfileAug,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
normalize_choices = ClassChoices(
|
||||
"normalize",
|
||||
classes=dict(
|
||||
@ -64,7 +75,8 @@ normalize_choices = ClassChoices(
|
||||
label_aggregator_choices = ClassChoices(
|
||||
"label_aggregator",
|
||||
classes=dict(
|
||||
label_aggregator=LabelAggregate
|
||||
label_aggregator=LabelAggregate,
|
||||
label_aggregator_max_pool=LabelAggregateMaxPooling,
|
||||
),
|
||||
default=None,
|
||||
optional=True,
|
||||
@ -155,6 +167,8 @@ class_choices_list = [
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --profileaug and --profileaug_conf
|
||||
profileaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --label_aggregator and --label_aggregator_conf
|
||||
@ -217,6 +231,13 @@ def build_diar_model(args):
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# Data augmentation for Profiles
|
||||
if hasattr(args, "profileaug") and args.profileaug is not None:
|
||||
profileaug_class = profileaug_choices.get_class(args.profileaug)
|
||||
profileaug = profileaug_class(**args.profileaug_conf)
|
||||
else:
|
||||
profileaug = None
|
||||
|
||||
# normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
@ -261,6 +282,7 @@ def build_diar_model(args):
|
||||
vocab_size=vocab_size,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
profileaug=profileaug,
|
||||
normalize=normalize,
|
||||
label_aggregator=label_aggregator,
|
||||
encoder=encoder,
|
||||
|
||||
@ -6,8 +6,9 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from funasr.modules.nets_utils import pad_list
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
from funasr.modules.nets_utils import pad_list, pad_list_all_dim
|
||||
|
||||
|
||||
class CommonCollateFn:
|
||||
@ -77,6 +78,78 @@ def common_collate_fn(
|
||||
output = (uttids, output)
|
||||
return output
|
||||
|
||||
|
||||
class DiarCollateFn:
|
||||
"""Functor class of common_collate_fn()"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
float_pad_value: Union[float, int] = 0.0,
|
||||
int_pad_value: int = -32768,
|
||||
not_sequence: Collection[str] = (),
|
||||
max_sample_size=None
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.float_pad_value = float_pad_value
|
||||
self.int_pad_value = int_pad_value
|
||||
self.not_sequence = set(not_sequence)
|
||||
self.max_sample_size = max_sample_size
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
|
||||
f"int_pad_value={self.float_pad_value})"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
return diar_collate_fn(
|
||||
data,
|
||||
float_pad_value=self.float_pad_value,
|
||||
int_pad_value=self.int_pad_value,
|
||||
not_sequence=self.not_sequence,
|
||||
)
|
||||
|
||||
|
||||
def diar_collate_fn(
|
||||
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
||||
float_pad_value: Union[float, int] = 0.0,
|
||||
int_pad_value: int = -32768,
|
||||
not_sequence: Collection[str] = (),
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
||||
assert all(
|
||||
not k.endswith("_lengths") for k in data[0]
|
||||
), f"*_lengths is reserved: {list(data[0])}"
|
||||
|
||||
output = {}
|
||||
for key in data[0]:
|
||||
if data[0][key].dtype.kind == "i":
|
||||
pad_value = int_pad_value
|
||||
else:
|
||||
pad_value = float_pad_value
|
||||
|
||||
array_list = [d[key] for d in data]
|
||||
tensor_list = [torch.from_numpy(a) for a in array_list]
|
||||
tensor = pad_list_all_dim(tensor_list, pad_value)
|
||||
output[key] = tensor
|
||||
|
||||
if key not in not_sequence:
|
||||
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
|
||||
|
||||
def crop_to_max_size(feature, target_size):
|
||||
size = len(feature)
|
||||
diff = size - target_size
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from typeguard import check_argument_types
|
||||
from torch.nn import functional as F
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
@ -78,3 +79,37 @@ class LabelAggregate(torch.nn.Module):
|
||||
olens = None
|
||||
|
||||
return output.to(input.dtype), olens
|
||||
|
||||
|
||||
class LabelAggregateMaxPooling(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hop_length: int = 8,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
self.hop_length = hop_length
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"hop_length={self.hop_length}, "
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""LabelAggregate forward function.
|
||||
|
||||
Args:
|
||||
input: (Batch, Nsamples, Label_dim)
|
||||
ilens: (Batch)
|
||||
Returns:
|
||||
output: (Batch, Frames, Label_dim)
|
||||
|
||||
"""
|
||||
|
||||
output = F.max_pool1d(input.transpose(1, 2), self.hop_length, self.hop_length).transpose(1, 2)
|
||||
olens = ilens // self.hop_length
|
||||
|
||||
return output.to(input.dtype), olens
|
||||
@ -75,10 +75,10 @@ class SequenceBinaryCrossEntropy(nn.Module):
|
||||
self.criterion = criterion
|
||||
|
||||
def forward(self, pred, label, lengths):
|
||||
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
|
||||
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
|
||||
loss = self.criterion(pred, label)
|
||||
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
|
||||
return loss.masked_fill(pad_mask, 0).sum() / denom
|
||||
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
|
||||
|
||||
|
||||
class NllLoss(nn.Module):
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import logging
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from itertools import permutations
|
||||
@ -12,6 +13,7 @@ from typing import Tuple, List
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import to_device
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
@ -19,11 +21,13 @@ from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models.specaug.abs_profileaug import AbsProfileAug
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
|
||||
from funasr.utils.misc import int2vec
|
||||
from funasr.utils.hinter import hint_once
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
@ -35,12 +39,8 @@ else:
|
||||
|
||||
|
||||
class DiarSondModel(FunASRModel):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
|
||||
https://arxiv.org/abs/2211.10243
|
||||
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
|
||||
https://arxiv.org/abs/2303.05397
|
||||
"""Speaker overlap-aware neural diarization model
|
||||
reference: https://arxiv.org/abs/2211.10243
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -48,6 +48,7 @@ class DiarSondModel(FunASRModel):
|
||||
vocab_size: int,
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
profileaug: Optional[AbsProfileAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
encoder: torch.nn.Module,
|
||||
speaker_encoder: Optional[torch.nn.Module],
|
||||
@ -64,7 +65,11 @@ class DiarSondModel(FunASRModel):
|
||||
speaker_discrimination_loss_weight: float = 1.0,
|
||||
inter_score_loss_weight: float = 0.0,
|
||||
inputs_type: str = "raw",
|
||||
model_regularizer_weight: float = 0.0,
|
||||
freeze_encoder: bool = False,
|
||||
onfly_shuffle_speaker: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
super().__init__()
|
||||
|
||||
@ -75,12 +80,16 @@ class DiarSondModel(FunASRModel):
|
||||
self.normalize = normalize
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.profileaug = profileaug
|
||||
self.label_aggregator = label_aggregator
|
||||
self.decoder = decoder
|
||||
self.token_list = token_list
|
||||
self.max_spk_num = max_spk_num
|
||||
self.normalize_speech_speaker = normalize_speech_speaker
|
||||
self.ignore_id = ignore_id
|
||||
self.model_regularizer_weight = model_regularizer_weight
|
||||
self.freeze_encoder = freeze_encoder
|
||||
self.onfly_shuffle_speaker = onfly_shuffle_speaker
|
||||
self.criterion_diar = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
@ -95,14 +104,45 @@ class DiarSondModel(FunASRModel):
|
||||
self.inter_score_loss_weight = inter_score_loss_weight
|
||||
self.forward_steps = 0
|
||||
self.inputs_type = inputs_type
|
||||
self.to_regularize_parameters = None
|
||||
|
||||
def get_regularize_parameters(self):
|
||||
to_regularize_parameters, normal_parameters = [], []
|
||||
for name, param in self.named_parameters():
|
||||
if ("encoder" in name and "weight" in name and "bn" not in name and
|
||||
("conv2" in name or "conv1" in name or "conv_sc" in name or "dense" in name)
|
||||
):
|
||||
to_regularize_parameters.append((name, param))
|
||||
else:
|
||||
normal_parameters.append((name, param))
|
||||
self.to_regularize_parameters = to_regularize_parameters
|
||||
return to_regularize_parameters, normal_parameters
|
||||
|
||||
def generate_pse_embedding(self):
|
||||
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
|
||||
embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float32)
|
||||
for idx, pse_label in enumerate(self.token_list):
|
||||
emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float)
|
||||
emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float32)
|
||||
embedding[idx] = emb
|
||||
return torch.from_numpy(embedding)
|
||||
|
||||
def rand_permute_speaker(self, raw_profile, raw_binary_labels):
|
||||
"""
|
||||
raw_profile: B, N, D
|
||||
raw_binary_labels: B, T, N
|
||||
"""
|
||||
assert raw_profile.shape[1] == raw_binary_labels.shape[2], \
|
||||
"Num profile: {}, Num label: {}".format(raw_profile.shape[1], raw_binary_labels.shape[-1])
|
||||
profile = torch.clone(raw_profile)
|
||||
binary_labels = torch.clone(raw_binary_labels)
|
||||
bsz, num_spk = profile.shape[0], profile.shape[1]
|
||||
for i in range(bsz):
|
||||
idx = list(range(num_spk))
|
||||
random.shuffle(idx)
|
||||
profile[i] = profile[i][idx, :]
|
||||
binary_labels[i] = binary_labels[i][:, idx]
|
||||
|
||||
return profile, binary_labels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
@ -113,6 +153,7 @@ class DiarSondModel(FunASRModel):
|
||||
binary_labels_lengths: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, samples) or (Batch, frames, input_size)
|
||||
speech_lengths: (Batch,) default None for chunk interator,
|
||||
@ -127,13 +168,38 @@ class DiarSondModel(FunASRModel):
|
||||
"""
|
||||
assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
|
||||
batch_size = speech.shape[0]
|
||||
if self.freeze_encoder:
|
||||
hint_once("Freeze encoder", "freeze_encoder", rank=0)
|
||||
self.encoder.eval()
|
||||
self.forward_steps = self.forward_steps + 1
|
||||
if self.pse_embedding.device != speech.device:
|
||||
self.pse_embedding = self.pse_embedding.to(speech.device)
|
||||
self.power_weight = self.power_weight.to(speech.device)
|
||||
self.int_token_arr = self.int_token_arr.to(speech.device)
|
||||
|
||||
# 1. Network forward
|
||||
if self.onfly_shuffle_speaker:
|
||||
hint_once("On-the-fly shuffle speaker permutation.", "onfly_shuffle_speaker", rank=0)
|
||||
profile, binary_labels = self.rand_permute_speaker(profile, binary_labels)
|
||||
|
||||
# 0a. Aggregate time-domain labels to match forward outputs
|
||||
if self.label_aggregator is not None:
|
||||
binary_labels, binary_labels_lengths = self.label_aggregator(
|
||||
binary_labels, binary_labels_lengths
|
||||
)
|
||||
# 0b. augment profiles
|
||||
if self.profileaug is not None and self.training:
|
||||
speech, profile, binary_labels = self.profileaug(
|
||||
speech, speech_lengths,
|
||||
profile, profile_lengths,
|
||||
binary_labels, binary_labels_lengths
|
||||
)
|
||||
|
||||
# 1. Calculate power-set encoding (PSE) labels
|
||||
pad_bin_labels = F.pad(binary_labels, (0, self.max_spk_num - binary_labels.shape[2]), "constant", 0.0)
|
||||
raw_pse_labels = torch.sum(pad_bin_labels * self.power_weight, dim=2, keepdim=True)
|
||||
pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
|
||||
|
||||
# 2. Network forward
|
||||
pred, inter_outputs = self.prediction_forward(
|
||||
speech, speech_lengths,
|
||||
profile, profile_lengths,
|
||||
@ -141,15 +207,6 @@ class DiarSondModel(FunASRModel):
|
||||
)
|
||||
(speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
|
||||
|
||||
# 2. Aggregate time-domain labels to match forward outputs
|
||||
if self.label_aggregator is not None:
|
||||
binary_labels, binary_labels_lengths = self.label_aggregator(
|
||||
binary_labels, binary_labels_lengths
|
||||
)
|
||||
# 2. Calculate power-set encoding (PSE) labels
|
||||
raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
|
||||
pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
|
||||
|
||||
# If encoder uses conv* as input_layer (i.e., subsampling),
|
||||
# the sequence length of 'pred' might be slightly less than the
|
||||
# length of 'spk_labels'. Here we force them to be equal.
|
||||
@ -165,9 +222,14 @@ class DiarSondModel(FunASRModel):
|
||||
loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
|
||||
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
|
||||
loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
|
||||
regularizer_loss = None
|
||||
if self.model_regularizer_weight > 0 and self.to_regularize_parameters is not None:
|
||||
regularizer_loss = self.calculate_regularizer_loss()
|
||||
label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device)
|
||||
loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
|
||||
+ self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
|
||||
# if regularizer_loss is not None:
|
||||
# loss = loss + regularizer_loss * self.model_regularizer_weight
|
||||
|
||||
(
|
||||
correct,
|
||||
@ -204,6 +266,7 @@ class DiarSondModel(FunASRModel):
|
||||
loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
|
||||
loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
|
||||
loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
|
||||
regularizer_loss=regularizer_loss.detach() if regularizer_loss is not None else None,
|
||||
sad_mr=sad_mr,
|
||||
sad_fr=sad_fr,
|
||||
mi=mi,
|
||||
@ -217,6 +280,12 @@ class DiarSondModel(FunASRModel):
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def calculate_regularizer_loss(self):
|
||||
regularizer_loss = 0.0
|
||||
for name, param in self.to_regularize_parameters:
|
||||
regularizer_loss = regularizer_loss + torch.norm(param, p=2)
|
||||
return regularizer_loss
|
||||
|
||||
def classification_loss(
|
||||
self,
|
||||
predictions: torch.Tensor,
|
||||
@ -388,6 +457,7 @@ class DiarSondModel(FunASRModel):
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch,)
|
||||
@ -487,4 +557,4 @@ class DiarSondModel(FunASRModel):
|
||||
speaker_miss,
|
||||
speaker_falarm,
|
||||
speaker_error,
|
||||
)
|
||||
)
|
||||
|
||||
22
funasr/models/specaug/abs_profileaug.py
Normal file
22
funasr/models/specaug/abs_profileaug.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbsProfileAug(torch.nn.Module):
|
||||
"""Abstract class for the augmentation of profile
|
||||
|
||||
The process-flow:
|
||||
|
||||
Frontend --> SpecAug -> Normalization -> Encoder -> Decoder
|
||||
`-> ProfileAug -> Speaker Encoder --'
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lengths: torch.Tensor = None,
|
||||
profile: torch.Tensor = None, profile_lengths: torch.Tensor = None,
|
||||
binary_labels: torch.Tensor = None, labels_length: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
raise NotImplementedError
|
||||
122
funasr/models/specaug/profileaug.py
Normal file
122
funasr/models/specaug/profileaug.py
Normal file
@ -0,0 +1,122 @@
|
||||
from typing import Tuple, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from funasr.models.specaug.abs_profileaug import AbsProfileAug
|
||||
|
||||
|
||||
class ProfileAug(AbsProfileAug):
|
||||
"""
|
||||
Implement the augmentation for profiles including:
|
||||
- Split aug: split one profile into two profiles, i.e., main and inaccurate, labels assigned to main
|
||||
- Merge aug: merge two profiles into one, labels are also merged into one, the other set to zero
|
||||
- Disturb aug: disturb some profile with others to simulate the inaccurate clustering centroids.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
apply_split_aug: bool = True,
|
||||
split_aug_prob: float = 0.05,
|
||||
apply_merge_aug: bool = True,
|
||||
merge_aug_prob: float = 0.2,
|
||||
apply_disturb_aug: bool = True,
|
||||
disturb_aug_prob: float = 0.4,
|
||||
disturb_alpha: float = 0.2,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.apply_split_aug = apply_split_aug
|
||||
self.split_aug_prob = split_aug_prob
|
||||
self.apply_merge_aug = apply_merge_aug
|
||||
self.merge_aug_prob = merge_aug_prob
|
||||
self.apply_disturb_aug = apply_disturb_aug
|
||||
self.disturb_aug_prob = disturb_aug_prob
|
||||
self.disturb_alpha = disturb_alpha
|
||||
|
||||
def split_aug(self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor):
|
||||
# B, N
|
||||
bsz, dim = profile.shape[0], profile.shape[-1]
|
||||
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
|
||||
spk_count = binary_labels.sum(dim=1)
|
||||
prob = np.random.rand(bsz)
|
||||
batch_indices = np.nonzero(prob < self.split_aug_prob)[0]
|
||||
for idx in batch_indices:
|
||||
valid_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
|
||||
pad_spk_idx = torch.nonzero((spk_count[idx] == 0) * mask[idx])
|
||||
if len(valid_spk_idx) == 0 or len(pad_spk_idx) == 0:
|
||||
continue
|
||||
split_spk_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
|
||||
to_cover_idx = pad_spk_idx[torch.randint(len(pad_spk_idx), ())]
|
||||
disturb_vec = torch.randn((dim,)).to(profile)
|
||||
disturb_vec = F.normalize(disturb_vec, dim=-1)
|
||||
profile[idx, to_cover_idx] = F.normalize(profile[idx, split_spk_idx] +
|
||||
self.disturb_alpha * disturb_vec)
|
||||
mask[idx, split_spk_idx] = 0
|
||||
mask[idx, to_cover_idx] = 0
|
||||
return profile, binary_labels, mask
|
||||
|
||||
def merge_aug(self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor):
|
||||
bsz, dim = profile.shape[0], profile.shape[-1]
|
||||
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
|
||||
spk_count = binary_labels.sum(dim=1)
|
||||
prob = np.random.rand(bsz)
|
||||
batch_indices = np.nonzero(prob < self.merge_aug_prob)[0]
|
||||
for idx in batch_indices:
|
||||
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
|
||||
if len(valid_spk_idx) == 0:
|
||||
continue
|
||||
to_merge = torch.randint(len(valid_spk_idx), (2, ))
|
||||
spk_idx_1, spk_idx_2 = valid_spk_idx[to_merge[0]], valid_spk_idx[to_merge[1]]
|
||||
# merge profile
|
||||
profile[idx, spk_idx_1] = profile[idx, spk_idx_1] + profile[idx, spk_idx_2]
|
||||
profile[idx, spk_idx_1] = F.normalize(profile[idx, spk_idx_1], dim=-1)
|
||||
profile[idx, spk_idx_2] = 0
|
||||
# merge binary labels
|
||||
binary_labels[idx, :, spk_idx_1] = binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2]
|
||||
binary_labels[idx, :, spk_idx_1] = (binary_labels[idx, :, spk_idx_1] > 0).to(binary_labels)
|
||||
binary_labels[idx, :, spk_idx_2] = 0
|
||||
|
||||
mask[idx, spk_idx_1] = 0
|
||||
mask[idx, spk_idx_2] = 0
|
||||
|
||||
return profile, binary_labels, mask
|
||||
|
||||
def disturb_aug(self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor):
|
||||
bsz, dim = profile.shape[0], profile.shape[-1]
|
||||
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
|
||||
spk_count = binary_labels.sum(dim=1)
|
||||
prob = np.random.rand(bsz)
|
||||
batch_indices = np.nonzero(prob < self.disturb_aug_prob)[0]
|
||||
for idx in batch_indices:
|
||||
pos_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
|
||||
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
|
||||
if len(pos_spk_idx) == 0 or len(valid_spk_idx) == 0:
|
||||
continue
|
||||
to_disturb_idx = pos_spk_idx[torch.randint(len(pos_spk_idx), ())]
|
||||
disturb_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
|
||||
alpha = self.disturb_alpha * torch.rand(()).item()
|
||||
profile[idx, to_disturb_idx] = ((1 - alpha) * profile[idx, to_disturb_idx]
|
||||
+ alpha * profile[idx, disturb_idx])
|
||||
profile[idx, to_disturb_idx] = F.normalize(profile[idx, to_disturb_idx], dim=-1)
|
||||
mask[idx, to_disturb_idx] = 0
|
||||
|
||||
return profile, binary_labels, mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor, speech_lengths: torch.Tensor = None,
|
||||
profile: torch.Tensor = None, profile_lengths: torch.Tensor = None,
|
||||
binary_labels: torch.Tensor = None, labels_length: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
# copy inputs to avoid inplace-operation
|
||||
speech, profile, binary_labels = torch.clone(speech), torch.clone(profile), torch.clone(binary_labels)
|
||||
profile = F.normalize(profile, dim=-1)
|
||||
|
||||
profile_mask = torch.ones(profile.shape[:2]).to(profile)
|
||||
if self.apply_disturb_aug:
|
||||
profile, binary_labels, profile_mask = self.disturb_aug(profile, binary_labels, profile_mask)
|
||||
if self.apply_split_aug:
|
||||
profile, binary_labels, profile_mask = self.split_aug(profile, binary_labels, profile_mask)
|
||||
if self.apply_merge_aug:
|
||||
profile, binary_labels, profile_mask = self.merge_aug(profile, binary_labels, profile_mask)
|
||||
|
||||
return speech, profile, binary_labels
|
||||
@ -61,6 +61,48 @@ def pad_list(xs, pad_value):
|
||||
return pad
|
||||
|
||||
|
||||
def pad_list_all_dim(xs, pad_value):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
num_dim = len(xs[0].shape)
|
||||
max_len_all_dim = []
|
||||
for i in range(num_dim):
|
||||
max_len_all_dim.append(max(x.size(i) for x in xs))
|
||||
pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
if num_dim == 1:
|
||||
pad[i, : xs[i].size(0)] = xs[i]
|
||||
elif num_dim == 2:
|
||||
pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i]
|
||||
elif num_dim == 3:
|
||||
pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i]
|
||||
else:
|
||||
raise ValueError(
|
||||
"pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim)
|
||||
)
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
|
||||
@ -1,11 +1,3 @@
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
|
||||
https://arxiv.org/abs/2211.10243
|
||||
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
|
||||
https://arxiv.org/abs/2303.05397
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
@ -21,24 +13,26 @@ from typing import Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.datasets.collate_fn import CommonCollateFn
|
||||
from funasr.datasets.collate_fn import DiarCollateFn
|
||||
from funasr.datasets.preprocessor import CommonPreprocessor
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
from funasr.layers.label_aggregation import LabelAggregate
|
||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||
from funasr.models.e2e_diar_sond import DiarSondModel
|
||||
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
||||
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||
from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
|
||||
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
|
||||
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
|
||||
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
|
||||
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
|
||||
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
|
||||
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
|
||||
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
|
||||
from funasr.models.e2e_diar_sond import DiarSondModel
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
||||
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||
from funasr.models.encoder.rnn_encoder import RNNEncoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
||||
from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
||||
@ -47,16 +41,21 @@ from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.fused import FusedFrontends
|
||||
from funasr.models.frontend.s3prl import S3prlFrontend
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||
from funasr.models.frontend.windowing import SlidingWindow
|
||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
|
||||
HuggingFaceTransformersPostEncoder, # noqa: H301
|
||||
)
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.preencoder.linear import LinearProjection
|
||||
from funasr.models.preencoder.sinc import LightweightSincConvs
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.models.specaug.specaug import SpecAugLFR
|
||||
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
||||
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
||||
from funasr.models.specaug.abs_profileaug import AbsProfileAug
|
||||
from funasr.models.specaug.profileaug import ProfileAug
|
||||
from funasr.tasks.abs_task import AbsTask
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.models.base_model import FunASRModel
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
from funasr.train.trainer import Trainer
|
||||
from funasr.utils.types import float_or_none
|
||||
@ -72,7 +71,6 @@ frontend_choices = ClassChoices(
|
||||
s3prl=S3prlFrontend,
|
||||
fused=FusedFrontends,
|
||||
wav_frontend=WavFrontend,
|
||||
wav_frontend_mel23=WavFrontendMel23,
|
||||
),
|
||||
type_check=AbsFrontend,
|
||||
default="default",
|
||||
@ -87,6 +85,15 @@ specaug_choices = ClassChoices(
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
profileaug_choices = ClassChoices(
|
||||
name="profileaug",
|
||||
classes=dict(
|
||||
profileaug=ProfileAug,
|
||||
),
|
||||
type_check=AbsProfileAug,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
normalize_choices = ClassChoices(
|
||||
"normalize",
|
||||
classes=dict(
|
||||
@ -100,7 +107,8 @@ normalize_choices = ClassChoices(
|
||||
label_aggregator_choices = ClassChoices(
|
||||
"label_aggregator",
|
||||
classes=dict(
|
||||
label_aggregator=LabelAggregate
|
||||
label_aggregator=LabelAggregate,
|
||||
label_aggregator_max_pool=LabelAggregateMaxPooling,
|
||||
),
|
||||
type_check=torch.nn.Module,
|
||||
default=None,
|
||||
@ -110,9 +118,8 @@ model_choices = ClassChoices(
|
||||
"model",
|
||||
classes=dict(
|
||||
sond=DiarSondModel,
|
||||
eend_ola=DiarEENDOLAModel,
|
||||
),
|
||||
type_check=FunASRModel,
|
||||
type_check=torch.nn.Module,
|
||||
default="sond",
|
||||
)
|
||||
encoder_choices = ClassChoices(
|
||||
@ -130,7 +137,6 @@ encoder_choices = ClassChoices(
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
data2vec_encoder=Data2VecEncoder,
|
||||
ecapa_tdnn=ECAPA_TDNN,
|
||||
eend_ola_transformer=EENDOLATransformerEncoder,
|
||||
),
|
||||
type_check=torch.nn.Module,
|
||||
default="resnet34",
|
||||
@ -182,15 +188,6 @@ decoder_choices = ClassChoices(
|
||||
type_check=torch.nn.Module,
|
||||
default="fsmn",
|
||||
)
|
||||
# encoder_decoder_attractor is used for EEND-OLA
|
||||
encoder_decoder_attractor_choices = ClassChoices(
|
||||
"encoder_decoder_attractor",
|
||||
classes=dict(
|
||||
eda=EncoderDecoderAttractor,
|
||||
),
|
||||
type_check=torch.nn.Module,
|
||||
default="eda",
|
||||
)
|
||||
|
||||
|
||||
class DiarTask(AbsTask):
|
||||
@ -203,6 +200,8 @@ class DiarTask(AbsTask):
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --profileaug and --profileaug_conf
|
||||
profileaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --label_aggregator and --label_aggregator_conf
|
||||
@ -342,13 +341,15 @@ class DiarTask(AbsTask):
|
||||
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
|
||||
Tuple[List[str], Dict[str, torch.Tensor]],
|
||||
]:
|
||||
assert check_argument_types()
|
||||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
|
||||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||
return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||
assert check_argument_types()
|
||||
if args.use_preprocessor:
|
||||
retval = CommonPreprocessor(
|
||||
train=train,
|
||||
@ -378,6 +379,7 @@ class DiarTask(AbsTask):
|
||||
)
|
||||
else:
|
||||
retval = None
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
@ -396,10 +398,47 @@ class DiarTask(AbsTask):
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
retval = ()
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_optimizers(
|
||||
cls,
|
||||
args: argparse.Namespace,
|
||||
model: torch.nn.Module,
|
||||
) -> List[torch.optim.Optimizer]:
|
||||
if cls.num_optimizers != 1:
|
||||
raise RuntimeError(
|
||||
"build_optimizers() must be overridden if num_optimizers != 1"
|
||||
)
|
||||
from funasr.tasks.abs_task import optim_classes
|
||||
optim_class = optim_classes.get(args.optim)
|
||||
if optim_class is None:
|
||||
raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
|
||||
else:
|
||||
if (hasattr(model, "model_regularizer_weight") and
|
||||
model.model_regularizer_weight > 0.0 and
|
||||
hasattr(model, "get_regularize_parameters")
|
||||
):
|
||||
to_regularize_parameters, normal_parameters = model.get_regularize_parameters()
|
||||
logging.info(f"Set weight decay {model.model_regularizer_weight} for parameters: "
|
||||
f"{[name for name, value in to_regularize_parameters]}")
|
||||
module_optim_config = [
|
||||
{"params": [value for name, value in to_regularize_parameters],
|
||||
"weight_decay": model.model_regularizer_weight},
|
||||
{"params": [value for name, value in normal_parameters],
|
||||
"weight_decay": 0.0}
|
||||
]
|
||||
optim = optim_class(module_optim_config, **args.optim_conf)
|
||||
else:
|
||||
optim = optim_class(model.parameters(), **args.optim_conf)
|
||||
|
||||
optimizers = [optim]
|
||||
return optimizers
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
@ -436,6 +475,13 @@ class DiarTask(AbsTask):
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 2b. Data augmentation for Profiles
|
||||
if hasattr(args, "profileaug") and args.profileaug is not None:
|
||||
profileaug_class = profileaug_choices.get_class(args.profileaug)
|
||||
profileaug = profileaug_class(**args.profileaug_conf)
|
||||
else:
|
||||
profileaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
@ -483,6 +529,7 @@ class DiarTask(AbsTask):
|
||||
vocab_size=vocab_size,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
profileaug=profileaug,
|
||||
normalize=normalize,
|
||||
label_aggregator=label_aggregator,
|
||||
encoder=encoder,
|
||||
@ -497,7 +544,9 @@ class DiarTask(AbsTask):
|
||||
# 10. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
logging.info(f"Init model parameters with {args.init}.")
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
|
||||
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
|
||||
@ -520,6 +569,7 @@ class DiarTask(AbsTask):
|
||||
device: Device type, "cpu", "cuda", or "cuda:N".
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
if config_file is None:
|
||||
assert model_file is not None, (
|
||||
"The argument 'model_file' must be provided "
|
||||
@ -535,9 +585,9 @@ class DiarTask(AbsTask):
|
||||
args["cmvn_file"] = cmvn_file
|
||||
args = argparse.Namespace(**args)
|
||||
model = cls.build_model(args)
|
||||
if not isinstance(model, FunASRModel):
|
||||
if not isinstance(model, torch.nn.Module):
|
||||
raise RuntimeError(
|
||||
f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
|
||||
f"model must inherit {torch.nn.Module.__name__}, but got {type(model)}"
|
||||
)
|
||||
model.to(device)
|
||||
model_dict = dict()
|
||||
@ -552,13 +602,13 @@ class DiarTask(AbsTask):
|
||||
if ".bin" in model_name:
|
||||
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
|
||||
else:
|
||||
model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
|
||||
model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
|
||||
if os.path.exists(model_name_pth):
|
||||
logging.info("model_file is load from pth: {}".format(model_name_pth))
|
||||
model_dict = torch.load(model_name_pth, map_location=device)
|
||||
else:
|
||||
model_dict = cls.convert_tf2torch(model, model_file)
|
||||
model.load_state_dict(model_dict)
|
||||
# model.load_state_dict(model_dict)
|
||||
else:
|
||||
model_dict = torch.load(model_file, map_location=device)
|
||||
model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
|
||||
@ -616,287 +666,3 @@ class DiarTask(AbsTask):
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class EENDOLADiarTask(AbsTask):
|
||||
# If you need more than 1 optimizer, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
model_choices,
|
||||
# --encoder and --encoder_conf
|
||||
encoder_choices,
|
||||
# --speaker_encoder and --speaker_encoder_conf
|
||||
encoder_decoder_attractor_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def add_task_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(description="Task related")
|
||||
|
||||
# NOTE(kamo): add_arguments(..., required=True) can't be used
|
||||
# to provide --print_config mode. Instead of it, do as
|
||||
# required = parser.get_default("required")
|
||||
# required += ["token_list"]
|
||||
|
||||
group.add_argument(
|
||||
"--token_list",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="A text mapping int-id to token",
|
||||
)
|
||||
group.add_argument(
|
||||
"--split_with_space",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether to split text using <space>",
|
||||
)
|
||||
group.add_argument(
|
||||
"--seg_dict_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="seg_dict_file for text processing",
|
||||
)
|
||||
group.add_argument(
|
||||
"--init",
|
||||
type=lambda x: str_or_none(x.lower()),
|
||||
default=None,
|
||||
help="The initialization method",
|
||||
choices=[
|
||||
"chainer",
|
||||
"xavier_uniform",
|
||||
"xavier_normal",
|
||||
"kaiming_uniform",
|
||||
"kaiming_normal",
|
||||
None,
|
||||
],
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input_size",
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help="The number of input dimension of the feature",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group(description="Preprocess related")
|
||||
group.add_argument(
|
||||
"--use_preprocessor",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Apply preprocessing to data or not",
|
||||
)
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str,
|
||||
default="char",
|
||||
choices=["char"],
|
||||
help="The text will be tokenized in the specified level token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speech_volume_normalize",
|
||||
type=float_or_none,
|
||||
default=None,
|
||||
help="Scale the maximum amplitude to the given value.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of rir scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rir_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="THe probability for applying RIR convolution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cmvn_file",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of noise scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_scp",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The file path of noise scp file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_apply_prob",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The probability applying Noise adding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_db_range",
|
||||
type=str,
|
||||
default="13_15",
|
||||
help="The range of noise decibel level.",
|
||||
)
|
||||
|
||||
for class_choices in cls.class_choices_list:
|
||||
# Append --<name> and --<name>_conf.
|
||||
# e.g. --encoder and --encoder_conf
|
||||
class_choices.add_arguments(group)
|
||||
|
||||
@classmethod
|
||||
def build_collate_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Callable[
|
||||
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
|
||||
Tuple[List[str], Dict[str, torch.Tensor]],
|
||||
]:
|
||||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
|
||||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||
# if args.use_preprocessor:
|
||||
# retval = CommonPreprocessor(
|
||||
# train=train,
|
||||
# token_type=args.token_type,
|
||||
# token_list=args.token_list,
|
||||
# bpemodel=None,
|
||||
# non_linguistic_symbols=None,
|
||||
# text_cleaner=None,
|
||||
# g2p_type=None,
|
||||
# split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
|
||||
# seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
|
||||
# # NOTE(kamo): Check attribute existence for backward compatibility
|
||||
# rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
|
||||
# rir_apply_prob=args.rir_apply_prob
|
||||
# if hasattr(args, "rir_apply_prob")
|
||||
# else 1.0,
|
||||
# noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
|
||||
# noise_apply_prob=args.noise_apply_prob
|
||||
# if hasattr(args, "noise_apply_prob")
|
||||
# else 1.0,
|
||||
# noise_db_range=args.noise_db_range
|
||||
# if hasattr(args, "noise_db_range")
|
||||
# else "13_15",
|
||||
# speech_volume_normalize=args.speech_volume_normalize
|
||||
# if hasattr(args, "rir_scp")
|
||||
# else None,
|
||||
# )
|
||||
# else:
|
||||
# retval = None
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def required_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
if not inference:
|
||||
retval = ("speech", )
|
||||
else:
|
||||
# Recognition mode
|
||||
retval = ("speech", )
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def optional_data_names(
|
||||
cls, train: bool = True, inference: bool = False
|
||||
) -> Tuple[str, ...]:
|
||||
retval = ()
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace):
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None or args.frontend == "wav_frontend_mel23":
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
if args.frontend == 'wav_frontend':
|
||||
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||
else:
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(**args.encoder_conf)
|
||||
|
||||
# 3. EncoderDecoderAttractor
|
||||
encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
|
||||
encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
|
||||
|
||||
# 9. Build model
|
||||
model_class = model_choices.get_class(args.model)
|
||||
model = model_class(
|
||||
frontend=frontend,
|
||||
encoder=encoder,
|
||||
encoder_decoder_attractor=encoder_decoder_attractor,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
# 10. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
return model
|
||||
|
||||
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
|
||||
@classmethod
|
||||
def build_model_from_file(
|
||||
cls,
|
||||
config_file: Union[Path, str] = None,
|
||||
model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
"""Build model from the files.
|
||||
|
||||
This method is used for inference or fine-tuning.
|
||||
|
||||
Args:
|
||||
config_file: The yaml file saved when training.
|
||||
model_file: The model file saved when training.
|
||||
cmvn_file: The cmvn file for front-end
|
||||
device: Device type, "cpu", "cuda", or "cuda:N".
|
||||
|
||||
"""
|
||||
if config_file is None:
|
||||
assert model_file is not None, (
|
||||
"The argument 'model_file' must be provided "
|
||||
"if the argument 'config_file' is not specified."
|
||||
)
|
||||
config_file = Path(model_file).parent / "config.yaml"
|
||||
else:
|
||||
config_file = Path(config_file)
|
||||
|
||||
with config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
args = argparse.Namespace(**args)
|
||||
model = cls.build_model(args)
|
||||
if not isinstance(model, FunASRModel):
|
||||
raise RuntimeError(
|
||||
f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
|
||||
)
|
||||
if model_file is not None:
|
||||
if device == "cuda":
|
||||
device = f"cuda:{torch.cuda.current_device()}"
|
||||
checkpoint = torch.load(model_file, map_location=device)
|
||||
if "state_dict" in checkpoint.keys():
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
model.to(device)
|
||||
return model, args
|
||||
|
||||
13
funasr/utils/hinter.py
Normal file
13
funasr/utils/hinter.py
Normal file
@ -0,0 +1,13 @@
|
||||
import sys
|
||||
import torch.distributed
|
||||
import logging
|
||||
|
||||
HINTED = set()
|
||||
|
||||
|
||||
def hint_once(content, uid, rank=None):
|
||||
if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank:
|
||||
if uid not in HINTED:
|
||||
logging.info(content)
|
||||
HINTED.add(uid)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user