diff --git a/egs/alimeeting/diarization/sond/README.md b/egs/alimeeting/diarization/sond/README.md new file mode 100644 index 000000000..8bef142dc --- /dev/null +++ b/egs/alimeeting/diarization/sond/README.md @@ -0,0 +1,6 @@ +# Results +You will get a DER about 4.21%, which is reported in [1], Table 6, line "SOND Oracle Profile". + +# Reference +[1] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, Zhihao Du, Shiliang Zhang, +Siqi Zheng, Zhijie Yan. EMNLP 2022. \ No newline at end of file diff --git a/egs/alimeeting/diarization/sond/config.yaml b/egs/alimeeting/diarization/sond/config.yaml new file mode 100644 index 000000000..072c171c3 --- /dev/null +++ b/egs/alimeeting/diarization/sond/config.yaml @@ -0,0 +1,2740 @@ +config: finetune.yaml +print_config: false +log_level: INFO +dry_run: false +iterator_type: sequence +output_dir: exp/sond +ngpu: 1 +seed: 0 +num_workers: 16 +num_att_plot: 0 +dist_backend: nccl +dist_init_method: env:// +dist_world_size: null +dist_rank: null +local_rank: 0 +dist_master_addr: null +dist_master_port: null +dist_launcher: null +multiprocessing_distributed: true +distributed: false +unused_parameters: true +sharded_ddp: false +ddp_backend: pytorch_ddp +cudnn_enabled: true +cudnn_benchmark: false +cudnn_deterministic: true +collect_stats: false +write_collected_feats: false +max_epoch: 50 +patience: null +val_scheduler_criterion: +- valid +- acc +early_stopping_criterion: +- valid +- loss +- min +best_model_criterion: +- - valid + - acc + - max +keep_nbest_models: 10 +nbest_averaging_interval: 0 +grad_clip: 5 +grad_clip_type: 2.0 +grad_noise: false +accum_grad: 1 +no_forward_run: false +resume: true +train_dtype: float32 +use_amp: false +log_interval: 50 +use_matplotlib: false +use_tensorboard: true +use_wandb: false +wandb_project: null +wandb_id: null +wandb_entity: null +wandb_name: null +wandb_model_log_interval: -1 +use_pai: true +detect_anomaly: false +pretrain_path: null +init_param: [] +ignore_init_mismatch: false +freeze_param: [] +num_iters_per_epoch: null +batch_size: 20 +valid_batch_size: null +batch_bins: 10000 +valid_batch_bins: null +train_shape_file: +- /data/volume1/youyan/aishell/ark/train/speech_shape.1 +- /data/volume1/youyan/aishell/ark/train/text_shape.1 +valid_shape_file: +- /data/volume1/youyan/aishell/ark/dev/speech_shape.1 +- /data/volume1/youyan/aishell/ark/dev/text_shape.1 +batch_type: length +valid_batch_type: null +fold_length: +- 512 +- 150 +sort_in_batch: descending +sort_batch: descending +multiple_iterator: false +chunk_length: 500 +chunk_shift_ratio: 0.5 +num_cache_chunks: 1024 +train_data_path_and_name_and_type: +- - /data/volume1/youyan/aishell/ark/train/data.scp + - speech + - kaldi_ark +- - /data/volume1/youyan/aishell/ark/train/data.text.1 + - text + - text +valid_data_path_and_name_and_type: +- - /data/volume1/youyan/aishell/ark/dev/data.scp + - speech + - kaldi_ark +- - /data/volume1/youyan/aishell/ark/dev/data.text.1 + - text + - text +allow_variable_data_keys: false +max_cache_size: 0.0 +max_cache_fd: 32 +valid_max_cache_size: null +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 30000 +token_list: +- '0' +- '1' +- '2' +- '3' +- '4' +- '5' +- '6' +- '7' +- '8' +- '9' +- '10' +- '11' +- '12' +- '13' +- '14' +- '15' +- '16' +- '17' +- '18' +- '19' +- '20' +- '21' +- '22' +- '23' +- '24' +- '25' +- '26' +- '27' +- '28' +- '29' +- '30' +- '32' +- '33' +- '34' +- '35' +- '36' +- '37' +- '38' +- '39' +- '40' +- '41' +- '42' +- '43' +- '44' +- '45' +- '46' +- '48' +- '49' +- '50' +- '51' +- '52' +- '53' +- '54' +- '56' +- '57' +- '58' +- '60' +- '64' +- '65' +- '66' +- '67' +- '68' +- '69' +- '70' +- '71' +- '72' +- '73' +- '74' +- '75' +- '76' +- '77' +- '78' +- '80' +- '81' +- '82' +- '83' +- '84' +- '85' +- '86' +- '88' +- '89' +- '90' +- '92' +- '96' +- '97' +- '98' +- '99' +- '100' +- '101' +- '102' +- '104' +- '105' +- '106' +- '108' +- '112' +- '113' +- '114' +- '116' +- '120' +- '128' +- '129' +- '130' +- '131' +- '132' +- '133' +- '134' +- '135' +- '136' +- '137' +- '138' +- '139' +- '140' +- '141' +- '142' +- '144' +- '145' +- '146' +- '147' +- '148' +- '149' +- '150' +- '152' +- '153' +- '154' +- '156' +- '160' +- '161' +- '162' +- '163' +- '164' +- '165' +- '166' +- '168' +- '169' +- '170' +- '172' +- '176' +- '177' +- '178' +- '180' +- '184' +- '192' +- '193' +- '194' +- '195' +- '196' +- '197' +- '198' +- '200' +- '201' +- '202' +- '204' +- '208' +- '209' +- '210' +- '212' +- '216' +- '224' +- '225' +- '226' +- '228' +- '232' +- '240' +- '256' +- '257' +- '258' +- '259' +- '260' +- '261' +- '262' +- '263' +- '264' +- '265' +- '266' +- '267' +- '268' +- '269' +- '270' +- '272' +- '273' +- '274' +- '275' +- '276' +- '277' +- '278' +- '280' +- '281' +- '282' +- '284' +- '288' +- '289' +- '290' +- '291' +- '292' +- '293' +- '294' +- '296' +- '297' +- '298' +- '300' +- '304' +- '305' +- '306' +- '308' +- '312' +- '320' +- '321' +- '322' +- '323' +- '324' +- '325' +- '326' +- '328' +- '329' +- '330' +- '332' +- '336' +- '337' +- '338' +- '340' +- '344' +- '352' +- '353' +- '354' +- '356' +- '360' +- '368' +- '384' +- '385' +- '386' +- '387' +- '388' +- '389' +- '390' +- '392' +- '393' +- '394' +- '396' +- '400' +- '401' +- '402' +- '404' +- '408' +- '416' +- '417' +- '418' +- '420' +- '424' +- '432' +- '448' +- '449' +- '450' +- '452' +- '456' +- '464' +- '480' +- '512' +- '513' +- '514' +- '515' +- '516' +- '517' +- '518' +- '519' +- '520' +- '521' +- '522' +- '523' +- '524' +- '525' +- '526' +- '528' +- '529' +- '530' +- '531' +- '532' +- '533' +- '534' +- '536' +- '537' +- '538' +- '540' +- '544' +- '545' +- '546' +- '547' +- '548' +- '549' +- '550' +- '552' +- '553' +- '554' +- '556' +- '560' +- '561' +- '562' +- '564' +- '568' +- '576' +- '577' +- '578' +- '579' +- '580' +- '581' +- '582' +- '584' +- '585' +- '586' +- '588' +- '592' +- '593' +- '594' +- '596' +- '600' +- '608' +- '609' +- '610' +- '612' +- '616' +- '624' +- '640' +- '641' +- '642' +- '643' +- '644' +- '645' +- '646' +- '648' +- '649' +- '650' +- '652' +- '656' +- '657' +- '658' +- '660' +- '664' +- '672' +- '673' +- '674' +- '676' +- '680' +- '688' +- '704' +- '705' +- '706' +- '708' +- '712' +- '720' +- '736' +- '768' +- '769' +- '770' +- '771' +- '772' +- '773' +- '774' +- '776' +- '777' +- '778' +- '780' +- '784' +- '785' +- '786' +- '788' +- '792' +- '800' +- '801' +- '802' +- '804' +- '808' +- '816' +- '832' +- '833' +- '834' +- '836' +- '840' +- '848' +- '864' +- '896' +- '897' +- '898' +- '900' +- '904' +- '912' +- '928' +- '960' +- '1024' +- '1025' +- '1026' +- '1027' +- '1028' +- '1029' +- '1030' +- '1031' +- '1032' +- '1033' +- '1034' +- '1035' +- '1036' +- '1037' +- '1038' +- '1040' +- '1041' +- '1042' +- '1043' +- '1044' +- '1045' +- '1046' +- '1048' +- '1049' +- '1050' +- '1052' +- '1056' +- '1057' +- '1058' +- '1059' +- '1060' +- '1061' +- '1062' +- '1064' +- '1065' +- '1066' +- '1068' +- '1072' +- '1073' +- '1074' +- '1076' +- '1080' +- '1088' +- '1089' +- '1090' +- '1091' +- '1092' +- '1093' +- '1094' +- '1096' +- '1097' +- '1098' +- '1100' +- '1104' +- '1105' +- '1106' +- '1108' +- '1112' +- '1120' +- '1121' +- '1122' +- '1124' +- '1128' +- '1136' +- '1152' +- '1153' +- '1154' +- '1155' +- '1156' +- '1157' +- '1158' +- '1160' +- '1161' +- '1162' +- '1164' +- '1168' +- '1169' +- '1170' +- '1172' +- '1176' +- '1184' +- '1185' +- '1186' +- '1188' +- '1192' +- '1200' +- '1216' +- '1217' +- '1218' +- '1220' +- '1224' +- '1232' +- '1248' +- '1280' +- '1281' +- '1282' +- '1283' +- '1284' +- '1285' +- '1286' +- '1288' +- '1289' +- '1290' +- '1292' +- '1296' +- '1297' +- '1298' +- '1300' +- '1304' +- '1312' +- '1313' +- '1314' +- '1316' +- '1320' +- '1328' +- '1344' +- '1345' +- '1346' +- '1348' +- '1352' +- '1360' +- '1376' +- '1408' +- '1409' +- '1410' +- '1412' +- '1416' +- '1424' +- '1440' +- '1472' +- '1536' +- '1537' +- '1538' +- '1539' +- '1540' +- '1541' +- '1542' +- '1544' +- '1545' +- '1546' +- '1548' +- '1552' +- '1553' +- '1554' +- '1556' +- '1560' +- '1568' +- '1569' +- '1570' +- '1572' +- '1576' +- '1584' +- '1600' +- '1601' +- '1602' +- '1604' +- '1608' +- '1616' +- '1632' +- '1664' +- '1665' +- '1666' +- '1668' +- '1672' +- '1680' +- '1696' +- '1728' +- '1792' +- '1793' +- '1794' +- '1796' +- '1800' +- '1808' +- '1824' +- '1856' +- '1920' +- '2048' +- '2049' +- '2050' +- '2051' +- '2052' +- '2053' +- '2054' +- '2055' +- '2056' +- '2057' +- '2058' +- '2059' +- '2060' +- '2061' +- '2062' +- '2064' +- '2065' +- '2066' +- '2067' +- '2068' +- '2069' +- '2070' +- '2072' +- '2073' +- '2074' +- '2076' +- '2080' +- '2081' +- '2082' +- '2083' +- '2084' +- '2085' +- '2086' +- '2088' +- '2089' +- '2090' +- '2092' +- '2096' +- '2097' +- '2098' +- '2100' +- '2104' +- '2112' +- '2113' +- '2114' +- '2115' +- '2116' +- '2117' +- '2118' +- '2120' +- '2121' +- '2122' +- '2124' +- '2128' +- '2129' +- '2130' +- '2132' +- '2136' +- '2144' +- '2145' +- '2146' +- '2148' +- '2152' +- '2160' +- '2176' +- '2177' +- '2178' +- '2179' +- '2180' +- '2181' +- '2182' +- '2184' +- '2185' +- '2186' +- '2188' +- '2192' +- '2193' +- '2194' +- '2196' +- '2200' +- '2208' +- '2209' +- '2210' +- '2212' +- '2216' +- '2224' +- '2240' +- '2241' +- '2242' +- '2244' +- '2248' +- '2256' +- '2272' +- '2304' +- '2305' +- '2306' +- '2307' +- '2308' +- '2309' +- '2310' +- '2312' +- '2313' +- '2314' +- '2316' +- '2320' +- '2321' +- '2322' +- '2324' +- '2328' +- '2336' +- '2337' +- '2338' +- '2340' +- '2344' +- '2352' +- '2368' +- '2369' +- '2370' +- '2372' +- '2376' +- '2384' +- '2400' +- '2432' +- '2433' +- '2434' +- '2436' +- '2440' +- '2448' +- '2464' +- '2496' +- '2560' +- '2561' +- '2562' +- '2563' +- '2564' +- '2565' +- '2566' +- '2568' +- '2569' +- '2570' +- '2572' +- '2576' +- '2577' +- '2578' +- '2580' +- '2584' +- '2592' +- '2593' +- '2594' +- '2596' +- '2600' +- '2608' +- '2624' +- '2625' +- '2626' +- '2628' +- '2632' +- '2640' +- '2656' +- '2688' +- '2689' +- '2690' +- '2692' +- '2696' +- '2704' +- '2720' +- '2752' +- '2816' +- '2817' +- '2818' +- '2820' +- '2824' +- '2832' +- '2848' +- '2880' +- '2944' +- '3072' +- '3073' +- '3074' +- '3075' +- '3076' +- '3077' +- '3078' +- '3080' +- '3081' +- '3082' +- '3084' +- '3088' +- '3089' +- '3090' +- '3092' +- '3096' +- '3104' +- '3105' +- '3106' +- '3108' +- '3112' +- '3120' +- '3136' +- '3137' +- '3138' +- '3140' +- '3144' +- '3152' +- '3168' +- '3200' +- '3201' +- '3202' +- '3204' +- '3208' +- '3216' +- '3232' +- '3264' +- '3328' +- '3329' +- '3330' +- '3332' +- '3336' +- '3344' +- '3360' +- '3392' +- '3456' +- '3584' +- '3585' +- '3586' +- '3588' +- '3592' +- '3600' +- '3616' +- '3648' +- '3712' +- '3840' +- '4096' +- '4097' +- '4098' +- '4099' +- '4100' +- '4101' +- '4102' +- '4103' +- '4104' +- '4105' +- '4106' +- '4107' +- '4108' +- '4109' +- '4110' +- '4112' +- '4113' +- '4114' +- '4115' +- '4116' +- '4117' +- '4118' +- '4120' +- '4121' +- '4122' +- '4124' +- '4128' +- '4129' +- '4130' +- '4131' +- '4132' +- '4133' +- '4134' +- '4136' +- '4137' +- '4138' +- '4140' +- '4144' +- '4145' +- '4146' +- '4148' +- '4152' +- '4160' +- '4161' +- '4162' +- '4163' +- '4164' +- '4165' +- '4166' +- '4168' +- '4169' +- '4170' +- '4172' +- '4176' +- '4177' +- '4178' +- '4180' +- '4184' +- '4192' +- '4193' +- '4194' +- '4196' +- '4200' +- '4208' +- '4224' +- '4225' +- '4226' +- '4227' +- '4228' +- '4229' +- '4230' +- '4232' +- '4233' +- '4234' +- '4236' +- '4240' +- '4241' +- '4242' +- '4244' +- '4248' +- '4256' +- '4257' +- '4258' +- '4260' +- '4264' +- '4272' +- '4288' +- '4289' +- '4290' +- '4292' +- '4296' +- '4304' +- '4320' +- '4352' +- '4353' +- '4354' +- '4355' +- '4356' +- '4357' +- '4358' +- '4360' +- '4361' +- '4362' +- '4364' +- '4368' +- '4369' +- '4370' +- '4372' +- '4376' +- '4384' +- '4385' +- '4386' +- '4388' +- '4392' +- '4400' +- '4416' +- '4417' +- '4418' +- '4420' +- '4424' +- '4432' +- '4448' +- '4480' +- '4481' +- '4482' +- '4484' +- '4488' +- '4496' +- '4512' +- '4544' +- '4608' +- '4609' +- '4610' +- '4611' +- '4612' +- '4613' +- '4614' +- '4616' +- '4617' +- '4618' +- '4620' +- '4624' +- '4625' +- '4626' +- '4628' +- '4632' +- '4640' +- '4641' +- '4642' +- '4644' +- '4648' +- '4656' +- '4672' +- '4673' +- '4674' +- '4676' +- '4680' +- '4688' +- '4704' +- '4736' +- '4737' +- '4738' +- '4740' +- '4744' +- '4752' +- '4768' +- '4800' +- '4864' +- '4865' +- '4866' +- '4868' +- '4872' +- '4880' +- '4896' +- '4928' +- '4992' +- '5120' +- '5121' +- '5122' +- '5123' +- '5124' +- '5125' +- '5126' +- '5128' +- '5129' +- '5130' +- '5132' +- '5136' +- '5137' +- '5138' +- '5140' +- '5144' +- '5152' +- '5153' +- '5154' +- '5156' +- '5160' +- '5168' +- '5184' +- '5185' +- '5186' +- '5188' +- '5192' +- '5200' +- '5216' +- '5248' +- '5249' +- '5250' +- '5252' +- '5256' +- '5264' +- '5280' +- '5312' +- '5376' +- '5377' +- '5378' +- '5380' +- '5384' +- '5392' +- '5408' +- '5440' +- '5504' +- '5632' +- '5633' +- '5634' +- '5636' +- '5640' +- '5648' +- '5664' +- '5696' +- '5760' +- '5888' +- '6144' +- '6145' +- '6146' +- '6147' +- '6148' +- '6149' +- '6150' +- '6152' +- '6153' +- '6154' +- '6156' +- '6160' +- '6161' +- '6162' +- '6164' +- '6168' +- '6176' +- '6177' +- '6178' +- '6180' +- '6184' +- '6192' +- '6208' +- '6209' +- '6210' +- '6212' +- '6216' +- '6224' +- '6240' +- '6272' +- '6273' +- '6274' +- '6276' +- '6280' +- '6288' +- '6304' +- '6336' +- '6400' +- '6401' +- '6402' +- '6404' +- '6408' +- '6416' +- '6432' +- '6464' +- '6528' +- '6656' +- '6657' +- '6658' +- '6660' +- '6664' +- '6672' +- '6688' +- '6720' +- '6784' +- '6912' +- '7168' +- '7169' +- '7170' +- '7172' +- '7176' +- '7184' +- '7200' +- '7232' +- '7296' +- '7424' +- '7680' +- '8192' +- '8193' +- '8194' +- '8195' +- '8196' +- '8197' +- '8198' +- '8199' +- '8200' +- '8201' +- '8202' +- '8203' +- '8204' +- '8205' +- '8206' +- '8208' +- '8209' +- '8210' +- '8211' +- '8212' +- '8213' +- '8214' +- '8216' +- '8217' +- '8218' +- '8220' +- '8224' +- '8225' +- '8226' +- '8227' +- '8228' +- '8229' +- '8230' +- '8232' +- '8233' +- '8234' +- '8236' +- '8240' +- '8241' +- '8242' +- '8244' +- '8248' +- '8256' +- '8257' +- '8258' +- '8259' +- '8260' +- '8261' +- '8262' +- '8264' +- '8265' +- '8266' +- '8268' +- '8272' +- '8273' +- '8274' +- '8276' +- '8280' +- '8288' +- '8289' +- '8290' +- '8292' +- '8296' +- '8304' +- '8320' +- '8321' +- '8322' +- '8323' +- '8324' +- '8325' +- '8326' +- '8328' +- '8329' +- '8330' +- '8332' +- '8336' +- '8337' +- '8338' +- '8340' +- '8344' +- '8352' +- '8353' +- '8354' +- '8356' +- '8360' +- '8368' +- '8384' +- '8385' +- '8386' +- '8388' +- '8392' +- '8400' +- '8416' +- '8448' +- '8449' +- '8450' +- '8451' +- '8452' +- '8453' +- '8454' +- '8456' +- '8457' +- '8458' +- '8460' +- '8464' +- '8465' +- '8466' +- '8468' +- '8472' +- '8480' +- '8481' +- '8482' +- '8484' +- '8488' +- '8496' +- '8512' +- '8513' +- '8514' +- '8516' +- '8520' +- '8528' +- '8544' +- '8576' +- '8577' +- '8578' +- '8580' +- '8584' +- '8592' +- '8608' +- '8640' +- '8704' +- '8705' +- '8706' +- '8707' +- '8708' +- '8709' +- '8710' +- '8712' +- '8713' +- '8714' +- '8716' +- '8720' +- '8721' +- '8722' +- '8724' +- '8728' +- '8736' +- '8737' +- '8738' +- '8740' +- '8744' +- '8752' +- '8768' +- '8769' +- '8770' +- '8772' +- '8776' +- '8784' +- '8800' +- '8832' +- '8833' +- '8834' +- '8836' +- '8840' +- '8848' +- '8864' +- '8896' +- '8960' +- '8961' +- '8962' +- '8964' +- '8968' +- '8976' +- '8992' +- '9024' +- '9088' +- '9216' +- '9217' +- '9218' +- '9219' +- '9220' +- '9221' +- '9222' +- '9224' +- '9225' +- '9226' +- '9228' +- '9232' +- '9233' +- '9234' +- '9236' +- '9240' +- '9248' +- '9249' +- '9250' +- '9252' +- '9256' +- '9264' +- '9280' +- '9281' +- '9282' +- '9284' +- '9288' +- '9296' +- '9312' +- '9344' +- '9345' +- '9346' +- '9348' +- '9352' +- '9360' +- '9376' +- '9408' +- '9472' +- '9473' +- '9474' +- '9476' +- '9480' +- '9488' +- '9504' +- '9536' +- '9600' +- '9728' +- '9729' +- '9730' +- '9732' +- '9736' +- '9744' +- '9760' +- '9792' +- '9856' +- '9984' +- '10240' +- '10241' +- '10242' +- '10243' +- '10244' +- '10245' +- '10246' +- '10248' +- '10249' +- '10250' +- '10252' +- '10256' +- '10257' +- '10258' +- '10260' +- '10264' +- '10272' +- '10273' +- '10274' +- '10276' +- '10280' +- '10288' +- '10304' +- '10305' +- '10306' +- '10308' +- '10312' +- '10320' +- '10336' +- '10368' +- '10369' +- '10370' +- '10372' +- '10376' +- '10384' +- '10400' +- '10432' +- '10496' +- '10497' +- '10498' +- '10500' +- '10504' +- '10512' +- '10528' +- '10560' +- '10624' +- '10752' +- '10753' +- '10754' +- '10756' +- '10760' +- '10768' +- '10784' +- '10816' +- '10880' +- '11008' +- '11264' +- '11265' +- '11266' +- '11268' +- '11272' +- '11280' +- '11296' +- '11328' +- '11392' +- '11520' +- '11776' +- '12288' +- '12289' +- '12290' +- '12291' +- '12292' +- '12293' +- '12294' +- '12296' +- '12297' +- '12298' +- '12300' +- '12304' +- '12305' +- '12306' +- '12308' +- '12312' +- '12320' +- '12321' +- '12322' +- '12324' +- '12328' +- '12336' +- '12352' +- '12353' +- '12354' +- '12356' +- '12360' +- '12368' +- '12384' +- '12416' +- '12417' +- '12418' +- '12420' +- '12424' +- '12432' +- '12448' +- '12480' +- '12544' +- '12545' +- '12546' +- '12548' +- '12552' +- '12560' +- '12576' +- '12608' +- '12672' +- '12800' +- '12801' +- '12802' +- '12804' +- '12808' +- '12816' +- '12832' +- '12864' +- '12928' +- '13056' +- '13312' +- '13313' +- '13314' +- '13316' +- '13320' +- '13328' +- '13344' +- '13376' +- '13440' +- '13568' +- '13824' +- '14336' +- '14337' +- '14338' +- '14340' +- '14344' +- '14352' +- '14368' +- '14400' +- '14464' +- '14592' +- '14848' +- '15360' +- '16384' +- '16385' +- '16386' +- '16387' +- '16388' +- '16389' +- '16390' +- '16391' +- '16392' +- '16393' +- '16394' +- '16395' +- '16396' +- '16397' +- '16398' +- '16400' +- '16401' +- '16402' +- '16403' +- '16404' +- '16405' +- '16406' +- '16408' +- '16409' +- '16410' +- '16412' +- '16416' +- '16417' +- '16418' +- '16419' +- '16420' +- '16421' +- '16422' +- '16424' +- '16425' +- '16426' +- '16428' +- '16432' +- '16433' +- '16434' +- '16436' +- '16440' +- '16448' +- '16449' +- '16450' +- '16451' +- '16452' +- '16453' +- '16454' +- '16456' +- '16457' +- '16458' +- '16460' +- '16464' +- '16465' +- '16466' +- '16468' +- '16472' +- '16480' +- '16481' +- '16482' +- '16484' +- '16488' +- '16496' +- '16512' +- '16513' +- '16514' +- '16515' +- '16516' +- '16517' +- '16518' +- '16520' +- '16521' +- '16522' +- '16524' +- '16528' +- '16529' +- '16530' +- '16532' +- '16536' +- '16544' +- '16545' +- '16546' +- '16548' +- '16552' +- '16560' +- '16576' +- '16577' +- '16578' +- '16580' +- '16584' +- '16592' +- '16608' +- '16640' +- '16641' +- '16642' +- '16643' +- '16644' +- '16645' +- '16646' +- '16648' +- '16649' +- '16650' +- '16652' +- '16656' +- '16657' +- '16658' +- '16660' +- '16664' +- '16672' +- '16673' +- '16674' +- '16676' +- '16680' +- '16688' +- '16704' +- '16705' +- '16706' +- '16708' +- '16712' +- '16720' +- '16736' +- '16768' +- '16769' +- '16770' +- '16772' +- '16776' +- '16784' +- '16800' +- '16832' +- '16896' +- '16897' +- '16898' +- '16899' +- '16900' +- '16901' +- '16902' +- '16904' +- '16905' +- '16906' +- '16908' +- '16912' +- '16913' +- '16914' +- '16916' +- '16920' +- '16928' +- '16929' +- '16930' +- '16932' +- '16936' +- '16944' +- '16960' +- '16961' +- '16962' +- '16964' +- '16968' +- '16976' +- '16992' +- '17024' +- '17025' +- '17026' +- '17028' +- '17032' +- '17040' +- '17056' +- '17088' +- '17152' +- '17153' +- '17154' +- '17156' +- '17160' +- '17168' +- '17184' +- '17216' +- '17280' +- '17408' +- '17409' +- '17410' +- '17411' +- '17412' +- '17413' +- '17414' +- '17416' +- '17417' +- '17418' +- '17420' +- '17424' +- '17425' +- '17426' +- '17428' +- '17432' +- '17440' +- '17441' +- '17442' +- '17444' +- '17448' +- '17456' +- '17472' +- '17473' +- '17474' +- '17476' +- '17480' +- '17488' +- '17504' +- '17536' +- '17537' +- '17538' +- '17540' +- '17544' +- '17552' +- '17568' +- '17600' +- '17664' +- '17665' +- '17666' +- '17668' +- '17672' +- '17680' +- '17696' +- '17728' +- '17792' +- '17920' +- '17921' +- '17922' +- '17924' +- '17928' +- '17936' +- '17952' +- '17984' +- '18048' +- '18176' +- '18432' +- '18433' +- '18434' +- '18435' +- '18436' +- '18437' +- '18438' +- '18440' +- '18441' +- '18442' +- '18444' +- '18448' +- '18449' +- '18450' +- '18452' +- '18456' +- '18464' +- '18465' +- '18466' +- '18468' +- '18472' +- '18480' +- '18496' +- '18497' +- '18498' +- '18500' +- '18504' +- '18512' +- '18528' +- '18560' +- '18561' +- '18562' +- '18564' +- '18568' +- '18576' +- '18592' +- '18624' +- '18688' +- '18689' +- '18690' +- '18692' +- '18696' +- '18704' +- '18720' +- '18752' +- '18816' +- '18944' +- '18945' +- '18946' +- '18948' +- '18952' +- '18960' +- '18976' +- '19008' +- '19072' +- '19200' +- '19456' +- '19457' +- '19458' +- '19460' +- '19464' +- '19472' +- '19488' +- '19520' +- '19584' +- '19712' +- '19968' +- '20480' +- '20481' +- '20482' +- '20483' +- '20484' +- '20485' +- '20486' +- '20488' +- '20489' +- '20490' +- '20492' +- '20496' +- '20497' +- '20498' +- '20500' +- '20504' +- '20512' +- '20513' +- '20514' +- '20516' +- '20520' +- '20528' +- '20544' +- '20545' +- '20546' +- '20548' +- '20552' +- '20560' +- '20576' +- '20608' +- '20609' +- '20610' +- '20612' +- '20616' +- '20624' +- '20640' +- '20672' +- '20736' +- '20737' +- '20738' +- '20740' +- '20744' +- '20752' +- '20768' +- '20800' +- '20864' +- '20992' +- '20993' +- '20994' +- '20996' +- '21000' +- '21008' +- '21024' +- '21056' +- '21120' +- '21248' +- '21504' +- '21505' +- '21506' +- '21508' +- '21512' +- '21520' +- '21536' +- '21568' +- '21632' +- '21760' +- '22016' +- '22528' +- '22529' +- '22530' +- '22532' +- '22536' +- '22544' +- '22560' +- '22592' +- '22656' +- '22784' +- '23040' +- '23552' +- '24576' +- '24577' +- '24578' +- '24579' +- '24580' +- '24581' +- '24582' +- '24584' +- '24585' +- '24586' +- '24588' +- '24592' +- '24593' +- '24594' +- '24596' +- '24600' +- '24608' +- '24609' +- '24610' +- '24612' +- '24616' +- '24624' +- '24640' +- '24641' +- '24642' +- '24644' +- '24648' +- '24656' +- '24672' +- '24704' +- '24705' +- '24706' +- '24708' +- '24712' +- '24720' +- '24736' +- '24768' +- '24832' +- '24833' +- '24834' +- '24836' +- '24840' +- '24848' +- '24864' +- '24896' +- '24960' +- '25088' +- '25089' +- '25090' +- '25092' +- '25096' +- '25104' +- '25120' +- '25152' +- '25216' +- '25344' +- '25600' +- '25601' +- '25602' +- '25604' +- '25608' +- '25616' +- '25632' +- '25664' +- '25728' +- '25856' +- '26112' +- '26624' +- '26625' +- '26626' +- '26628' +- '26632' +- '26640' +- '26656' +- '26688' +- '26752' +- '26880' +- '27136' +- '27648' +- '28672' +- '28673' +- '28674' +- '28676' +- '28680' +- '28688' +- '28704' +- '28736' +- '28800' +- '28928' +- '29184' +- '29696' +- '30720' +- '32768' +- '32769' +- '32770' +- '32771' +- '32772' +- '32773' +- '32774' +- '32775' +- '32776' +- '32777' +- '32778' +- '32779' +- '32780' +- '32781' +- '32782' +- '32784' +- '32785' +- '32786' +- '32787' +- '32788' +- '32789' +- '32790' +- '32792' +- '32793' +- '32794' +- '32796' +- '32800' +- '32801' +- '32802' +- '32803' +- '32804' +- '32805' +- '32806' +- '32808' +- '32809' +- '32810' +- '32812' +- '32816' +- '32817' +- '32818' +- '32820' +- '32824' +- '32832' +- '32833' +- '32834' +- '32835' +- '32836' +- '32837' +- '32838' +- '32840' +- '32841' +- '32842' +- '32844' +- '32848' +- '32849' +- '32850' +- '32852' +- '32856' +- '32864' +- '32865' +- '32866' +- '32868' +- '32872' +- '32880' +- '32896' +- '32897' +- '32898' +- '32899' +- '32900' +- '32901' +- '32902' +- '32904' +- '32905' +- '32906' +- '32908' +- '32912' +- '32913' +- '32914' +- '32916' +- '32920' +- '32928' +- '32929' +- '32930' +- '32932' +- '32936' +- '32944' +- '32960' +- '32961' +- '32962' +- '32964' +- '32968' +- '32976' +- '32992' +- '33024' +- '33025' +- '33026' +- '33027' +- '33028' +- '33029' +- '33030' +- '33032' +- '33033' +- '33034' +- '33036' +- '33040' +- '33041' +- '33042' +- '33044' +- '33048' +- '33056' +- '33057' +- '33058' +- '33060' +- '33064' +- '33072' +- '33088' +- '33089' +- '33090' +- '33092' +- '33096' +- '33104' +- '33120' +- '33152' +- '33153' +- '33154' +- '33156' +- '33160' +- '33168' +- '33184' +- '33216' +- '33280' +- '33281' +- '33282' +- '33283' +- '33284' +- '33285' +- '33286' +- '33288' +- '33289' +- '33290' +- '33292' +- '33296' +- '33297' +- '33298' +- '33300' +- '33304' +- '33312' +- '33313' +- '33314' +- '33316' +- '33320' +- '33328' +- '33344' +- '33345' +- '33346' +- '33348' +- '33352' +- '33360' +- '33376' +- '33408' +- '33409' +- '33410' +- '33412' +- '33416' +- '33424' +- '33440' +- '33472' +- '33536' +- '33537' +- '33538' +- '33540' +- '33544' +- '33552' +- '33568' +- '33600' +- '33664' +- '33792' +- '33793' +- '33794' +- '33795' +- '33796' +- '33797' +- '33798' +- '33800' +- '33801' +- '33802' +- '33804' +- '33808' +- '33809' +- '33810' +- '33812' +- '33816' +- '33824' +- '33825' +- '33826' +- '33828' +- '33832' +- '33840' +- '33856' +- '33857' +- '33858' +- '33860' +- '33864' +- '33872' +- '33888' +- '33920' +- '33921' +- '33922' +- '33924' +- '33928' +- '33936' +- '33952' +- '33984' +- '34048' +- '34049' +- '34050' +- '34052' +- '34056' +- '34064' +- '34080' +- '34112' +- '34176' +- '34304' +- '34305' +- '34306' +- '34308' +- '34312' +- '34320' +- '34336' +- '34368' +- '34432' +- '34560' +- '34816' +- '34817' +- '34818' +- '34819' +- '34820' +- '34821' +- '34822' +- '34824' +- '34825' +- '34826' +- '34828' +- '34832' +- '34833' +- '34834' +- '34836' +- '34840' +- '34848' +- '34849' +- '34850' +- '34852' +- '34856' +- '34864' +- '34880' +- '34881' +- '34882' +- '34884' +- '34888' +- '34896' +- '34912' +- '34944' +- '34945' +- '34946' +- '34948' +- '34952' +- '34960' +- '34976' +- '35008' +- '35072' +- '35073' +- '35074' +- '35076' +- '35080' +- '35088' +- '35104' +- '35136' +- '35200' +- '35328' +- '35329' +- '35330' +- '35332' +- '35336' +- '35344' +- '35360' +- '35392' +- '35456' +- '35584' +- '35840' +- '35841' +- '35842' +- '35844' +- '35848' +- '35856' +- '35872' +- '35904' +- '35968' +- '36096' +- '36352' +- '36864' +- '36865' +- '36866' +- '36867' +- '36868' +- '36869' +- '36870' +- '36872' +- '36873' +- '36874' +- '36876' +- '36880' +- '36881' +- '36882' +- '36884' +- '36888' +- '36896' +- '36897' +- '36898' +- '36900' +- '36904' +- '36912' +- '36928' +- '36929' +- '36930' +- '36932' +- '36936' +- '36944' +- '36960' +- '36992' +- '36993' +- '36994' +- '36996' +- '37000' +- '37008' +- '37024' +- '37056' +- '37120' +- '37121' +- '37122' +- '37124' +- '37128' +- '37136' +- '37152' +- '37184' +- '37248' +- '37376' +- '37377' +- '37378' +- '37380' +- '37384' +- '37392' +- '37408' +- '37440' +- '37504' +- '37632' +- '37888' +- '37889' +- '37890' +- '37892' +- '37896' +- '37904' +- '37920' +- '37952' +- '38016' +- '38144' +- '38400' +- '38912' +- '38913' +- '38914' +- '38916' +- '38920' +- '38928' +- '38944' +- '38976' +- '39040' +- '39168' +- '39424' +- '39936' +- '40960' +- '40961' +- '40962' +- '40963' +- '40964' +- '40965' +- '40966' +- '40968' +- '40969' +- '40970' +- '40972' +- '40976' +- '40977' +- '40978' +- '40980' +- '40984' +- '40992' +- '40993' +- '40994' +- '40996' +- '41000' +- '41008' +- '41024' +- '41025' +- '41026' +- '41028' +- '41032' +- '41040' +- '41056' +- '41088' +- '41089' +- '41090' +- '41092' +- '41096' +- '41104' +- '41120' +- '41152' +- '41216' +- '41217' +- '41218' +- '41220' +- '41224' +- '41232' +- '41248' +- '41280' +- '41344' +- '41472' +- '41473' +- '41474' +- '41476' +- '41480' +- '41488' +- '41504' +- '41536' +- '41600' +- '41728' +- '41984' +- '41985' +- '41986' +- '41988' +- '41992' +- '42000' +- '42016' +- '42048' +- '42112' +- '42240' +- '42496' +- '43008' +- '43009' +- '43010' +- '43012' +- '43016' +- '43024' +- '43040' +- '43072' +- '43136' +- '43264' +- '43520' +- '44032' +- '45056' +- '45057' +- '45058' +- '45060' +- '45064' +- '45072' +- '45088' +- '45120' +- '45184' +- '45312' +- '45568' +- '46080' +- '47104' +- '49152' +- '49153' +- '49154' +- '49155' +- '49156' +- '49157' +- '49158' +- '49160' +- '49161' +- '49162' +- '49164' +- '49168' +- '49169' +- '49170' +- '49172' +- '49176' +- '49184' +- '49185' +- '49186' +- '49188' +- '49192' +- '49200' +- '49216' +- '49217' +- '49218' +- '49220' +- '49224' +- '49232' +- '49248' +- '49280' +- '49281' +- '49282' +- '49284' +- '49288' +- '49296' +- '49312' +- '49344' +- '49408' +- '49409' +- '49410' +- '49412' +- '49416' +- '49424' +- '49440' +- '49472' +- '49536' +- '49664' +- '49665' +- '49666' +- '49668' +- '49672' +- '49680' +- '49696' +- '49728' +- '49792' +- '49920' +- '50176' +- '50177' +- '50178' +- '50180' +- '50184' +- '50192' +- '50208' +- '50240' +- '50304' +- '50432' +- '50688' +- '51200' +- '51201' +- '51202' +- '51204' +- '51208' +- '51216' +- '51232' +- '51264' +- '51328' +- '51456' +- '51712' +- '52224' +- '53248' +- '53249' +- '53250' +- '53252' +- '53256' +- '53264' +- '53280' +- '53312' +- '53376' +- '53504' +- '53760' +- '54272' +- '55296' +- '57344' +- '57345' +- '57346' +- '57348' +- '57352' +- '57360' +- '57376' +- '57408' +- '57472' +- '57600' +- '57856' +- '58368' +- '59392' +- '61440' +init: null +input_size: null +cmvn_file: null +ctc_conf: + dropout_rate: 0.0 + ctc_type: builtin + reduce: true + ignore_nan_grad: true +joint_net_conf: null +use_preprocessor: true +token_type: char +bpemodel: null +non_linguistic_symbols: null +cleaner: null +g2p: null +speech_volume_normalize: null +rir_scp: null +rir_apply_prob: 1.0 +noise_scp: null +noise_apply_prob: 1.0 +noise_db_range: '13_15' +specaug: null +specaug_conf: {} +normalize: null +normalize_conf: {} +label_aggregator: null +label_aggregator_conf: {} +model: sond +model_conf: + # ctc_weight: 0.0 + lsm_weight: 0.1 + length_normalized_loss: true + max_spk_num: 16 + # predictor_weight: 1.0 + # predictor_bias: 1 + # sampling_ratio: 0.75 +# speech encoder +encoder: resnet34 +encoder_conf: + # pass by model, equal to feature dim + # input_size: 80 + pooling_type: "window_shift" + pool_size: 20 + stride: 1 + tf2torch_tensor_name_prefix_torch: encoder + tf2torch_tensor_name_prefix_tf: EAND/speech_encoder +speaker_encoder: conv +speaker_encoder_conf: + input_units: 256 + num_layers: 3 + num_units: 256 + kernel_size: 1 + dropout_rate: 0.0 + position_encoder: null + out_units: 256 + out_norm: false + auxiliary_states: false + tf2torch_tensor_name_prefix_torch: speaker_encoder + tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder +ci_scorer: dot +ci_scorer_conf: {} +cd_scorer: san +cd_scorer_conf: + input_size: 512 + output_size: 512 + out_units: 1 + attention_heads: 4 + linear_units: 1024 + num_blocks: 4 + dropout_rate: 0.0 + positional_dropout_rate: 0.0 + attention_dropout_rate: 0.0 + # use string "null" to remove input layer + input_layer: "null" + pos_enc_class: null + normalize_before: true + tf2torch_tensor_name_prefix_torch: cd_scorer + tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer +# post net +decoder: fsmn +decoder_conf: + in_units: 32 + out_units: 2517 + filter_size: 31 + fsmn_num_layers: 6 + dnn_num_layers: 1 + num_memory_units: 512 + ffn_inner_dim: 512 + dropout_rate: 0.0 + tf2torch_tensor_name_prefix_torch: decoder + tf2torch_tensor_name_prefix_tf: EAND/post_net +frontend: wav_frontend +frontend_conf: + fs: 16000 + window: povey + n_mels: 80 + frame_length: 25 + frame_shift: 10 + filter_length_min: -1 + filter_length_max: -1 + lfr_m: 1 + lfr_n: 1 + dither: 0.0 + snip_edges: false +num_worker_count: 1 +required: +- output_dir +- token_list +oss_bucket: 'null' +version: 0.1.4 diff --git a/egs/alimeeting/diarization/sond/config_fbank.yaml b/egs/alimeeting/diarization/sond/config_fbank.yaml new file mode 100644 index 000000000..cb4b8a921 --- /dev/null +++ b/egs/alimeeting/diarization/sond/config_fbank.yaml @@ -0,0 +1,2728 @@ +config: finetune.yaml +print_config: false +log_level: INFO +dry_run: false +iterator_type: sequence +output_dir: exp/sond +ngpu: 1 +seed: 0 +num_workers: 16 +num_att_plot: 0 +dist_backend: nccl +dist_init_method: env:// +dist_world_size: null +dist_rank: null +local_rank: 0 +dist_master_addr: null +dist_master_port: null +dist_launcher: null +multiprocessing_distributed: true +distributed: false +unused_parameters: true +sharded_ddp: false +ddp_backend: pytorch_ddp +cudnn_enabled: true +cudnn_benchmark: false +cudnn_deterministic: true +collect_stats: false +write_collected_feats: false +max_epoch: 50 +patience: null +val_scheduler_criterion: +- valid +- acc +early_stopping_criterion: +- valid +- loss +- min +best_model_criterion: +- - valid + - acc + - max +keep_nbest_models: 10 +nbest_averaging_interval: 0 +grad_clip: 5 +grad_clip_type: 2.0 +grad_noise: false +accum_grad: 1 +no_forward_run: false +resume: true +train_dtype: float32 +use_amp: false +log_interval: 50 +use_matplotlib: false +use_tensorboard: true +use_wandb: false +wandb_project: null +wandb_id: null +wandb_entity: null +wandb_name: null +wandb_model_log_interval: -1 +use_pai: true +detect_anomaly: false +pretrain_path: null +init_param: [] +ignore_init_mismatch: false +freeze_param: [] +num_iters_per_epoch: null +batch_size: 20 +valid_batch_size: null +batch_bins: 10000 +valid_batch_bins: null +train_shape_file: +- /data/volume1/youyan/aishell/ark/train/speech_shape.1 +- /data/volume1/youyan/aishell/ark/train/text_shape.1 +valid_shape_file: +- /data/volume1/youyan/aishell/ark/dev/speech_shape.1 +- /data/volume1/youyan/aishell/ark/dev/text_shape.1 +batch_type: length +valid_batch_type: null +fold_length: +- 512 +- 150 +sort_in_batch: descending +sort_batch: descending +multiple_iterator: false +chunk_length: 500 +chunk_shift_ratio: 0.5 +num_cache_chunks: 1024 +train_data_path_and_name_and_type: +- - /data/volume1/youyan/aishell/ark/train/data.scp + - speech + - kaldi_ark +- - /data/volume1/youyan/aishell/ark/train/data.text.1 + - text + - text +valid_data_path_and_name_and_type: +- - /data/volume1/youyan/aishell/ark/dev/data.scp + - speech + - kaldi_ark +- - /data/volume1/youyan/aishell/ark/dev/data.text.1 + - text + - text +allow_variable_data_keys: false +max_cache_size: 0.0 +max_cache_fd: 32 +valid_max_cache_size: null +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 30000 +token_list: +- '0' +- '1' +- '2' +- '3' +- '4' +- '5' +- '6' +- '7' +- '8' +- '9' +- '10' +- '11' +- '12' +- '13' +- '14' +- '15' +- '16' +- '17' +- '18' +- '19' +- '20' +- '21' +- '22' +- '23' +- '24' +- '25' +- '26' +- '27' +- '28' +- '29' +- '30' +- '32' +- '33' +- '34' +- '35' +- '36' +- '37' +- '38' +- '39' +- '40' +- '41' +- '42' +- '43' +- '44' +- '45' +- '46' +- '48' +- '49' +- '50' +- '51' +- '52' +- '53' +- '54' +- '56' +- '57' +- '58' +- '60' +- '64' +- '65' +- '66' +- '67' +- '68' +- '69' +- '70' +- '71' +- '72' +- '73' +- '74' +- '75' +- '76' +- '77' +- '78' +- '80' +- '81' +- '82' +- '83' +- '84' +- '85' +- '86' +- '88' +- '89' +- '90' +- '92' +- '96' +- '97' +- '98' +- '99' +- '100' +- '101' +- '102' +- '104' +- '105' +- '106' +- '108' +- '112' +- '113' +- '114' +- '116' +- '120' +- '128' +- '129' +- '130' +- '131' +- '132' +- '133' +- '134' +- '135' +- '136' +- '137' +- '138' +- '139' +- '140' +- '141' +- '142' +- '144' +- '145' +- '146' +- '147' +- '148' +- '149' +- '150' +- '152' +- '153' +- '154' +- '156' +- '160' +- '161' +- '162' +- '163' +- '164' +- '165' +- '166' +- '168' +- '169' +- '170' +- '172' +- '176' +- '177' +- '178' +- '180' +- '184' +- '192' +- '193' +- '194' +- '195' +- '196' +- '197' +- '198' +- '200' +- '201' +- '202' +- '204' +- '208' +- '209' +- '210' +- '212' +- '216' +- '224' +- '225' +- '226' +- '228' +- '232' +- '240' +- '256' +- '257' +- '258' +- '259' +- '260' +- '261' +- '262' +- '263' +- '264' +- '265' +- '266' +- '267' +- '268' +- '269' +- '270' +- '272' +- '273' +- '274' +- '275' +- '276' +- '277' +- '278' +- '280' +- '281' +- '282' +- '284' +- '288' +- '289' +- '290' +- '291' +- '292' +- '293' +- '294' +- '296' +- '297' +- '298' +- '300' +- '304' +- '305' +- '306' +- '308' +- '312' +- '320' +- '321' +- '322' +- '323' +- '324' +- '325' +- '326' +- '328' +- '329' +- '330' +- '332' +- '336' +- '337' +- '338' +- '340' +- '344' +- '352' +- '353' +- '354' +- '356' +- '360' +- '368' +- '384' +- '385' +- '386' +- '387' +- '388' +- '389' +- '390' +- '392' +- '393' +- '394' +- '396' +- '400' +- '401' +- '402' +- '404' +- '408' +- '416' +- '417' +- '418' +- '420' +- '424' +- '432' +- '448' +- '449' +- '450' +- '452' +- '456' +- '464' +- '480' +- '512' +- '513' +- '514' +- '515' +- '516' +- '517' +- '518' +- '519' +- '520' +- '521' +- '522' +- '523' +- '524' +- '525' +- '526' +- '528' +- '529' +- '530' +- '531' +- '532' +- '533' +- '534' +- '536' +- '537' +- '538' +- '540' +- '544' +- '545' +- '546' +- '547' +- '548' +- '549' +- '550' +- '552' +- '553' +- '554' +- '556' +- '560' +- '561' +- '562' +- '564' +- '568' +- '576' +- '577' +- '578' +- '579' +- '580' +- '581' +- '582' +- '584' +- '585' +- '586' +- '588' +- '592' +- '593' +- '594' +- '596' +- '600' +- '608' +- '609' +- '610' +- '612' +- '616' +- '624' +- '640' +- '641' +- '642' +- '643' +- '644' +- '645' +- '646' +- '648' +- '649' +- '650' +- '652' +- '656' +- '657' +- '658' +- '660' +- '664' +- '672' +- '673' +- '674' +- '676' +- '680' +- '688' +- '704' +- '705' +- '706' +- '708' +- '712' +- '720' +- '736' +- '768' +- '769' +- '770' +- '771' +- '772' +- '773' +- '774' +- '776' +- '777' +- '778' +- '780' +- '784' +- '785' +- '786' +- '788' +- '792' +- '800' +- '801' +- '802' +- '804' +- '808' +- '816' +- '832' +- '833' +- '834' +- '836' +- '840' +- '848' +- '864' +- '896' +- '897' +- '898' +- '900' +- '904' +- '912' +- '928' +- '960' +- '1024' +- '1025' +- '1026' +- '1027' +- '1028' +- '1029' +- '1030' +- '1031' +- '1032' +- '1033' +- '1034' +- '1035' +- '1036' +- '1037' +- '1038' +- '1040' +- '1041' +- '1042' +- '1043' +- '1044' +- '1045' +- '1046' +- '1048' +- '1049' +- '1050' +- '1052' +- '1056' +- '1057' +- '1058' +- '1059' +- '1060' +- '1061' +- '1062' +- '1064' +- '1065' +- '1066' +- '1068' +- '1072' +- '1073' +- '1074' +- '1076' +- '1080' +- '1088' +- '1089' +- '1090' +- '1091' +- '1092' +- '1093' +- '1094' +- '1096' +- '1097' +- '1098' +- '1100' +- '1104' +- '1105' +- '1106' +- '1108' +- '1112' +- '1120' +- '1121' +- '1122' +- '1124' +- '1128' +- '1136' +- '1152' +- '1153' +- '1154' +- '1155' +- '1156' +- '1157' +- '1158' +- '1160' +- '1161' +- '1162' +- '1164' +- '1168' +- '1169' +- '1170' +- '1172' +- '1176' +- '1184' +- '1185' +- '1186' +- '1188' +- '1192' +- '1200' +- '1216' +- '1217' +- '1218' +- '1220' +- '1224' +- '1232' +- '1248' +- '1280' +- '1281' +- '1282' +- '1283' +- '1284' +- '1285' +- '1286' +- '1288' +- '1289' +- '1290' +- '1292' +- '1296' +- '1297' +- '1298' +- '1300' +- '1304' +- '1312' +- '1313' +- '1314' +- '1316' +- '1320' +- '1328' +- '1344' +- '1345' +- '1346' +- '1348' +- '1352' +- '1360' +- '1376' +- '1408' +- '1409' +- '1410' +- '1412' +- '1416' +- '1424' +- '1440' +- '1472' +- '1536' +- '1537' +- '1538' +- '1539' +- '1540' +- '1541' +- '1542' +- '1544' +- '1545' +- '1546' +- '1548' +- '1552' +- '1553' +- '1554' +- '1556' +- '1560' +- '1568' +- '1569' +- '1570' +- '1572' +- '1576' +- '1584' +- '1600' +- '1601' +- '1602' +- '1604' +- '1608' +- '1616' +- '1632' +- '1664' +- '1665' +- '1666' +- '1668' +- '1672' +- '1680' +- '1696' +- '1728' +- '1792' +- '1793' +- '1794' +- '1796' +- '1800' +- '1808' +- '1824' +- '1856' +- '1920' +- '2048' +- '2049' +- '2050' +- '2051' +- '2052' +- '2053' +- '2054' +- '2055' +- '2056' +- '2057' +- '2058' +- '2059' +- '2060' +- '2061' +- '2062' +- '2064' +- '2065' +- '2066' +- '2067' +- '2068' +- '2069' +- '2070' +- '2072' +- '2073' +- '2074' +- '2076' +- '2080' +- '2081' +- '2082' +- '2083' +- '2084' +- '2085' +- '2086' +- '2088' +- '2089' +- '2090' +- '2092' +- '2096' +- '2097' +- '2098' +- '2100' +- '2104' +- '2112' +- '2113' +- '2114' +- '2115' +- '2116' +- '2117' +- '2118' +- '2120' +- '2121' +- '2122' +- '2124' +- '2128' +- '2129' +- '2130' +- '2132' +- '2136' +- '2144' +- '2145' +- '2146' +- '2148' +- '2152' +- '2160' +- '2176' +- '2177' +- '2178' +- '2179' +- '2180' +- '2181' +- '2182' +- '2184' +- '2185' +- '2186' +- '2188' +- '2192' +- '2193' +- '2194' +- '2196' +- '2200' +- '2208' +- '2209' +- '2210' +- '2212' +- '2216' +- '2224' +- '2240' +- '2241' +- '2242' +- '2244' +- '2248' +- '2256' +- '2272' +- '2304' +- '2305' +- '2306' +- '2307' +- '2308' +- '2309' +- '2310' +- '2312' +- '2313' +- '2314' +- '2316' +- '2320' +- '2321' +- '2322' +- '2324' +- '2328' +- '2336' +- '2337' +- '2338' +- '2340' +- '2344' +- '2352' +- '2368' +- '2369' +- '2370' +- '2372' +- '2376' +- '2384' +- '2400' +- '2432' +- '2433' +- '2434' +- '2436' +- '2440' +- '2448' +- '2464' +- '2496' +- '2560' +- '2561' +- '2562' +- '2563' +- '2564' +- '2565' +- '2566' +- '2568' +- '2569' +- '2570' +- '2572' +- '2576' +- '2577' +- '2578' +- '2580' +- '2584' +- '2592' +- '2593' +- '2594' +- '2596' +- '2600' +- '2608' +- '2624' +- '2625' +- '2626' +- '2628' +- '2632' +- '2640' +- '2656' +- '2688' +- '2689' +- '2690' +- '2692' +- '2696' +- '2704' +- '2720' +- '2752' +- '2816' +- '2817' +- '2818' +- '2820' +- '2824' +- '2832' +- '2848' +- '2880' +- '2944' +- '3072' +- '3073' +- '3074' +- '3075' +- '3076' +- '3077' +- '3078' +- '3080' +- '3081' +- '3082' +- '3084' +- '3088' +- '3089' +- '3090' +- '3092' +- '3096' +- '3104' +- '3105' +- '3106' +- '3108' +- '3112' +- '3120' +- '3136' +- '3137' +- '3138' +- '3140' +- '3144' +- '3152' +- '3168' +- '3200' +- '3201' +- '3202' +- '3204' +- '3208' +- '3216' +- '3232' +- '3264' +- '3328' +- '3329' +- '3330' +- '3332' +- '3336' +- '3344' +- '3360' +- '3392' +- '3456' +- '3584' +- '3585' +- '3586' +- '3588' +- '3592' +- '3600' +- '3616' +- '3648' +- '3712' +- '3840' +- '4096' +- '4097' +- '4098' +- '4099' +- '4100' +- '4101' +- '4102' +- '4103' +- '4104' +- '4105' +- '4106' +- '4107' +- '4108' +- '4109' +- '4110' +- '4112' +- '4113' +- '4114' +- '4115' +- '4116' +- '4117' +- '4118' +- '4120' +- '4121' +- '4122' +- '4124' +- '4128' +- '4129' +- '4130' +- '4131' +- '4132' +- '4133' +- '4134' +- '4136' +- '4137' +- '4138' +- '4140' +- '4144' +- '4145' +- '4146' +- '4148' +- '4152' +- '4160' +- '4161' +- '4162' +- '4163' +- '4164' +- '4165' +- '4166' +- '4168' +- '4169' +- '4170' +- '4172' +- '4176' +- '4177' +- '4178' +- '4180' +- '4184' +- '4192' +- '4193' +- '4194' +- '4196' +- '4200' +- '4208' +- '4224' +- '4225' +- '4226' +- '4227' +- '4228' +- '4229' +- '4230' +- '4232' +- '4233' +- '4234' +- '4236' +- '4240' +- '4241' +- '4242' +- '4244' +- '4248' +- '4256' +- '4257' +- '4258' +- '4260' +- '4264' +- '4272' +- '4288' +- '4289' +- '4290' +- '4292' +- '4296' +- '4304' +- '4320' +- '4352' +- '4353' +- '4354' +- '4355' +- '4356' +- '4357' +- '4358' +- '4360' +- '4361' +- '4362' +- '4364' +- '4368' +- '4369' +- '4370' +- '4372' +- '4376' +- '4384' +- '4385' +- '4386' +- '4388' +- '4392' +- '4400' +- '4416' +- '4417' +- '4418' +- '4420' +- '4424' +- '4432' +- '4448' +- '4480' +- '4481' +- '4482' +- '4484' +- '4488' +- '4496' +- '4512' +- '4544' +- '4608' +- '4609' +- '4610' +- '4611' +- '4612' +- '4613' +- '4614' +- '4616' +- '4617' +- '4618' +- '4620' +- '4624' +- '4625' +- '4626' +- '4628' +- '4632' +- '4640' +- '4641' +- '4642' +- '4644' +- '4648' +- '4656' +- '4672' +- '4673' +- '4674' +- '4676' +- '4680' +- '4688' +- '4704' +- '4736' +- '4737' +- '4738' +- '4740' +- '4744' +- '4752' +- '4768' +- '4800' +- '4864' +- '4865' +- '4866' +- '4868' +- '4872' +- '4880' +- '4896' +- '4928' +- '4992' +- '5120' +- '5121' +- '5122' +- '5123' +- '5124' +- '5125' +- '5126' +- '5128' +- '5129' +- '5130' +- '5132' +- '5136' +- '5137' +- '5138' +- '5140' +- '5144' +- '5152' +- '5153' +- '5154' +- '5156' +- '5160' +- '5168' +- '5184' +- '5185' +- '5186' +- '5188' +- '5192' +- '5200' +- '5216' +- '5248' +- '5249' +- '5250' +- '5252' +- '5256' +- '5264' +- '5280' +- '5312' +- '5376' +- '5377' +- '5378' +- '5380' +- '5384' +- '5392' +- '5408' +- '5440' +- '5504' +- '5632' +- '5633' +- '5634' +- '5636' +- '5640' +- '5648' +- '5664' +- '5696' +- '5760' +- '5888' +- '6144' +- '6145' +- '6146' +- '6147' +- '6148' +- '6149' +- '6150' +- '6152' +- '6153' +- '6154' +- '6156' +- '6160' +- '6161' +- '6162' +- '6164' +- '6168' +- '6176' +- '6177' +- '6178' +- '6180' +- '6184' +- '6192' +- '6208' +- '6209' +- '6210' +- '6212' +- '6216' +- '6224' +- '6240' +- '6272' +- '6273' +- '6274' +- '6276' +- '6280' +- '6288' +- '6304' +- '6336' +- '6400' +- '6401' +- '6402' +- '6404' +- '6408' +- '6416' +- '6432' +- '6464' +- '6528' +- '6656' +- '6657' +- '6658' +- '6660' +- '6664' +- '6672' +- '6688' +- '6720' +- '6784' +- '6912' +- '7168' +- '7169' +- '7170' +- '7172' +- '7176' +- '7184' +- '7200' +- '7232' +- '7296' +- '7424' +- '7680' +- '8192' +- '8193' +- '8194' +- '8195' +- '8196' +- '8197' +- '8198' +- '8199' +- '8200' +- '8201' +- '8202' +- '8203' +- '8204' +- '8205' +- '8206' +- '8208' +- '8209' +- '8210' +- '8211' +- '8212' +- '8213' +- '8214' +- '8216' +- '8217' +- '8218' +- '8220' +- '8224' +- '8225' +- '8226' +- '8227' +- '8228' +- '8229' +- '8230' +- '8232' +- '8233' +- '8234' +- '8236' +- '8240' +- '8241' +- '8242' +- '8244' +- '8248' +- '8256' +- '8257' +- '8258' +- '8259' +- '8260' +- '8261' +- '8262' +- '8264' +- '8265' +- '8266' +- '8268' +- '8272' +- '8273' +- '8274' +- '8276' +- '8280' +- '8288' +- '8289' +- '8290' +- '8292' +- '8296' +- '8304' +- '8320' +- '8321' +- '8322' +- '8323' +- '8324' +- '8325' +- '8326' +- '8328' +- '8329' +- '8330' +- '8332' +- '8336' +- '8337' +- '8338' +- '8340' +- '8344' +- '8352' +- '8353' +- '8354' +- '8356' +- '8360' +- '8368' +- '8384' +- '8385' +- '8386' +- '8388' +- '8392' +- '8400' +- '8416' +- '8448' +- '8449' +- '8450' +- '8451' +- '8452' +- '8453' +- '8454' +- '8456' +- '8457' +- '8458' +- '8460' +- '8464' +- '8465' +- '8466' +- '8468' +- '8472' +- '8480' +- '8481' +- '8482' +- '8484' +- '8488' +- '8496' +- '8512' +- '8513' +- '8514' +- '8516' +- '8520' +- '8528' +- '8544' +- '8576' +- '8577' +- '8578' +- '8580' +- '8584' +- '8592' +- '8608' +- '8640' +- '8704' +- '8705' +- '8706' +- '8707' +- '8708' +- '8709' +- '8710' +- '8712' +- '8713' +- '8714' +- '8716' +- '8720' +- '8721' +- '8722' +- '8724' +- '8728' +- '8736' +- '8737' +- '8738' +- '8740' +- '8744' +- '8752' +- '8768' +- '8769' +- '8770' +- '8772' +- '8776' +- '8784' +- '8800' +- '8832' +- '8833' +- '8834' +- '8836' +- '8840' +- '8848' +- '8864' +- '8896' +- '8960' +- '8961' +- '8962' +- '8964' +- '8968' +- '8976' +- '8992' +- '9024' +- '9088' +- '9216' +- '9217' +- '9218' +- '9219' +- '9220' +- '9221' +- '9222' +- '9224' +- '9225' +- '9226' +- '9228' +- '9232' +- '9233' +- '9234' +- '9236' +- '9240' +- '9248' +- '9249' +- '9250' +- '9252' +- '9256' +- '9264' +- '9280' +- '9281' +- '9282' +- '9284' +- '9288' +- '9296' +- '9312' +- '9344' +- '9345' +- '9346' +- '9348' +- '9352' +- '9360' +- '9376' +- '9408' +- '9472' +- '9473' +- '9474' +- '9476' +- '9480' +- '9488' +- '9504' +- '9536' +- '9600' +- '9728' +- '9729' +- '9730' +- '9732' +- '9736' +- '9744' +- '9760' +- '9792' +- '9856' +- '9984' +- '10240' +- '10241' +- '10242' +- '10243' +- '10244' +- '10245' +- '10246' +- '10248' +- '10249' +- '10250' +- '10252' +- '10256' +- '10257' +- '10258' +- '10260' +- '10264' +- '10272' +- '10273' +- '10274' +- '10276' +- '10280' +- '10288' +- '10304' +- '10305' +- '10306' +- '10308' +- '10312' +- '10320' +- '10336' +- '10368' +- '10369' +- '10370' +- '10372' +- '10376' +- '10384' +- '10400' +- '10432' +- '10496' +- '10497' +- '10498' +- '10500' +- '10504' +- '10512' +- '10528' +- '10560' +- '10624' +- '10752' +- '10753' +- '10754' +- '10756' +- '10760' +- '10768' +- '10784' +- '10816' +- '10880' +- '11008' +- '11264' +- '11265' +- '11266' +- '11268' +- '11272' +- '11280' +- '11296' +- '11328' +- '11392' +- '11520' +- '11776' +- '12288' +- '12289' +- '12290' +- '12291' +- '12292' +- '12293' +- '12294' +- '12296' +- '12297' +- '12298' +- '12300' +- '12304' +- '12305' +- '12306' +- '12308' +- '12312' +- '12320' +- '12321' +- '12322' +- '12324' +- '12328' +- '12336' +- '12352' +- '12353' +- '12354' +- '12356' +- '12360' +- '12368' +- '12384' +- '12416' +- '12417' +- '12418' +- '12420' +- '12424' +- '12432' +- '12448' +- '12480' +- '12544' +- '12545' +- '12546' +- '12548' +- '12552' +- '12560' +- '12576' +- '12608' +- '12672' +- '12800' +- '12801' +- '12802' +- '12804' +- '12808' +- '12816' +- '12832' +- '12864' +- '12928' +- '13056' +- '13312' +- '13313' +- '13314' +- '13316' +- '13320' +- '13328' +- '13344' +- '13376' +- '13440' +- '13568' +- '13824' +- '14336' +- '14337' +- '14338' +- '14340' +- '14344' +- '14352' +- '14368' +- '14400' +- '14464' +- '14592' +- '14848' +- '15360' +- '16384' +- '16385' +- '16386' +- '16387' +- '16388' +- '16389' +- '16390' +- '16391' +- '16392' +- '16393' +- '16394' +- '16395' +- '16396' +- '16397' +- '16398' +- '16400' +- '16401' +- '16402' +- '16403' +- '16404' +- '16405' +- '16406' +- '16408' +- '16409' +- '16410' +- '16412' +- '16416' +- '16417' +- '16418' +- '16419' +- '16420' +- '16421' +- '16422' +- '16424' +- '16425' +- '16426' +- '16428' +- '16432' +- '16433' +- '16434' +- '16436' +- '16440' +- '16448' +- '16449' +- '16450' +- '16451' +- '16452' +- '16453' +- '16454' +- '16456' +- '16457' +- '16458' +- '16460' +- '16464' +- '16465' +- '16466' +- '16468' +- '16472' +- '16480' +- '16481' +- '16482' +- '16484' +- '16488' +- '16496' +- '16512' +- '16513' +- '16514' +- '16515' +- '16516' +- '16517' +- '16518' +- '16520' +- '16521' +- '16522' +- '16524' +- '16528' +- '16529' +- '16530' +- '16532' +- '16536' +- '16544' +- '16545' +- '16546' +- '16548' +- '16552' +- '16560' +- '16576' +- '16577' +- '16578' +- '16580' +- '16584' +- '16592' +- '16608' +- '16640' +- '16641' +- '16642' +- '16643' +- '16644' +- '16645' +- '16646' +- '16648' +- '16649' +- '16650' +- '16652' +- '16656' +- '16657' +- '16658' +- '16660' +- '16664' +- '16672' +- '16673' +- '16674' +- '16676' +- '16680' +- '16688' +- '16704' +- '16705' +- '16706' +- '16708' +- '16712' +- '16720' +- '16736' +- '16768' +- '16769' +- '16770' +- '16772' +- '16776' +- '16784' +- '16800' +- '16832' +- '16896' +- '16897' +- '16898' +- '16899' +- '16900' +- '16901' +- '16902' +- '16904' +- '16905' +- '16906' +- '16908' +- '16912' +- '16913' +- '16914' +- '16916' +- '16920' +- '16928' +- '16929' +- '16930' +- '16932' +- '16936' +- '16944' +- '16960' +- '16961' +- '16962' +- '16964' +- '16968' +- '16976' +- '16992' +- '17024' +- '17025' +- '17026' +- '17028' +- '17032' +- '17040' +- '17056' +- '17088' +- '17152' +- '17153' +- '17154' +- '17156' +- '17160' +- '17168' +- '17184' +- '17216' +- '17280' +- '17408' +- '17409' +- '17410' +- '17411' +- '17412' +- '17413' +- '17414' +- '17416' +- '17417' +- '17418' +- '17420' +- '17424' +- '17425' +- '17426' +- '17428' +- '17432' +- '17440' +- '17441' +- '17442' +- '17444' +- '17448' +- '17456' +- '17472' +- '17473' +- '17474' +- '17476' +- '17480' +- '17488' +- '17504' +- '17536' +- '17537' +- '17538' +- '17540' +- '17544' +- '17552' +- '17568' +- '17600' +- '17664' +- '17665' +- '17666' +- '17668' +- '17672' +- '17680' +- '17696' +- '17728' +- '17792' +- '17920' +- '17921' +- '17922' +- '17924' +- '17928' +- '17936' +- '17952' +- '17984' +- '18048' +- '18176' +- '18432' +- '18433' +- '18434' +- '18435' +- '18436' +- '18437' +- '18438' +- '18440' +- '18441' +- '18442' +- '18444' +- '18448' +- '18449' +- '18450' +- '18452' +- '18456' +- '18464' +- '18465' +- '18466' +- '18468' +- '18472' +- '18480' +- '18496' +- '18497' +- '18498' +- '18500' +- '18504' +- '18512' +- '18528' +- '18560' +- '18561' +- '18562' +- '18564' +- '18568' +- '18576' +- '18592' +- '18624' +- '18688' +- '18689' +- '18690' +- '18692' +- '18696' +- '18704' +- '18720' +- '18752' +- '18816' +- '18944' +- '18945' +- '18946' +- '18948' +- '18952' +- '18960' +- '18976' +- '19008' +- '19072' +- '19200' +- '19456' +- '19457' +- '19458' +- '19460' +- '19464' +- '19472' +- '19488' +- '19520' +- '19584' +- '19712' +- '19968' +- '20480' +- '20481' +- '20482' +- '20483' +- '20484' +- '20485' +- '20486' +- '20488' +- '20489' +- '20490' +- '20492' +- '20496' +- '20497' +- '20498' +- '20500' +- '20504' +- '20512' +- '20513' +- '20514' +- '20516' +- '20520' +- '20528' +- '20544' +- '20545' +- '20546' +- '20548' +- '20552' +- '20560' +- '20576' +- '20608' +- '20609' +- '20610' +- '20612' +- '20616' +- '20624' +- '20640' +- '20672' +- '20736' +- '20737' +- '20738' +- '20740' +- '20744' +- '20752' +- '20768' +- '20800' +- '20864' +- '20992' +- '20993' +- '20994' +- '20996' +- '21000' +- '21008' +- '21024' +- '21056' +- '21120' +- '21248' +- '21504' +- '21505' +- '21506' +- '21508' +- '21512' +- '21520' +- '21536' +- '21568' +- '21632' +- '21760' +- '22016' +- '22528' +- '22529' +- '22530' +- '22532' +- '22536' +- '22544' +- '22560' +- '22592' +- '22656' +- '22784' +- '23040' +- '23552' +- '24576' +- '24577' +- '24578' +- '24579' +- '24580' +- '24581' +- '24582' +- '24584' +- '24585' +- '24586' +- '24588' +- '24592' +- '24593' +- '24594' +- '24596' +- '24600' +- '24608' +- '24609' +- '24610' +- '24612' +- '24616' +- '24624' +- '24640' +- '24641' +- '24642' +- '24644' +- '24648' +- '24656' +- '24672' +- '24704' +- '24705' +- '24706' +- '24708' +- '24712' +- '24720' +- '24736' +- '24768' +- '24832' +- '24833' +- '24834' +- '24836' +- '24840' +- '24848' +- '24864' +- '24896' +- '24960' +- '25088' +- '25089' +- '25090' +- '25092' +- '25096' +- '25104' +- '25120' +- '25152' +- '25216' +- '25344' +- '25600' +- '25601' +- '25602' +- '25604' +- '25608' +- '25616' +- '25632' +- '25664' +- '25728' +- '25856' +- '26112' +- '26624' +- '26625' +- '26626' +- '26628' +- '26632' +- '26640' +- '26656' +- '26688' +- '26752' +- '26880' +- '27136' +- '27648' +- '28672' +- '28673' +- '28674' +- '28676' +- '28680' +- '28688' +- '28704' +- '28736' +- '28800' +- '28928' +- '29184' +- '29696' +- '30720' +- '32768' +- '32769' +- '32770' +- '32771' +- '32772' +- '32773' +- '32774' +- '32775' +- '32776' +- '32777' +- '32778' +- '32779' +- '32780' +- '32781' +- '32782' +- '32784' +- '32785' +- '32786' +- '32787' +- '32788' +- '32789' +- '32790' +- '32792' +- '32793' +- '32794' +- '32796' +- '32800' +- '32801' +- '32802' +- '32803' +- '32804' +- '32805' +- '32806' +- '32808' +- '32809' +- '32810' +- '32812' +- '32816' +- '32817' +- '32818' +- '32820' +- '32824' +- '32832' +- '32833' +- '32834' +- '32835' +- '32836' +- '32837' +- '32838' +- '32840' +- '32841' +- '32842' +- '32844' +- '32848' +- '32849' +- '32850' +- '32852' +- '32856' +- '32864' +- '32865' +- '32866' +- '32868' +- '32872' +- '32880' +- '32896' +- '32897' +- '32898' +- '32899' +- '32900' +- '32901' +- '32902' +- '32904' +- '32905' +- '32906' +- '32908' +- '32912' +- '32913' +- '32914' +- '32916' +- '32920' +- '32928' +- '32929' +- '32930' +- '32932' +- '32936' +- '32944' +- '32960' +- '32961' +- '32962' +- '32964' +- '32968' +- '32976' +- '32992' +- '33024' +- '33025' +- '33026' +- '33027' +- '33028' +- '33029' +- '33030' +- '33032' +- '33033' +- '33034' +- '33036' +- '33040' +- '33041' +- '33042' +- '33044' +- '33048' +- '33056' +- '33057' +- '33058' +- '33060' +- '33064' +- '33072' +- '33088' +- '33089' +- '33090' +- '33092' +- '33096' +- '33104' +- '33120' +- '33152' +- '33153' +- '33154' +- '33156' +- '33160' +- '33168' +- '33184' +- '33216' +- '33280' +- '33281' +- '33282' +- '33283' +- '33284' +- '33285' +- '33286' +- '33288' +- '33289' +- '33290' +- '33292' +- '33296' +- '33297' +- '33298' +- '33300' +- '33304' +- '33312' +- '33313' +- '33314' +- '33316' +- '33320' +- '33328' +- '33344' +- '33345' +- '33346' +- '33348' +- '33352' +- '33360' +- '33376' +- '33408' +- '33409' +- '33410' +- '33412' +- '33416' +- '33424' +- '33440' +- '33472' +- '33536' +- '33537' +- '33538' +- '33540' +- '33544' +- '33552' +- '33568' +- '33600' +- '33664' +- '33792' +- '33793' +- '33794' +- '33795' +- '33796' +- '33797' +- '33798' +- '33800' +- '33801' +- '33802' +- '33804' +- '33808' +- '33809' +- '33810' +- '33812' +- '33816' +- '33824' +- '33825' +- '33826' +- '33828' +- '33832' +- '33840' +- '33856' +- '33857' +- '33858' +- '33860' +- '33864' +- '33872' +- '33888' +- '33920' +- '33921' +- '33922' +- '33924' +- '33928' +- '33936' +- '33952' +- '33984' +- '34048' +- '34049' +- '34050' +- '34052' +- '34056' +- '34064' +- '34080' +- '34112' +- '34176' +- '34304' +- '34305' +- '34306' +- '34308' +- '34312' +- '34320' +- '34336' +- '34368' +- '34432' +- '34560' +- '34816' +- '34817' +- '34818' +- '34819' +- '34820' +- '34821' +- '34822' +- '34824' +- '34825' +- '34826' +- '34828' +- '34832' +- '34833' +- '34834' +- '34836' +- '34840' +- '34848' +- '34849' +- '34850' +- '34852' +- '34856' +- '34864' +- '34880' +- '34881' +- '34882' +- '34884' +- '34888' +- '34896' +- '34912' +- '34944' +- '34945' +- '34946' +- '34948' +- '34952' +- '34960' +- '34976' +- '35008' +- '35072' +- '35073' +- '35074' +- '35076' +- '35080' +- '35088' +- '35104' +- '35136' +- '35200' +- '35328' +- '35329' +- '35330' +- '35332' +- '35336' +- '35344' +- '35360' +- '35392' +- '35456' +- '35584' +- '35840' +- '35841' +- '35842' +- '35844' +- '35848' +- '35856' +- '35872' +- '35904' +- '35968' +- '36096' +- '36352' +- '36864' +- '36865' +- '36866' +- '36867' +- '36868' +- '36869' +- '36870' +- '36872' +- '36873' +- '36874' +- '36876' +- '36880' +- '36881' +- '36882' +- '36884' +- '36888' +- '36896' +- '36897' +- '36898' +- '36900' +- '36904' +- '36912' +- '36928' +- '36929' +- '36930' +- '36932' +- '36936' +- '36944' +- '36960' +- '36992' +- '36993' +- '36994' +- '36996' +- '37000' +- '37008' +- '37024' +- '37056' +- '37120' +- '37121' +- '37122' +- '37124' +- '37128' +- '37136' +- '37152' +- '37184' +- '37248' +- '37376' +- '37377' +- '37378' +- '37380' +- '37384' +- '37392' +- '37408' +- '37440' +- '37504' +- '37632' +- '37888' +- '37889' +- '37890' +- '37892' +- '37896' +- '37904' +- '37920' +- '37952' +- '38016' +- '38144' +- '38400' +- '38912' +- '38913' +- '38914' +- '38916' +- '38920' +- '38928' +- '38944' +- '38976' +- '39040' +- '39168' +- '39424' +- '39936' +- '40960' +- '40961' +- '40962' +- '40963' +- '40964' +- '40965' +- '40966' +- '40968' +- '40969' +- '40970' +- '40972' +- '40976' +- '40977' +- '40978' +- '40980' +- '40984' +- '40992' +- '40993' +- '40994' +- '40996' +- '41000' +- '41008' +- '41024' +- '41025' +- '41026' +- '41028' +- '41032' +- '41040' +- '41056' +- '41088' +- '41089' +- '41090' +- '41092' +- '41096' +- '41104' +- '41120' +- '41152' +- '41216' +- '41217' +- '41218' +- '41220' +- '41224' +- '41232' +- '41248' +- '41280' +- '41344' +- '41472' +- '41473' +- '41474' +- '41476' +- '41480' +- '41488' +- '41504' +- '41536' +- '41600' +- '41728' +- '41984' +- '41985' +- '41986' +- '41988' +- '41992' +- '42000' +- '42016' +- '42048' +- '42112' +- '42240' +- '42496' +- '43008' +- '43009' +- '43010' +- '43012' +- '43016' +- '43024' +- '43040' +- '43072' +- '43136' +- '43264' +- '43520' +- '44032' +- '45056' +- '45057' +- '45058' +- '45060' +- '45064' +- '45072' +- '45088' +- '45120' +- '45184' +- '45312' +- '45568' +- '46080' +- '47104' +- '49152' +- '49153' +- '49154' +- '49155' +- '49156' +- '49157' +- '49158' +- '49160' +- '49161' +- '49162' +- '49164' +- '49168' +- '49169' +- '49170' +- '49172' +- '49176' +- '49184' +- '49185' +- '49186' +- '49188' +- '49192' +- '49200' +- '49216' +- '49217' +- '49218' +- '49220' +- '49224' +- '49232' +- '49248' +- '49280' +- '49281' +- '49282' +- '49284' +- '49288' +- '49296' +- '49312' +- '49344' +- '49408' +- '49409' +- '49410' +- '49412' +- '49416' +- '49424' +- '49440' +- '49472' +- '49536' +- '49664' +- '49665' +- '49666' +- '49668' +- '49672' +- '49680' +- '49696' +- '49728' +- '49792' +- '49920' +- '50176' +- '50177' +- '50178' +- '50180' +- '50184' +- '50192' +- '50208' +- '50240' +- '50304' +- '50432' +- '50688' +- '51200' +- '51201' +- '51202' +- '51204' +- '51208' +- '51216' +- '51232' +- '51264' +- '51328' +- '51456' +- '51712' +- '52224' +- '53248' +- '53249' +- '53250' +- '53252' +- '53256' +- '53264' +- '53280' +- '53312' +- '53376' +- '53504' +- '53760' +- '54272' +- '55296' +- '57344' +- '57345' +- '57346' +- '57348' +- '57352' +- '57360' +- '57376' +- '57408' +- '57472' +- '57600' +- '57856' +- '58368' +- '59392' +- '61440' +init: null +input_size: 80 +ctc_conf: + dropout_rate: 0.0 + ctc_type: builtin + reduce: true + ignore_nan_grad: true +joint_net_conf: null +use_preprocessor: true +token_type: char +bpemodel: null +non_linguistic_symbols: null +cleaner: null +g2p: null +speech_volume_normalize: null +rir_scp: null +rir_apply_prob: 1.0 +noise_scp: null +noise_apply_prob: 1.0 +noise_db_range: '13_15' +frontend: null +frontend_conf: {} +specaug: null +specaug_conf: {} +normalize: null +normalize_conf: {} +label_aggregator: null +label_aggregator_conf: {} +model: sond +model_conf: + # ctc_weight: 0.0 + lsm_weight: 0.1 + length_normalized_loss: true + max_spk_num: 16 + # predictor_weight: 1.0 + # predictor_bias: 1 + # sampling_ratio: 0.75 +# speech encoder +encoder: resnet34 +encoder_conf: + # pass by model, equal to feature dim + # input_size: 80 + pooling_type: "window_shift" + pool_size: 20 + stride: 1 + tf2torch_tensor_name_prefix_torch: encoder + tf2torch_tensor_name_prefix_tf: EAND/speech_encoder +speaker_encoder: conv +speaker_encoder_conf: + input_units: 256 + num_layers: 3 + num_units: 256 + kernel_size: 1 + dropout_rate: 0.0 + position_encoder: null + out_units: 256 + out_norm: false + auxiliary_states: false + tf2torch_tensor_name_prefix_torch: speaker_encoder + tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder +ci_scorer: dot +ci_scorer_conf: {} +cd_scorer: san +cd_scorer_conf: + input_size: 512 + output_size: 512 + out_units: 1 + attention_heads: 4 + linear_units: 1024 + num_blocks: 4 + dropout_rate: 0.0 + positional_dropout_rate: 0.0 + attention_dropout_rate: 0.0 + # use string "null" to remove input layer + input_layer: "null" + pos_enc_class: null + normalize_before: true + tf2torch_tensor_name_prefix_torch: cd_scorer + tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer +# post net +decoder: fsmn +decoder_conf: + in_units: 32 + out_units: 2517 + filter_size: 31 + fsmn_num_layers: 6 + dnn_num_layers: 1 + num_memory_units: 512 + ffn_inner_dim: 512 + dropout_rate: 0.0 + tf2torch_tensor_name_prefix_torch: decoder + tf2torch_tensor_name_prefix_tf: EAND/post_net +num_worker_count: 1 +required: +- output_dir +- token_list +oss_bucket: 'null' +version: 0.1.6 diff --git a/egs/alimeeting/diarization/sond/infer_alimeeting_test.py b/egs/alimeeting/diarization/sond/infer_alimeeting_test.py new file mode 100644 index 000000000..0988f5d03 --- /dev/null +++ b/egs/alimeeting/diarization/sond/infer_alimeeting_test.py @@ -0,0 +1,24 @@ +from funasr.bin.diar_inference_launch import inference_launch +import sys + + +def main(): + diar_config_path = sys.argv[1] if len(sys.argv) > 1 else "sond_fbank.yaml" + diar_model_path = sys.argv[2] if len(sys.argv) > 2 else "sond.pth" + output_dir = sys.argv[3] if len(sys.argv) > 3 else "./outputs" + data_path_and_name_and_type = [ + ("data/test_rmsil/feats.scp", "speech", "kaldi_ark"), + ("data/test_rmsil/test_rmsil_tdnn6_xvec.scp", "profile", "kaldi_ark"), + ] + pipeline = inference_launch( + mode="sond", + diar_train_config=diar_config_path, + diar_model_file=diar_model_path, + output_dir=output_dir, + num_workers=1 + ) + pipeline(data_path_and_name_and_type) + + +if __name__ == '__main__': + main() diff --git a/egs/alimeeting/diarization/sond/local/convert_label_to_rttm.py b/egs/alimeeting/diarization/sond/local/convert_label_to_rttm.py new file mode 100644 index 000000000..880f60fe7 --- /dev/null +++ b/egs/alimeeting/diarization/sond/local/convert_label_to_rttm.py @@ -0,0 +1,132 @@ +import os +from funasr.utils.job_runner import MultiProcessRunnerV3 +import numpy as np +from funasr.utils.misc import load_scp_as_list, load_scp_as_dict +from collections import OrderedDict +from tqdm import tqdm +from scipy.ndimage import median_filter + + +class MyRunner(MultiProcessRunnerV3): + def prepare(self, parser): + parser.add_argument("label_txt", type=str) + parser.add_argument("map_scp", type=str) + parser.add_argument("out_rttm", type=str) + parser.add_argument("--n_spk", type=int, default=4) + parser.add_argument("--chunk_len", type=int, default=1600) + parser.add_argument("--shift_len", type=int, default=400) + parser.add_argument("--ignore_len", type=int, default=5) + parser.add_argument("--smooth_size", type=int, default=7) + parser.add_argument("--vote_prob", type=float, default=0.5) + args = parser.parse_args() + + if not os.path.exists(os.path.dirname(args.out_rttm)): + os.makedirs(os.path.dirname(args.out_rttm)) + + utt2labels = load_scp_as_list(args.label_txt, 'list') + utt2labels = sorted(utt2labels, key=lambda x: x[0]) + meeting2map = load_scp_as_dict(args.map_scp) + meeting2labels = OrderedDict() + for utt_id, chunk_label in utt2labels: + mid = utt_id.split("-")[0] + if mid not in meeting2labels: + meeting2labels[mid] = [] + meeting2labels[mid].append(chunk_label) + task_list = [(mid, labels, meeting2map[mid]) for mid, labels in meeting2labels.items()] + + return task_list, None, args + + def post(self, result_list, args): + with open(args.out_rttm, "wt") as fd: + for results in result_list: + fd.writelines(results) + + +def int2vec(x, vec_dim=8, dtype=np.int): + b = ('{:0' + str(vec_dim) + 'b}').format(x) + # little-endian order: lower bit first + return (np.array(list(b)[::-1]) == '1').astype(dtype) + + +def seq2arr(seq, vec_dim=8): + return np.row_stack([int2vec(int(x), vec_dim) for x in seq]) + + +def sample2ms(sample, sr=16000): + return int(float(sample) / sr * 100) + + +def calc_multi_labels(chunk_label_list, chunk_len, shift_len, n_spk, vote_prob=0.5): + n_chunk = len(chunk_label_list) + last_chunk_valid_frame = len(chunk_label_list[-1]) - (chunk_len - shift_len) + n_frame = (n_chunk - 2) * shift_len + chunk_len + last_chunk_valid_frame + multi_labels = np.zeros((n_frame, n_spk), dtype=float) + weight = np.zeros((n_frame, 1), dtype=float) + for i in range(n_chunk): + raw_label = chunk_label_list[i] + for k in range(len(raw_label)): + if raw_label[k] == '': + raw_label[k] = raw_label[k-1] if k > 0 else '0' + chunk_multi_label = seq2arr(raw_label, n_spk) + chunk_len = chunk_multi_label.shape[0] + multi_labels[i*shift_len:i*shift_len+chunk_len, :] += chunk_multi_label + weight[i*shift_len:i*shift_len+chunk_len, :] += 1 + multi_labels = multi_labels / weight # normalizing vote + multi_labels = (multi_labels > vote_prob).astype(int) # voting results + return multi_labels + + +def calc_spk_turns(label_arr, spk_list): + turn_list = [] + length = label_arr.shape[0] + n_spk = label_arr.shape[1] + for k in range(n_spk): + if spk_list[k] == "None": + continue + in_utt = False + start = 0 + for i in range(length): + if label_arr[i, k] == 1 and in_utt is False: + start = i + in_utt = True + if label_arr[i, k] == 0 and in_utt is True: + turn_list.append([spk_list[k], start, i - start]) + in_utt = False + if in_utt: + turn_list.append([spk_list[k], start, length - start]) + return turn_list + + +def smooth_multi_labels(multi_label, win_len): + multi_label = median_filter(multi_label, (win_len, 1), mode="constant", cval=0.0).astype(int) + return multi_label + + +def process(task_args): + _, task_list, _, args = task_args + spk_list = ["spk{}".format(i+1) for i in range(args.n_spk)] + template = "SPEAKER {} 1 {:.2f} {:.2f} {} \n" + results = [] + for mid, chunk_label_list, map_file_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar): + utt2map = load_scp_as_list(map_file_path, 'list') + multi_labels = calc_multi_labels(chunk_label_list, args.chunk_len, args.shift_len, args.n_spk, args.vote_prob) + multi_labels = smooth_multi_labels(multi_labels, args.smooth_size) + org_len = sample2ms(int(utt2map[-1][1][1]), args.sr) + org_multi_labels = np.zeros((org_len, args.n_spk)) + for seg_id, [org_st, org_ed, st, ed] in utt2map: + org_st, org_dur = sample2ms(int(org_st), args.sr), sample2ms(int(org_ed) - int(org_st), args.sr) + st, dur = sample2ms(int(st), args.sr), sample2ms(int(ed) - int(st), args.sr) + ll = min(org_multi_labels[org_st: org_st+org_dur, :].shape[0], multi_labels[st: st+dur, :].shape[0]) + org_multi_labels[org_st: org_st+ll, :] = multi_labels[st: st+ll, :] + spk_turns = calc_spk_turns(org_multi_labels, spk_list) + spk_turns = sorted(spk_turns, key=lambda x: x[1]) + for spk, st, dur in spk_turns: + # TODO: handle the leak of segments at the change points + if dur > args.ignore_len: + results.append(template.format(mid, float(st)/100, float(dur)/100, spk)) + return results + + +if __name__ == '__main__': + my_runner = MyRunner(process) + my_runner.run() diff --git a/egs/alimeeting/diarization/sond/path.sh b/egs/alimeeting/diarization/sond/path.sh new file mode 100755 index 000000000..7972642d0 --- /dev/null +++ b/egs/alimeeting/diarization/sond/path.sh @@ -0,0 +1,5 @@ +export FUNASR_DIR=$PWD/../../.. + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PATH=$FUNASR_DIR/funasr/bin:$PATH diff --git a/egs/alimeeting/diarization/sond/run.sh b/egs/alimeeting/diarization/sond/run.sh new file mode 100644 index 000000000..7e9a7f7ba --- /dev/null +++ b/egs/alimeeting/diarization/sond/run.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +. ./path.sh || exit 1; + +stage=0 +stop_stage=2 + +. utils/parse_options.sh || exit 1; + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "Downloading AliMeeting test set data..." + wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/alimeeting_test_data_for_sond.tar.gz + echo "Done. Extracting data..." + tar zxf alimeeting_test_data_for_sond.tar.gz + echo "Done." + + echo "Downloading Pre-trained model..." + git clone https://www.modelscope.cn/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch.git + git clone https://www.modelscope.cn/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch.git + ln -s speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth ./sv.pth + cp speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.yaml ./sv.yaml + ln -s speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.pth ./sond.pth + cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond_fbank.yaml ./sond_fbank.yaml + cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.yaml ./sond.yaml + echo "Done." + + echo "Downloading dscore for scoring..." + git clone https://github.com/nryant/dscore.git +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Calculating diarization results..." + python infer_alimeeting_test.py sond_fbank.yaml sond.pth outputs + python local/convert_label_to_rttm.py \ + outputs/labels.txt \ + data/test_rmsil/raw_rmsil_map.scp \ + outputs/prediction_sm_83.rttm \ + --ignore_len 10 --no_pbar --smooth_size 83 \ + --vote_prob 0.5 --n_spk 16 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Scoring..." + python dscore/score.py \ + -r data/test_rmsil/test_org.crttm \ + -s outputs/prediction_sm_83.rttm \ + --collar 0.25 +fi diff --git a/egs/alimeeting/diarization/sond/unit_test.py b/egs/alimeeting/diarization/sond/unit_test.py new file mode 100644 index 000000000..84a424762 --- /dev/null +++ b/egs/alimeeting/diarization/sond/unit_test.py @@ -0,0 +1,97 @@ +from funasr.bin.diar_inference_launch import inference_launch +import os + + +def test_fbank_cpu_infer(): + diar_config_path = "config_fbank.yaml" + diar_model_path = "sond.pth" + output_dir = "./outputs" + data_path_and_name_and_type = [ + ("data/unit_test/test_feats.scp", "speech", "kaldi_ark"), + ("data/unit_test/test_profile.scp", "profile", "kaldi_ark"), + ] + pipeline = inference_launch( + mode="sond", + diar_train_config=diar_config_path, + diar_model_file=diar_model_path, + output_dir=output_dir, + num_workers=1, + log_level="WARNING", + ) + results = pipeline(data_path_and_name_and_type) + print(results) + + +def test_fbank_gpu_infer(): + diar_config_path = "config_fbank.yaml" + diar_model_path = "sond.pth" + output_dir = "./outputs" + data_path_and_name_and_type = [ + ("data/unit_test/test_feats.scp", "speech", "kaldi_ark"), + ("data/unit_test/test_profile.scp", "profile", "kaldi_ark"), + ] + pipeline = inference_launch( + mode="sond", + diar_train_config=diar_config_path, + diar_model_file=diar_model_path, + output_dir=output_dir, + ngpu=1, + num_workers=1, + log_level="WARNING", + ) + results = pipeline(data_path_and_name_and_type) + print(results) + + +def test_wav_gpu_infer(): + diar_config_path = "config.yaml" + diar_model_path = "sond.pth" + output_dir = "./outputs" + data_path_and_name_and_type = [ + ("data/unit_test/test_wav.scp", "speech", "sound"), + ("data/unit_test/test_profile.scp", "profile", "kaldi_ark"), + ] + pipeline = inference_launch( + mode="sond", + diar_train_config=diar_config_path, + diar_model_file=diar_model_path, + output_dir=output_dir, + ngpu=1, + num_workers=1, + log_level="WARNING", + ) + results = pipeline(data_path_and_name_and_type) + print(results) + + +def test_without_profile_gpu_infer(): + diar_config_path = "config.yaml" + diar_model_path = "sond.pth" + output_dir = "./outputs" + raw_inputs = [[ + "data/unit_test/raw_inputs/record.wav", + "data/unit_test/raw_inputs/spk1.wav", + "data/unit_test/raw_inputs/spk2.wav", + "data/unit_test/raw_inputs/spk3.wav", + "data/unit_test/raw_inputs/spk4.wav" + ]] + pipeline = inference_launch( + mode="sond_demo", + diar_train_config=diar_config_path, + diar_model_file=diar_model_path, + output_dir=output_dir, + ngpu=1, + num_workers=1, + log_level="WARNING", + param_dict={}, + ) + results = pipeline(raw_inputs=raw_inputs) + print(results) + + +if __name__ == '__main__': + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + test_fbank_cpu_infer() + test_fbank_gpu_infer() + test_wav_gpu_infer() + test_without_profile_gpu_infer() diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py new file mode 100755 index 000000000..c3e210bf7 --- /dev/null +++ b/funasr/bin/diar_inference_launch.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import argparse +import logging +import os +import sys +from typing import Union, Dict, Any + +from funasr.utils import config_argparse +from funasr.utils.cli_utils import get_commandline_args +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Speaker Verification", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=False) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--njob", + type=int, + default=1, + help="The number of jobs for each gpu", + ) + parser.add_argument( + "--gpuid_list", + type=str, + default="", + help="The visible gpus", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=False, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=True) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--vad_infer_config", + type=str, + help="VAD infer configuration", + ) + group.add_argument( + "--vad_model_file", + type=str, + help="VAD model parameter file", + ) + group.add_argument( + "--diar_train_config", + type=str, + help="ASR training configuration", + ) + group.add_argument( + "--diar_model_file", + type=str, + help="ASR model parameter file", + ) + group.add_argument( + "--cmvn_file", + type=str, + help="Global CMVN file", + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + + group = parser.add_argument_group("The inference configuration related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group.add_argument( + "--diar_smooth_size", + type=int, + default=121, + help="The smoothing size for post-processing" + ) + + return parser + + +def inference_launch(mode, **kwargs): + if mode == "sond": + from funasr.bin.sond_inference import inference_modelscope + return inference_modelscope(**kwargs) + elif mode == "sond_demo": + from funasr.bin.sond_inference import inference_modelscope + param_dict = { + "extract_profile": True, + "sv_train_config": "sv.yaml", + "sv_model_file": "sv.pth", + } + if "param_dict" in kwargs: + kwargs["param_dict"].update(param_dict) + else: + kwargs["param_dict"] = param_dict + return inference_modelscope(**kwargs) + else: + logging.info("Unknown decoding mode: {}".format(mode)) + return None + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + parser.add_argument( + "--mode", + type=str, + default="sond", + help="The decoding mode", + ) + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + + # set logging messages + logging.basicConfig( + level=args.log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.info("Decoding args: {}".format(kwargs)) + + # gpu setting + if args.ngpu > 0: + jobid = int(args.output_dir.split(".")[-1]) + gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob] + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = gpuid + + inference_launch(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py new file mode 100755 index 000000000..299de0dda --- /dev/null +++ b/funasr/bin/sond_inference.py @@ -0,0 +1,544 @@ +#!/usr/bin/env python3 +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +from collections import OrderedDict +import numpy as np +import soundfile +import torch +from torch.nn import functional as F +from typeguard import check_argument_types +from typeguard import check_return_type + +from funasr.utils.cli_utils import get_commandline_args +from funasr.tasks.diar import DiarTask +from funasr.tasks.asr import ASRTask +from funasr.torch_utils.device_funcs import to_device +from funasr.torch_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import config_argparse +from funasr.utils.types import str2bool +from funasr.utils.types import str2triple_str +from funasr.utils.types import str_or_none +from scipy.ndimage import median_filter +from funasr.utils.misc import statistic_model_parameters + +class Speech2Diarization: + """Speech2Xvector class + + Examples: + >>> import soundfile + >>> import numpy as np + >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth") + >>> profile = np.load("profiles.npy") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2diar(audio, profile) + {"spk1": [(int, int), ...], ...} + + """ + + def __init__( + self, + diar_train_config: Union[Path, str] = None, + diar_model_file: Union[Path, str] = None, + device: str = "cpu", + batch_size: int = 1, + dtype: str = "float32", + streaming: bool = False, + smooth_size: int = 83, + dur_threshold: float = 10, + ): + assert check_argument_types() + + # TODO: 1. Build Diarization model + diar_model, diar_train_args = DiarTask.build_model_from_file( + config_file=diar_train_config, + model_file=diar_model_file, + device=device + ) + logging.info("diar_model: {}".format(diar_model)) + logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model))) + logging.info("diar_train_args: {}".format(diar_train_args)) + diar_model.to(dtype=getattr(torch, dtype)).eval() + + self.diar_model = diar_model + self.diar_train_args = diar_train_args + self.token_list = diar_train_args.token_list + self.smooth_size = smooth_size + self.dur_threshold = dur_threshold + self.device = device + self.dtype = dtype + + def smooth_multi_labels(self, multi_label): + multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int) + return multi_label + + @staticmethod + def calc_spk_turns(label_arr, spk_list): + turn_list = [] + length = label_arr.shape[0] + n_spk = label_arr.shape[1] + for k in range(n_spk): + if spk_list[k] == "None": + continue + in_utt = False + start = 0 + for i in range(length): + if label_arr[i, k] == 1 and in_utt is False: + start = i + in_utt = True + if label_arr[i, k] == 0 and in_utt is True: + turn_list.append([spk_list[k], start, i - start]) + in_utt = False + if in_utt: + turn_list.append([spk_list[k], start, length - start]) + return turn_list + + @staticmethod + def seq2arr(seq, vec_dim=8): + def int2vec(x, vec_dim=8, dtype=np.int): + b = ('{:0' + str(vec_dim) + 'b}').format(x) + # little-endian order: lower bit first + return (np.array(list(b)[::-1]) == '1').astype(dtype) + + return np.row_stack([int2vec(int(x), vec_dim) for x in seq]) + + def post_processing(self, raw_logits: torch.Tensor, spk_num: int): + logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T + # upsampling outputs to match inputs + ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio + logits_idx = F.upsample( + logits_idx.unsqueeze(1).float(), + size=(ut, ), + mode="nearest", + ).squeeze(1).long() + logits_idx = logits_idx[0].tolist() + pse_labels = [self.token_list[x] for x in logits_idx] + multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers + multi_labels = self.smooth_multi_labels(multi_labels) + spk_list = ["spk{}".format(i + 1) for i in range(spk_num)] + spk_turns = self.calc_spk_turns(multi_labels, spk_list) + results = OrderedDict() + for spk, st, dur in spk_turns: + if spk not in results: + results[spk] = [] + if dur > self.dur_threshold: + results[spk].append((st, st+dur)) + + # sort segments in start time ascending + for spk in results: + results[spk] = sorted(results[spk], key=lambda x: x[0]) + + return results, pse_labels + + @torch.no_grad() + def __call__( + self, + speech: Union[torch.Tensor, np.ndarray], + profile: Union[torch.Tensor, np.ndarray], + ): + """Inference + + Args: + speech: Input speech data + profile: Speaker profiles + Returns: + diarization results for each speaker + + """ + assert check_argument_types() + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + if isinstance(profile, np.ndarray): + profile = torch.tensor(profile) + + # data: (Nsamples,) -> (1, Nsamples) + speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + profile = profile.unsqueeze(0).to(getattr(torch, self.dtype)) + # lengths: (1,) + speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) + profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1)) + batch = {"speech": speech, "speech_lengths": speech_lengths, + "profile": profile, "profile_lengths": profile_lengths} + # a. To device + batch = to_device(batch, device=self.device) + + logits = self.diar_model.prediction_forward(**batch) + results, pse_labels = self.post_processing(logits, profile.shape[1]) + + return results, pse_labels + + @staticmethod + def from_pretrained( + model_tag: Optional[str] = None, + **kwargs: Optional[Any], + ): + """Build Speech2Xvector instance from the pretrained model. + + Args: + model_tag (Optional[str]): Model tag of the pretrained models. + Currently, the tags of espnet_model_zoo are supported. + + Returns: + Speech2Xvector: Speech2Xvector instance. + + """ + if model_tag is not None: + try: + from espnet_model_zoo.downloader import ModelDownloader + + except ImportError: + logging.error( + "`espnet_model_zoo` is not installed. " + "Please install via `pip install -U espnet_model_zoo`." + ) + raise + d = ModelDownloader() + kwargs.update(**d.download_and_unpack(model_tag)) + + return Speech2Diarization(**kwargs) + + +def inference_modelscope( + diar_train_config: str, + diar_model_file: str, + output_dir: Optional[str] = None, + batch_size: int = 1, + dtype: str = "float32", + ngpu: int = 0, + seed: int = 0, + num_workers: int = 0, + log_level: Union[int, str] = "INFO", + key_file: Optional[str] = None, + model_tag: Optional[str] = None, + allow_variable_data_keys: bool = True, + streaming: bool = False, + smooth_size: int = 83, + dur_threshold: int = 10, + out_format: str = "vad", + param_dict: Optional[dict] = None, + **kwargs, +): + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.info("param_dict: {}".format(param_dict)) + + if ngpu >= 1 and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2a. Build speech2xvec [Optional] + if param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]: + assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict." + assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict." + sv_train_config = param_dict["sv_train_config"] + sv_model_file = param_dict["sv_model_file"] + from funasr.bin.sv_inference import Speech2Xvector + speech2xvector_kwargs = dict( + sv_train_config=sv_train_config, + sv_model_file=sv_model_file, + device=device, + dtype=dtype, + streaming=streaming, + embedding_node="resnet1_dense" + ) + logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs)) + speech2xvector = Speech2Xvector.from_pretrained( + model_tag=model_tag, + **speech2xvector_kwargs, + ) + speech2xvector.sv_model.eval() + + # 2b. Build speech2diar + speech2diar_kwargs = dict( + diar_train_config=diar_train_config, + diar_model_file=diar_model_file, + device=device, + dtype=dtype, + streaming=streaming, + smooth_size=smooth_size, + dur_threshold=dur_threshold, + ) + logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs)) + speech2diar = Speech2Diarization.from_pretrained( + model_tag=model_tag, + **speech2diar_kwargs, + ) + speech2diar.diar_model.eval() + + def output_results_str(results: dict, uttid: str): + rst = [] + mid = uttid.rsplit("-", 1)[0] + for key in results: + results[key] = [(x[0]/100, x[1]/100) for x in results[key]] + if out_format == "vad": + for spk, segs in results.items(): + rst.append("{} {}".format(spk, segs)) + else: + template = "SPEAKER {} 0 {:.2f} {:.2f} {} " + for spk, segs in results.items(): + rst.extend([template.format(mid, st, ed, spk) for st, ed in segs]) + + return "\n".join(rst) + + def _forward( + data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, + raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str]]] = None, + output_dir_v2: Optional[str] = None, + param_dict: Optional[dict] = None, + ): + logging.info("param_dict: {}".format(param_dict)) + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, (list, tuple)): + assert all([len(example) >= 2 for example in raw_inputs]), \ + "The length of test case in raw_inputs must larger than 1 (>=2)." + + def prepare_dataset(): + for idx, example in enumerate(raw_inputs): + # read waveform file + example = [soundfile.read(x)[0] if isinstance(example[0], str) else x + for x in example] + # convert torch tensor to numpy array + example = [x.numpy() if isinstance(example[0], torch.Tensor) else x + for x in example] + speech = example[0] + logging.info("Extracting profiles for {} waveforms".format(len(example)-1)) + profile = [speech2xvector.calculate_embedding(x) for x in example[1:]] + profile = torch.cat(profile, dim=0) + yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]} + + loader = prepare_dataset() + else: + raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ") + else: + # 3. Build data-iterator + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=None, + collate_fn=None, + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + # 7. Start for-loop + output_path = output_dir_v2 if output_dir_v2 is not None else output_dir + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + output_writer = open("{}/result.txt".format(output_path), "w") + pse_label_writer = open("{}/labels.txt".format(output_path), "w") + logging.info("Start to diarize...") + result_list = [] + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + + results, pse_labels = speech2diar(**batch) + # Only supporting batch_size==1 + key, value = keys[0], output_results_str(results, keys[0]) + item = {"key": key, "value": value} + result_list.append(item) + if output_path is not None: + output_writer.write(value) + output_writer.flush() + pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels))) + pse_label_writer.flush() + + if output_path is not None: + output_writer.close() + pse_label_writer.close() + + return result_list + + return _forward + + +def inference( + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + diar_train_config: Optional[str], + diar_model_file: Optional[str], + output_dir: Optional[str] = None, + batch_size: int = 1, + dtype: str = "float32", + ngpu: int = 0, + seed: int = 0, + num_workers: int = 1, + log_level: Union[int, str] = "INFO", + key_file: Optional[str] = None, + model_tag: Optional[str] = None, + allow_variable_data_keys: bool = True, + streaming: bool = False, + smooth_size: int = 83, + dur_threshold: int = 10, + out_format: str = "vad", + **kwargs, +): + inference_pipeline = inference_modelscope( + diar_train_config=diar_train_config, + diar_model_file=diar_model_file, + output_dir=output_dir, + batch_size=batch_size, + dtype=dtype, + ngpu=ngpu, + seed=seed, + num_workers=num_workers, + log_level=log_level, + key_file=key_file, + model_tag=model_tag, + allow_variable_data_keys=allow_variable_data_keys, + streaming=streaming, + smooth_size=smooth_size, + dur_threshold=dur_threshold, + out_format=out_format, + **kwargs, + ) + + return inference_pipeline(data_path_and_name_and_type, raw_inputs=None) + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Speaker verification/x-vector extraction", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=False) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--gpuid_list", + type=str, + default="", + help="The visible gpus", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=False, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--diar_train_config", + type=str, + help="diarization training configuration", + ) + group.add_argument( + "--diar_model_file", + type=str, + help="diarization model parameter file", + ) + group.add_argument( + "--dur_threshold", + type=int, + default=10, + help="The threshold for short segments in number frames" + ) + parser.add_argument( + "--smooth_size", + type=int, + default=83, + help="The smoothing window length in number frames" + ) + group.add_argument( + "--model_tag", + type=str, + help="Pretrained model tag. If specify this option, *_train_config and " + "*_file will be overwritten", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + parser.add_argument("--streaming", type=str2bool, default=False) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + logging.info("args: {}".format(kwargs)) + if args.output_dir is None: + jobid, n_gpu = 1, 1 + gpuid = args.gpuid_list.split(",")[jobid-1] + else: + jobid = int(args.output_dir.split(".")[-1]) + n_gpu = len(args.gpuid_list.split(",")) + gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu] + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = gpuid + results_list = inference(**kwargs) + for results in results_list: + print("{} {}".format(results["key"], results["value"])) + + +if __name__ == "__main__": + main() diff --git a/funasr/bin/sv_inference.py b/funasr/bin/sv_inference.py index 57ce91d6d..a78bccded 100755 --- a/funasr/bin/sv_inference.py +++ b/funasr/bin/sv_inference.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + import argparse import logging import os @@ -26,7 +29,7 @@ from funasr.utils import config_argparse from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none - +from funasr.utils.misc import statistic_model_parameters class Speech2Xvector: """Speech2Xvector class @@ -59,6 +62,7 @@ class Speech2Xvector: device=device ) logging.info("sv_model: {}".format(sv_model)) + logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model))) logging.info("sv_train_args: {}".format(sv_train_args)) sv_model.to(dtype=getattr(torch, dtype)).eval() @@ -156,17 +160,17 @@ class Speech2Xvector: def inference_modelscope( - output_dir: Optional[str], - batch_size: int, - dtype: str, - ngpu: int, - seed: int, - num_workers: int, - log_level: Union[int, str], - key_file: Optional[str], - sv_train_config: Optional[str], - sv_model_file: Optional[str], - model_tag: Optional[str], + output_dir: Optional[str] = None, + batch_size: int = 1, + dtype: str = "float32", + ngpu: int = 1, + seed: int = 0, + num_workers: int = 0, + log_level: Union[int, str] = "INFO", + key_file: Optional[str] = None, + sv_train_config: Optional[str] = "sv.yaml", + sv_model_file: Optional[str] = "sv.pth", + model_tag: Optional[str] = None, allow_variable_data_keys: bool = True, streaming: bool = False, embedding_node: str = "resnet1_dense", @@ -214,7 +218,6 @@ def inference_modelscope( data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, raw_inputs: Union[np.ndarray, torch.Tensor] = None, output_dir_v2: Optional[str] = None, - fs: dict = None, param_dict: Optional[dict] = None, ): logging.info("param_dict: {}".format(param_dict)) diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py index c511dc717..1205d194d 100755 --- a/funasr/bin/sv_inference_launch.py +++ b/funasr/bin/sv_inference_launch.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) import argparse import logging diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py new file mode 100644 index 000000000..d29ffe5c6 --- /dev/null +++ b/funasr/models/e2e_diar_sond.py @@ -0,0 +1,402 @@ +#!/usr/bin/env python3 +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from contextlib import contextmanager +from distutils.version import LooseVersion +from itertools import permutations +from typing import Dict +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from torch.nn import functional as F +from typeguard import check_argument_types + +from funasr.modules.nets_utils import to_device +from funasr.modules.nets_utils import make_pad_mask +from funasr.models.decoder.abs_decoder import AbsDecoder +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.layers.abs_normalize import AbsNormalize +from funasr.torch_utils.device_funcs import force_gatherable +from funasr.train.abs_espnet_model import AbsESPnetModel + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class DiarSondModel(AbsESPnetModel): + """Speaker overlap-aware neural diarization model + reference: https://arxiv.org/abs/2211.10243 + """ + + def __init__( + self, + vocab_size: int, + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + encoder: AbsEncoder, + speaker_encoder: AbsEncoder, + ci_scorer: torch.nn.Module, + cd_scorer: torch.nn.Module, + decoder: torch.nn.Module, + token_list: list, + lsm_weight: float = 0.1, + length_normalized_loss: bool = False, + max_spk_num: int = 16, + label_aggregator: Optional[torch.nn.Module] = None, + normlize_speech_speaker: bool = False, + ): + assert check_argument_types() + + super().__init__() + + self.encoder = encoder + self.speaker_encoder = speaker_encoder + self.ci_scorer = ci_scorer + self.cd_scorer = cd_scorer + self.normalize = normalize + self.frontend = frontend + self.specaug = specaug + self.label_aggregator = label_aggregator + self.decoder = decoder + self.token_list = token_list + self.max_spk_num = max_spk_num + self.normalize_speech_speaker = normlize_speech_speaker + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor = None, + profile: torch.Tensor = None, + profile_lengths: torch.Tensor = None, + spk_labels: torch.Tensor = None, + spk_labels_lengths: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss + + Args: + speech: (Batch, samples) + speech_lengths: (Batch,) default None for chunk interator, + because the chunk-iterator does not + have the speech_lengths returned. + see in + espnet2/iterators/chunk_iter_factory.py + profile: (Batch, N_spk, dim) + profile_lengths: (Batch,) + spk_labels: (Batch, ) + """ + assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape) + batch_size = speech.shape[0] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + if self.attractor is None: + # 2a. Decoder (baiscally a predction layer after encoder_out) + pred = self.decoder(encoder_out, encoder_out_lens) + else: + # 2b. Encoder Decoder Attractors + # Shuffle the chronological order of encoder_out, then calculate attractor + encoder_out_shuffled = encoder_out.clone() + for i in range(len(encoder_out_lens)): + encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[ + i, torch.randperm(encoder_out_lens[i]), : + ] + attractor, att_prob = self.attractor( + encoder_out_shuffled, + encoder_out_lens, + to_device( + self, + torch.zeros( + encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2) + ), + ), + ) + # Remove the final attractor which does not correspond to a speaker + # Then multiply the attractors and encoder_out + pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1)) + # 3. Aggregate time-domain labels + if self.label_aggregator is not None: + spk_labels, spk_labels_lengths = self.label_aggregator( + spk_labels, spk_labels_lengths + ) + + # If encoder uses conv* as input_layer (i.e., subsampling), + # the sequence length of 'pred' might be slighly less than the + # length of 'spk_labels'. Here we force them to be equal. + length_diff_tolerance = 2 + length_diff = spk_labels.shape[1] - pred.shape[1] + if length_diff > 0 and length_diff <= length_diff_tolerance: + spk_labels = spk_labels[:, 0 : pred.shape[1], :] + + if self.attractor is None: + loss_pit, loss_att = None, None + loss, perm_idx, perm_list, label_perm = self.pit_loss( + pred, spk_labels, encoder_out_lens + ) + else: + loss_pit, perm_idx, perm_list, label_perm = self.pit_loss( + pred, spk_labels, encoder_out_lens + ) + loss_att = self.attractor_loss(att_prob, spk_labels) + loss = loss_pit + self.attractor_weight * loss_att + ( + correct, + num_frames, + speech_scored, + speech_miss, + speech_falarm, + speaker_scored, + speaker_miss, + speaker_falarm, + speaker_error, + ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens) + + if speech_scored > 0 and num_frames > 0: + sad_mr, sad_fr, mi, fa, cf, acc, der = ( + speech_miss / speech_scored, + speech_falarm / speech_scored, + speaker_miss / speaker_scored, + speaker_falarm / speaker_scored, + speaker_error / speaker_scored, + correct / num_frames, + (speaker_miss + speaker_falarm + speaker_error) / speaker_scored, + ) + else: + sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0 + + stats = dict( + loss=loss.detach(), + loss_att=loss_att.detach() if loss_att is not None else None, + loss_pit=loss_pit.detach() if loss_pit is not None else None, + sad_mr=sad_mr, + sad_fr=sad_fr, + mi=mi, + fa=fa, + cf=cf, + acc=acc, + der=der, + ) + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + spk_labels: torch.Tensor = None, + spk_labels_lengths: torch.Tensor = None, + ) -> Dict[str, torch.Tensor]: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode_speaker( + self, + profile: torch.Tensor, + profile_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + with autocast(False): + if profile.shape[1] < self.max_spk_num: + profile = F.pad(profile, [0, 0, 0, self.max_spk_num-profile.shape[1], 0, 0], "constant", 0.0) + profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float() + profile = F.normalize(profile, dim=2) + if self.speaker_encoder is not None: + profile = self.speaker_encoder(profile, profile_lengths)[0] + return profile * profile_mask, profile_lengths + else: + return profile, profile_lengths + + def encode_speech( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.encoder is not None: + speech, speech_lengths = self.encode(speech, speech_lengths) + speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1]) + speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float() + return speech * speech_mask, speech_lengths + else: + return speech, speech_lengths + + @staticmethod + def concate_speech_ivc( + speech: torch.Tensor, + ivc: torch.Tensor + ) -> torch.Tensor: + nn, tt = ivc.shape[1], speech.shape[1] + speech = speech.unsqueeze(dim=1) # B x 1 x T x D + speech = speech.expand(-1, nn, -1, -1) # B x N x T x D + ivc = ivc.unsqueeze(dim=2) # B x N x 1 x D + ivc = ivc.expand(-1, -1, tt, -1) # B x N x T x D + sd_in = torch.cat([speech, ivc], dim=3) # B x N x T x 2D + return sd_in + + def calc_similarity( + self, + speech_encoder_outputs: torch.Tensor, + speaker_encoder_outputs: torch.Tensor, + seq_len: torch.Tensor = None, + spk_len: torch.Tensor = None, + ) -> torch.Tensor: + bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1] + d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2] + if self.normalize_speech_speaker: + speech_encoder_outputs = F.normalize(speech_encoder_outputs, dim=2) + speaker_encoder_outputs = F.normalize(speaker_encoder_outputs, dim=2) + ge_in = self.concate_speech_ivc(speech_encoder_outputs, speaker_encoder_outputs) + ge_in = torch.reshape(ge_in, [bb * self.max_spk_num, tt, d_sph + d_spk]) + ge_len = seq_len.unsqueeze(1).expand(-1, self.max_spk_num) + ge_len = torch.reshape(ge_len, [bb * self.max_spk_num]) + cd_simi = self.cd_scorer(ge_in, ge_len)[0] + cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1]) + cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1]) + + if isinstance(self.ci_scorer, AbsEncoder): + ci_simi = self.ci_scorer(ge_in, ge_len)[0] + else: + ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs) + simi = torch.cat([cd_simi, ci_simi], dim=2) + + return simi + + def post_net_forward(self, simi, seq_len): + logits = self.decoder(simi, seq_len)[0] + + return logits + + def prediction_forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + profile: torch.Tensor, + profile_lengths: torch.Tensor, + ) -> torch.Tensor: + # speech encoding + speech, speech_lengths = self.encode_speech(speech, speech_lengths) + # speaker encoding + profile, profile_lengths = self.encode_speaker(profile, profile_lengths) + # calculating similarity + similarity = self.calc_similarity(speech, profile, speech_lengths, profile_lengths) + # post net forward + logits = self.post_net_forward(similarity, speech_lengths) + + return logits + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch,) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim) + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = speech.shape[0] + speech_lengths = ( + speech_lengths + if speech_lengths is not None + else torch.ones(batch_size).int() * speech.shape[1] + ) + + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + @staticmethod + def calc_diarization_error(pred, label, length): + # Note (jiatong): Credit to https://github.com/hitachi-speech/EEND + + (batch_size, max_len, num_output) = label.size() + # mask the padding part + mask = np.zeros((batch_size, max_len, num_output)) + for i in range(batch_size): + mask[i, : length[i], :] = 1 + + # pred and label have the shape (batch_size, max_len, num_output) + label_np = label.data.cpu().numpy().astype(int) + pred_np = (pred.data.cpu().numpy() > 0).astype(int) + label_np = label_np * mask + pred_np = pred_np * mask + length = length.data.cpu().numpy() + + # compute speech activity detection error + n_ref = np.sum(label_np, axis=2) + n_sys = np.sum(pred_np, axis=2) + speech_scored = float(np.sum(n_ref > 0)) + speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0))) + speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0))) + + # compute speaker diarization error + speaker_scored = float(np.sum(n_ref)) + speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0))) + speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0))) + n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2) + speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map)) + correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output) + num_frames = np.sum(length) + return ( + correct, + num_frames, + speech_scored, + speech_miss, + speech_falarm, + speaker_scored, + speaker_miss, + speaker_falarm, + speaker_error, + ) diff --git a/funasr/models/encoder/opennmt_encoders/__init__.py b/funasr/models/encoder/opennmt_encoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models/encoder/opennmt_encoders/ci_scorers.py b/funasr/models/encoder/opennmt_encoders/ci_scorers.py new file mode 100644 index 000000000..50056ee28 --- /dev/null +++ b/funasr/models/encoder/opennmt_encoders/ci_scorers.py @@ -0,0 +1,38 @@ +import torch +from torch.nn import functional as F + + +class DotScorer(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + xs_pad: torch.Tensor, + spk_emb: torch.Tensor, + ): + # xs_pad: B, T, D + # spk_emb: B, N, D + scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2)) + return scores + + def convert_tf2torch(self, var_dict_tf, var_dict_torch): + return {} + + +class CosScorer(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + xs_pad: torch.Tensor, + spk_emb: torch.Tensor, + ): + # xs_pad: B, T, D + # spk_emb: B, N, D + scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1) + return scores + + def convert_tf2torch(self, var_dict_tf, var_dict_torch): + return {} diff --git a/funasr/models/encoder/opennmt_encoders/conv_encoder.py b/funasr/models/encoder/opennmt_encoders/conv_encoder.py new file mode 100644 index 000000000..40967437b --- /dev/null +++ b/funasr/models/encoder/opennmt_encoders/conv_encoder.py @@ -0,0 +1,277 @@ +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +import logging +import torch +import torch.nn as nn +from torch.nn import functional as F +from typeguard import check_argument_types +import numpy as np +from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.layer_norm import LayerNorm +from funasr.models.encoder.abs_encoder import AbsEncoder +import math +from funasr.modules.repeat import repeat + + +class EncoderLayer(nn.Module): + def __init__( + self, + input_units, + num_units, + kernel_size=3, + activation="tanh", + stride=1, + include_batch_norm=False, + residual=False + ): + super().__init__() + left_padding = math.ceil((kernel_size - stride) / 2) + right_padding = kernel_size - stride - left_padding + self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0) + self.conv1d = nn.Conv1d( + input_units, + num_units, + kernel_size, + stride, + ) + self.activation = self.get_activation(activation) + if include_batch_norm: + self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3) + self.residual = residual + self.include_batch_norm = include_batch_norm + self.input_units = input_units + self.num_units = num_units + self.stride = stride + + @staticmethod + def get_activation(activation): + if activation == "tanh": + return nn.Tanh() + else: + return nn.ReLU() + + def forward(self, xs_pad, ilens=None): + outputs = self.conv1d(self.conv_padding(xs_pad)) + if self.residual and self.stride == 1 and self.input_units == self.num_units: + outputs = outputs + xs_pad + + if self.include_batch_norm: + outputs = self.bn(outputs) + + # add parenthesis for repeat module + return self.activation(outputs), ilens + + +class ConvEncoder(AbsEncoder): + """ + author: Speech Lab, Alibaba Group, China + Convolution encoder in OpenNMT framework + """ + + def __init__( + self, + num_layers, + input_units, + num_units, + kernel_size=3, + dropout_rate=0.3, + position_encoder=None, + activation='tanh', + auxiliary_states=True, + out_units=None, + out_norm=False, + out_residual=False, + include_batchnorm=False, + regularization_weight=0.0, + stride=1, + tf2torch_tensor_name_prefix_torch: str = "speaker_encoder", + tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder", + ): + assert check_argument_types() + super().__init__() + self._output_size = num_units + + self.num_layers = num_layers + self.input_units = input_units + self.num_units = num_units + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + self.position_encoder = position_encoder + self.out_units = out_units + self.auxiliary_states = auxiliary_states + self.out_norm = out_norm + self.activation = activation + self.out_residual = out_residual + self.include_batch_norm = include_batchnorm + self.regularization_weight = regularization_weight + self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch + self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + if isinstance(stride, int): + self.stride = [stride] * self.num_layers + else: + self.stride = stride + self.downsample_rate = 1 + for s in self.stride: + self.downsample_rate *= s + + self.dropout = nn.Dropout(dropout_rate) + self.cnn_a = repeat( + self.num_layers, + lambda lnum: EncoderLayer( + input_units if lnum == 0 else num_units, + num_units, + kernel_size, + activation, + self.stride[lnum], + include_batchnorm, + residual=True if lnum > 0 else False + ) + ) + + if self.out_units is not None: + left_padding = math.ceil((kernel_size - stride) / 2) + right_padding = kernel_size - stride - left_padding + self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0) + self.conv_out = nn.Conv1d( + num_units, + num_units, + kernel_size, + ) + + if self.out_norm: + self.after_norm = LayerNorm(num_units) + + def output_size(self) -> int: + return self.num_units + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + + inputs = xs_pad + if self.position_encoder is not None: + inputs = self.position_encoder(inputs) + + if self.dropout_rate > 0: + inputs = self.dropout(inputs) + + outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens) + + if self.out_units is not None: + outputs = self.conv_out(self.out_padding(outputs)) + + outputs = outputs.transpose(1, 2) + if self.out_norm: + outputs = self.after_norm(outputs) + + if self.out_residual: + outputs = outputs + inputs + + return outputs, ilens, None + + def gen_tf2torch_map_dict(self): + tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch + tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf + map_dict_local = { + # torch: conv1d.weight in "out_channel in_channel kernel_size" + # tf : conv1d.weight in "kernel_size in_channel out_channel" + # torch: linear.weight in "out_channel in_channel" + # tf : dense.weight in "in_channel out_channel" + "{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch): + {"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (2, 1, 0), + }, + "{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch): + {"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + + "{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch): + {"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (2, 1, 0), + }, + "{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch): + {"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + } + if self.out_units is not None: + # add output layer + map_dict_local.update({ + "{}.conv_out.weight".format(tensor_name_prefix_torch): + {"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers), + "squeeze": None, + "transpose": (2, 1, 0), + }, # tf: (1, 256, 256) -> torch: (256, 256, 1) + "{}.conv_out.bias".format(tensor_name_prefix_torch): + {"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers), + "squeeze": None, + "transpose": None, + }, # tf: (256,) -> torch: (256,) + }) + + return map_dict_local + + def convert_tf2torch(self, + var_dict_tf, + var_dict_torch, + ): + + map_dict = self.gen_tf2torch_map_dict() + + var_dict_torch_update = dict() + for name in sorted(var_dict_torch.keys(), reverse=False): + if name.startswith(self.tf2torch_tensor_name_prefix_torch): + # process special (first and last) layers + if name in map_dict: + name_tf = map_dict[name]["name"] + data_tf = var_dict_tf[name_tf] + if map_dict[name]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) + if map_dict[name]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), \ + "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[name].size(), data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format( + name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape + )) + # process general layers + else: + # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case + names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.') + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), \ + "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[name].size(), data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format( + name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape + )) + else: + logging.warning("{} is missed from tf checkpoint".format(name)) + + return var_dict_torch_update + diff --git a/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py b/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py new file mode 100644 index 000000000..e41b2aac5 --- /dev/null +++ b/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py @@ -0,0 +1,335 @@ +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +import logging +import torch +import torch.nn as nn +from torch.nn import functional as F +from typeguard import check_argument_types +import numpy as np +from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.layer_norm import LayerNorm +from funasr.models.encoder.abs_encoder import AbsEncoder +import math +from funasr.modules.repeat import repeat +from funasr.modules.multi_layer_conv import FsmnFeedForward + + +class FsmnBlock(torch.nn.Module): + def __init__( + self, + n_feat, + dropout_rate, + kernel_size, + fsmn_shift=0, + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout_rate) + self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, + padding=0, groups=n_feat, bias=False) + # padding + left_padding = (kernel_size - 1) // 2 + if fsmn_shift > 0: + left_padding = left_padding + fsmn_shift + right_padding = kernel_size - 1 - left_padding + self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) + + def forward(self, inputs, mask, mask_shfit_chunk=None): + b, t, d = inputs.size() + if mask is not None: + mask = torch.reshape(mask, (b, -1, 1)) + if mask_shfit_chunk is not None: + mask = mask * mask_shfit_chunk + + inputs = inputs * mask + x = inputs.transpose(1, 2) + x = self.pad_fn(x) + x = self.fsmn_block(x) + x = x.transpose(1, 2) + x = x + inputs + x = self.dropout(x) + return x * mask + + +class EncoderLayer(torch.nn.Module): + def __init__( + self, + in_size, + size, + feed_forward, + fsmn_block, + dropout_rate=0.0 + ): + super().__init__() + self.in_size = in_size + self.size = size + self.ffn = feed_forward + self.memory = fsmn_block + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + xs_pad: torch.Tensor, + mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # xs_pad in Batch, Time, Dim + + context = self.ffn(xs_pad)[0] + memory = self.memory(context, mask) + + memory = self.dropout(memory) + if self.in_size == self.size: + return memory + xs_pad, mask + + return memory, mask + + +class FsmnEncoder(AbsEncoder): + """Encoder using Fsmn + """ + + def __init__(self, + in_units, + filter_size, + fsmn_num_layers, + dnn_num_layers, + num_memory_units=512, + ffn_inner_dim=2048, + dropout_rate=0.0, + shift=0, + position_encoder=None, + sample_rate=1, + out_units=None, + tf2torch_tensor_name_prefix_torch="post_net", + tf2torch_tensor_name_prefix_tf="EAND/post_net" + ): + """Initializes the parameters of the encoder. + + Args: + filter_size: the total order of memory block + fsmn_num_layers: The number of fsmn layers. + dnn_num_layers: The number of dnn layers + num_units: The number of memory units. + ffn_inner_dim: The number of units of the inner linear transformation + in the feed forward layer. + dropout_rate: The probability to drop units from the outputs. + shift: left padding, to control delay + position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to + apply on inputs or ``None``. + """ + super(FsmnEncoder, self).__init__() + self.in_units = in_units + self.filter_size = filter_size + self.fsmn_num_layers = fsmn_num_layers + self.dnn_num_layers = dnn_num_layers + self.num_memory_units = num_memory_units + self.ffn_inner_dim = ffn_inner_dim + self.dropout_rate = dropout_rate + self.shift = shift + if not isinstance(shift, list): + self.shift = [shift for _ in range(self.fsmn_num_layers)] + self.sample_rate = sample_rate + if not isinstance(sample_rate, list): + self.sample_rate = [sample_rate for _ in range(self.fsmn_num_layers)] + self.position_encoder = position_encoder + self.dropout = nn.Dropout(dropout_rate) + self.out_units = out_units + self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch + self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + + self.fsmn_layers = repeat( + self.fsmn_num_layers, + lambda lnum: EncoderLayer( + in_units if lnum == 0 else num_memory_units, + num_memory_units, + FsmnFeedForward( + in_units if lnum == 0 else num_memory_units, + ffn_inner_dim, + num_memory_units, + 1, + dropout_rate + ), + FsmnBlock( + num_memory_units, + dropout_rate, + filter_size, + self.shift[lnum] + ) + ), + ) + + self.dnn_layers = repeat( + dnn_num_layers, + lambda lnum: FsmnFeedForward( + num_memory_units, + ffn_inner_dim, + num_memory_units, + 1, + dropout_rate, + ) + ) + if out_units is not None: + self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1) + + def output_size(self) -> int: + return self.num_memory_units + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + inputs = xs_pad + if self.position_encoder is not None: + inputs = self.position_encoder(inputs) + + inputs = self.dropout(inputs) + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + inputs = self.fsmn_layers(inputs, masks)[0] + inputs = self.dnn_layers(inputs)[0] + + if self.out_units is not None: + inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2) + + return inputs, ilens, None + + def gen_tf2torch_map_dict(self): + tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch + tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf + map_dict_local = { + # torch: conv1d.weight in "out_channel in_channel kernel_size" + # tf : conv1d.weight in "kernel_size in_channel out_channel" + # torch: linear.weight in "out_channel in_channel" + # tf : dense.weight in "in_channel out_channel" + # for fsmn_layers + "{}.fsmn_layers.layeridx.ffn.norm.bias".format(tensor_name_prefix_torch): + {"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.fsmn_layers.layeridx.ffn.norm.weight".format(tensor_name_prefix_torch): + {"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.fsmn_layers.layeridx.ffn.w_1.bias".format(tensor_name_prefix_torch): + {"name": "{}/fsmn_layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.fsmn_layers.layeridx.ffn.w_1.weight".format(tensor_name_prefix_torch): + {"name": "{}/fsmn_layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (2, 1, 0), + }, + "{}.fsmn_layers.layeridx.ffn.w_2.weight".format(tensor_name_prefix_torch): + {"name": "{}/fsmn_layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (2, 1, 0), + }, + "{}.fsmn_layers.layeridx.memory.fsmn_block.weight".format(tensor_name_prefix_torch): + {"name": "{}/fsmn_layer_layeridx/memory/depth_conv_w".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 2, 0), + }, # (1, 31, 512, 1) -> (31, 512, 1) -> (512, 1, 31) + + # for dnn_layers + "{}.dnn_layers.layeridx.norm.bias".format(tensor_name_prefix_torch): + {"name": "{}/dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.dnn_layers.layeridx.norm.weight".format(tensor_name_prefix_torch): + {"name": "{}/dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.dnn_layers.layeridx.w_1.bias".format(tensor_name_prefix_torch): + {"name": "{}/dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.dnn_layers.layeridx.w_1.weight".format(tensor_name_prefix_torch): + {"name": "{}/dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (2, 1, 0), + }, + "{}.dnn_layers.layeridx.w_2.weight".format(tensor_name_prefix_torch): + {"name": "{}/dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (2, 1, 0), + }, + + } + if self.out_units is not None: + # add output layer + map_dict_local.update({ + "{}.conv1d.weight".format(tensor_name_prefix_torch): + {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (2, 1, 0), + }, + "{}.conv1d.bias".format(tensor_name_prefix_torch): + {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + }) + + return map_dict_local + + def convert_tf2torch(self, + var_dict_tf, + var_dict_torch, + ): + + map_dict = self.gen_tf2torch_map_dict() + + var_dict_torch_update = dict() + for name in sorted(var_dict_torch.keys(), reverse=False): + if name.startswith(self.tf2torch_tensor_name_prefix_torch): + # process special (first and last) layers + if name in map_dict: + name_tf = map_dict[name]["name"] + data_tf = var_dict_tf[name_tf] + if map_dict[name]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) + if map_dict[name]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), \ + "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[name].size(), data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format( + name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape + )) + # process general layers + else: + # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case + names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.') + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), \ + "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[name].size(), data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format( + name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape + )) + else: + logging.warning("{} is missed from tf checkpoint".format(name)) + + return var_dict_torch_update diff --git a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py new file mode 100644 index 000000000..443b37ae3 --- /dev/null +++ b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py @@ -0,0 +1,480 @@ +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +import logging +import torch +import torch.nn as nn +from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk +from typeguard import check_argument_types +import numpy as np +from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM +from funasr.modules.embedding import SinusoidalPositionEncoder +from funasr.modules.layer_norm import LayerNorm +from funasr.modules.multi_layer_conv import Conv1dLinear +from funasr.modules.multi_layer_conv import MultiLayeredConv1d +from funasr.modules.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from funasr.modules.repeat import repeat +from funasr.modules.subsampling import Conv2dSubsampling +from funasr.modules.subsampling import Conv2dSubsampling2 +from funasr.modules.subsampling import Conv2dSubsampling6 +from funasr.modules.subsampling import Conv2dSubsampling8 +from funasr.modules.subsampling import TooShortUttError +from funasr.modules.subsampling import check_short_utt +from funasr.models.ctc import CTC +from funasr.models.encoder.abs_encoder import AbsEncoder + + +class EncoderLayer(nn.Module): + def __init__( + self, + in_size, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(in_size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.in_size = in_size + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + self.dropout_rate = dropout_rate + + def forward(self, x, mask, cache=None, mask_att_chunk_encoder=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = stoch_layer_coeff * self.concat_linear(x_concat) + else: + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.dropout( + self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder) + ) + else: + x = stoch_layer_coeff * self.dropout( + self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder) + ) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + return x, mask, cache, mask_att_chunk_encoder + + +class SelfAttentionEncoder(AbsEncoder): + """ + author: Speech Lab, Alibaba Group, China + Self attention encoder in OpenNMT framework + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + pos_enc_class=SinusoidalPositionEncoder, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + tf2torch_tensor_name_prefix_torch: str = "encoder", + tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", + out_units=None, + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + SinusoidalPositionEncoder(), + ) + elif input_layer is None: + if input_size == output_size: + self.embed = None + else: + self.embed = torch.nn.Linear(input_size, output_size) + elif input_layer == "pe": + self.embed = SinusoidalPositionEncoder() + elif input_layer == "null": + self.embed = None + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + output_size, + MultiHeadSelfAttention( + attention_heads, + output_size, + output_size, + attention_dropout_rate, + ), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ) if lnum > 0 else EncoderLayer( + input_size, + output_size, + MultiHeadSelfAttention( + attention_heads, + input_size if input_layer == "pe" or input_layer == "null" else output_size, + output_size, + attention_dropout_rate, + ), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + self.dropout = nn.Dropout(dropout_rate) + self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch + self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + self.out_units = out_units + if out_units is not None: + self.output_linear = nn.Linear(output_size, out_units) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + xs_pad *= self.output_size()**0.5 + if self.embed is None: + xs_pad = xs_pad + elif ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + xs_pad = self.dropout(xs_pad) + # encoder_outs = self.encoders0(xs_pad, masks) + # xs_pad, masks = encoder_outs[0], encoder_outs[1] + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + encoder_outs = self.encoders(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + if self.out_units is not None: + xs_pad = self.output_linear(xs_pad) + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None + + def gen_tf2torch_map_dict(self): + tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch + tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf + map_dict_local = { + # cicd + # torch: conv1d.weight in "out_channel in_channel kernel_size" + # tf : conv1d.weight in "kernel_size in_channel out_channel" + # torch: linear.weight in "out_channel in_channel" + # tf : dense.weight in "in_channel out_channel" + "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (768,256),(1,256,768) + "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (768,),(768,) + "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,256),(1,256,256) + "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + # ffn + "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (1024,256),(1,256,1024) + "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (1024,),(1024,) + "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, # (256,1024),(1,1024,256) + "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch): + {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + # out norm + "{}.after_norm.weight".format(tensor_name_prefix_torch): + {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + "{}.after_norm.bias".format(tensor_name_prefix_torch): + {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + } + if self.out_units is not None: + map_dict_local.update({ + "{}.output_linear.weight".format(tensor_name_prefix_torch): + {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf), + "squeeze": 0, + "transpose": (1, 0), + }, + "{}.output_linear.bias".format(tensor_name_prefix_torch): + {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, # (256,),(256,) + }) + + return map_dict_local + + def convert_tf2torch(self, + var_dict_tf, + var_dict_torch, + ): + + map_dict = self.gen_tf2torch_map_dict() + + var_dict_torch_update = dict() + for name in sorted(var_dict_torch.keys(), reverse=False): + if name.startswith(self.tf2torch_tensor_name_prefix_torch): + # process special (first and last) layers + if name in map_dict: + name_tf = map_dict[name]["name"] + data_tf = var_dict_tf[name_tf] + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + if map_dict[name]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) + if map_dict[name]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) + assert var_dict_torch[name].size() == data_tf.size(), \ + "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[name].size(), data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format( + name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape + )) + # process general layers + else: + # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case + names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.') + layeridx = int(names[2]) + name_q = name.replace(".{}.".format(layeridx), ".layeridx.") + if name_q in map_dict.keys(): + name_v = map_dict[name_q]["name"] + name_tf = name_v.replace("layeridx", "{}".format(layeridx)) + data_tf = var_dict_tf[name_tf] + if map_dict[name_q]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) + if map_dict[name_q]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), \ + "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[name].size(), data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format( + name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape + )) + else: + logging.warning("{} is missed from tf checkpoint".format(name)) + + return var_dict_torch_update diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py index 66e446c7a..952ce1597 100644 --- a/funasr/models/encoder/resnet34_encoder.py +++ b/funasr/models/encoder/resnet34_encoder.py @@ -1,7 +1,11 @@ import torch from torch.nn import functional as F from funasr.models.encoder.abs_encoder import AbsEncoder -from typing import Tuple +from typing import Tuple, Optional +from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling +from collections import OrderedDict +import logging +import numpy as np class BasicLayer(torch.nn.Module): @@ -116,10 +120,18 @@ class ResNet34(AbsEncoder): self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1) self.resnet0_bn = torch.nn.BatchNorm2d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum) + self.time_ds_ratio = 8 + def output_size(self) -> int: return self.num_nodes_pooling_layer - def forward(self, xs_pad: torch.Tensor, ilens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + features = xs_pad assert features.size(-1) == self.input_size, \ "Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size) @@ -141,4 +153,463 @@ class ResNet34(AbsEncoder): features = F.relu(features) features = self.resnet0_bn(features) - return features, ilens // 8 + return features, resnet_out_lens + +# Note: For training, this implement is not equivalent to tf because of the kernel_regularizer in tf.layers. +# TODO: implement kernel_regularizer in torch with munal loss addition or weigth_decay in the optimizer +class ResNet34_SP_L2Reg(AbsEncoder): + def __init__( + self, + input_size, + use_head_conv=True, + batchnorm_momentum=0.5, + use_head_maxpool=False, + num_nodes_pooling_layer=256, + layers_in_block=(3, 4, 6, 3), + filters_in_block=(32, 64, 128, 256), + tf2torch_tensor_name_prefix_torch="encoder", + tf2torch_tensor_name_prefix_tf="EAND/speech_encoder", + tf_train_steps=720000, + ): + super(ResNet34_SP_L2Reg, self).__init__() + + self.use_head_conv = use_head_conv + self.use_head_maxpool = use_head_maxpool + self.num_nodes_pooling_layer = num_nodes_pooling_layer + self.layers_in_block = layers_in_block + self.filters_in_block = filters_in_block + self.input_size = input_size + self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch + self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + self.tf_train_steps = tf_train_steps + + pre_filters = filters_in_block[0] + if use_head_conv: + self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros") + self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum) + + if use_head_maxpool: + self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1) + + for i in range(len(layers_in_block)): + if i == 0: + in_filters = pre_filters if self.use_head_conv else 1 + else: + in_filters = filters_in_block[i-1] + + block = BasicBlock(in_filters, + filters=filters_in_block[i], + num_layer=layers_in_block[i], + stride=1 if i == 0 else 2, + bn_momentum=batchnorm_momentum) + self.add_module("block_{}".format(i), block) + + self.resnet0_dense = torch.nn.Conv1d(filters_in_block[-1] * input_size // 8, num_nodes_pooling_layer, 1) + self.resnet0_bn = torch.nn.BatchNorm1d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum) + + self.time_ds_ratio = 8 + + def output_size(self) -> int: + return self.num_nodes_pooling_layer + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + + features = xs_pad + assert features.size(-1) == self.input_size, \ + "Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size) + features = torch.unsqueeze(features, dim=1) + if self.use_head_conv: + features = self.pre_conv(features) + features = self.pre_conv_bn(features) + features = F.relu(features) + + if self.use_head_maxpool: + features = self.head_maxpool(features) + + resnet_outs, resnet_out_lens = features, ilens + for i in range(len(self.layers_in_block)): + block = self._modules["block_{}".format(i)] + resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens) + + # B, C, T, F + bb, cc, tt, ff = resnet_outs.shape + resnet_outs = torch.reshape(resnet_outs.permute(0, 3, 1, 2), [bb, ff*cc, tt]) + features = self.resnet0_dense(resnet_outs) + features = F.relu(features) + features = self.resnet0_bn(features) + + return features, resnet_out_lens + + def gen_tf2torch_map_dict(self): + tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch + tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf + train_steps = self.tf_train_steps + map_dict_local = { + # torch: conv1d.weight in "out_channel in_channel kernel_size" + # tf : conv1d.weight in "kernel_size in_channel out_channel" + # torch: linear.weight in "out_channel in_channel" + # tf : dense.weight in "in_channel out_channel" + "{}.pre_conv.weight".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (3, 2, 0, 1), + }, + "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps + } + for layer_idx in range(3): + map_dict_local.update({ + "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": (2, 1, 0) if layer_idx == 0 else (1, 0), + }, + "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps + }) + + for block_idx in range(len(self.layers_in_block)): + for layer_idx in range(self.layers_in_block[block_idx]): + for i in ["1", "2", "_sc"]: + map_dict_local.update({ + "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": (3, 2, 0, 1), + }, + "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": None, + }, + "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": None, + }, + "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": None, + }, + "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": None, + }, + "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps, + }) + + return map_dict_local + + def convert_tf2torch(self, + var_dict_tf, + var_dict_torch, + ): + + map_dict = self.gen_tf2torch_map_dict() + + var_dict_torch_update = dict() + for name in sorted(var_dict_torch.keys(), reverse=False): + if name.startswith(self.tf2torch_tensor_name_prefix_torch): + if name in map_dict: + if "num_batches_tracked" not in name: + name_tf = map_dict[name]["name"] + data_tf = var_dict_tf[name_tf] + if map_dict[name]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) + if map_dict[name]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), \ + "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[name].size(), data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format( + name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape + )) + else: + var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu") + logging.info("torch tensor: {}, manually assigning to: {}".format( + name, map_dict[name] + )) + else: + logging.warning("{} is missed from tf checkpoint".format(name)) + + return var_dict_torch_update + + + +class ResNet34Diar(ResNet34): + def __init__( + self, + input_size, + embedding_node="resnet1_dense", + use_head_conv=True, + batchnorm_momentum=0.5, + use_head_maxpool=False, + num_nodes_pooling_layer=256, + layers_in_block=(3, 4, 6, 3), + filters_in_block=(32, 64, 128, 256), + num_nodes_resnet1=256, + num_nodes_last_layer=256, + pooling_type="window_shift", + pool_size=20, + stride=1, + tf2torch_tensor_name_prefix_torch="encoder", + tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder" + ): + super(ResNet34Diar, self).__init__( + input_size, + use_head_conv=use_head_conv, + batchnorm_momentum=batchnorm_momentum, + use_head_maxpool=use_head_maxpool, + num_nodes_pooling_layer=num_nodes_pooling_layer, + layers_in_block=layers_in_block, + filters_in_block=filters_in_block, + ) + + self.embedding_node = embedding_node + self.num_nodes_resnet1 = num_nodes_resnet1 + self.num_nodes_last_layer = num_nodes_last_layer + self.pooling_type = pooling_type + self.pool_size = pool_size + self.stride = stride + self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch + self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf + + self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1) + self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum) + + self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer) + self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum) + + def output_size(self) -> int: + if self.embedding_node.startswith("resnet1"): + return self.num_nodes_resnet1 + elif self.embedding_node.startswith("resnet2"): + return self.num_nodes_last_layer + + return self.num_nodes_pooling_layer + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + + endpoints = OrderedDict() + res_out, ilens = super().forward(xs_pad, ilens) + endpoints["resnet0_bn"] = res_out + if self.pooling_type == "frame_gsp": + features = statistic_pooling(res_out, ilens, (3, )) + else: + features, ilens = windowed_statistic_pooling(res_out, ilens, (2, 3), self.pool_size, self.stride) + features = features.transpose(1, 2) + endpoints["pooling"] = features + + features = self.resnet1_dense(features) + endpoints["resnet1_dense"] = features + features = F.relu(features) + endpoints["resnet1_relu"] = features + features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2) + endpoints["resnet1_bn"] = features + + features = self.resnet2_dense(features) + endpoints["resnet2_dense"] = features + features = F.relu(features) + endpoints["resnet2_relu"] = features + features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2) + endpoints["resnet2_bn"] = features + + return endpoints[self.embedding_node], ilens, None + + def gen_tf2torch_map_dict(self): + tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch + tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf + train_steps = 300000 + map_dict_local = { + # torch: conv1d.weight in "out_channel in_channel kernel_size" + # tf : conv1d.weight in "kernel_size in_channel out_channel" + # torch: linear.weight in "out_channel in_channel" + # tf : dense.weight in "in_channel out_channel" + "{}.pre_conv.weight".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": (3, 2, 0, 1), + }, + "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch): + {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf), + "squeeze": None, + "transpose": None, + }, + "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps + } + for layer_idx in range(3): + map_dict_local.update({ + "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": (3, 2, 0, 1) if layer_idx == 0 else (1, 0), + }, + "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx): + {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx), + "squeeze": None, + "transpose": None, + }, + "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps + }) + + for block_idx in range(len(self.layers_in_block)): + for layer_idx in range(self.layers_in_block[block_idx]): + for i in ["1", "2", "_sc"]: + map_dict_local.update({ + "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": (3, 2, 0, 1), + }, + "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": None, + }, + "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": None, + }, + "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": None, + }, + "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i): + {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i), + "squeeze": None, + "transpose": None, + }, + "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps, + }) + + return map_dict_local + + def convert_tf2torch(self, + var_dict_tf, + var_dict_torch, + ): + + map_dict = self.gen_tf2torch_map_dict() + + var_dict_torch_update = dict() + for name in sorted(var_dict_torch.keys(), reverse=False): + if name.startswith(self.tf2torch_tensor_name_prefix_torch): + if name in map_dict: + if "num_batches_tracked" not in name: + name_tf = map_dict[name]["name"] + data_tf = var_dict_tf[name_tf] + if map_dict[name]["squeeze"] is not None: + data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) + if map_dict[name]["transpose"] is not None: + data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) + data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") + assert var_dict_torch[name].size() == data_tf.size(), \ + "{}, {}, {} != {}".format(name, name_tf, + var_dict_torch[name].size(), data_tf.size()) + var_dict_torch_update[name] = data_tf + logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format( + name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape + )) + else: + var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu") + logging.info("torch tensor: {}, manually assigning to: {}".format( + name, map_dict[name] + )) + else: + logging.warning("{} is missed from tf checkpoint".format(name)) + + return var_dict_torch_update diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py index 57c5976f1..7a6425be3 100644 --- a/funasr/models/frontend/wav_frontend.py +++ b/funasr/models/frontend/wav_frontend.py @@ -90,7 +90,9 @@ class WavFrontend(AbsFrontend): filter_length_max: int = -1, lfr_m: int = 1, lfr_n: int = 1, - dither: float = 1.0 + dither: float = 1.0, + snip_edges: bool = True, + upsacle_samples: bool = True, ): assert check_argument_types() super().__init__() @@ -105,6 +107,8 @@ class WavFrontend(AbsFrontend): self.lfr_n = lfr_n self.cmvn_file = cmvn_file self.dither = dither + self.snip_edges = snip_edges + self.upsacle_samples = upsacle_samples def output_size(self) -> int: return self.n_mels * self.lfr_m @@ -119,7 +123,8 @@ class WavFrontend(AbsFrontend): for i in range(batch_size): waveform_length = input_lengths[i] waveform = input[i][:waveform_length] - waveform = waveform * (1 << 15) + if self.upsacle_samples: + waveform = waveform * (1 << 15) waveform = waveform.unsqueeze(0) mat = kaldi.fbank(waveform, num_mel_bins=self.n_mels, @@ -128,7 +133,8 @@ class WavFrontend(AbsFrontend): dither=self.dither, energy_floor=0.0, window_type=self.window, - sample_frequency=self.fs) + sample_frequency=self.fs, + snip_edges=self.snip_edges) if self.lfr_m != 1 or self.lfr_n != 1: mat = apply_lfr(mat, self.lfr_m, self.lfr_n) diff --git a/funasr/models/pooling/statistic_pooling.py b/funasr/models/pooling/statistic_pooling.py index eeaed7d5d..dc8c98f0d 100644 --- a/funasr/models/pooling/statistic_pooling.py +++ b/funasr/models/pooling/statistic_pooling.py @@ -2,7 +2,10 @@ import torch from typing import Tuple from typing import Union from funasr.modules.nets_utils import make_non_pad_mask +from torch.nn import functional as F +import math +VAR2STD_EPSILON = 1e-12 class StatisticPooling(torch.nn.Module): def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12): @@ -34,3 +37,59 @@ class StatisticPooling(torch.nn.Module): stat_pooling = torch.cat([mean, stddev], dim=1) return stat_pooling + + def convert_tf2torch(self, var_dict_tf, var_dict_torch): + return {} + + +def statistic_pooling( + xs_pad: torch.Tensor, + ilens: torch.Tensor = None, + pooling_dim: Tuple = (2, 3) +) -> torch.Tensor: + # xs_pad in (Batch, Channel, Time, Frequency) + + if ilens is None: + seq_mask = torch.ones_like(xs_pad).to(xs_pad) + else: + seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad) + mean = (torch.sum(xs_pad, dim=pooling_dim, keepdim=True) / + torch.sum(seq_mask, dim=pooling_dim, keepdim=True)) + squared_difference = torch.pow(xs_pad - mean, 2.0) + variance = (torch.sum(squared_difference, dim=pooling_dim, keepdim=True) / + torch.sum(seq_mask, dim=pooling_dim, keepdim=True)) + for i in reversed(pooling_dim): + mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i) + + value_mask = torch.less_equal(variance, VAR2STD_EPSILON).float() + variance = (1.0 - value_mask) * variance + value_mask * VAR2STD_EPSILON + stddev = torch.sqrt(variance) + + stat_pooling = torch.cat([mean, stddev], dim=1) + + return stat_pooling + + +def windowed_statistic_pooling( + xs_pad: torch.Tensor, + ilens: torch.Tensor = None, + pooling_dim: Tuple = (2, 3), + pooling_size: int = 20, + pooling_stride: int = 1 +) -> Tuple[torch.Tensor, int]: + # xs_pad in (Batch, Channel, Time, Frequency) + + tt = xs_pad.shape[2] + num_chunk = int(math.ceil(tt / pooling_stride)) + pad = pooling_size // 2 + features = F.pad(xs_pad, (0, 0, pad, pad), "reflect") + stat_list = [] + + for i in range(num_chunk): + # B x C + st, ed = i*pooling_stride, i*pooling_stride+pooling_size + stat = statistic_pooling(features[:, :, st: ed, :], pooling_dim=pooling_dim) + stat_list.append(stat.unsqueeze(2)) + + # B x C x T + return torch.cat(stat_list, dim=2), ilens / pooling_stride diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py index e3ad56a5a..c47d96d06 100644 --- a/funasr/modules/attention.py +++ b/funasr/modules/attention.py @@ -622,4 +622,108 @@ class MultiHeadedAttentionCrossAtt(nn.Module): q_h, k_h, v_h = self.forward_qkv(x, memory) q_h = q_h * self.d_k ** (-0.5) scores = torch.matmul(q_h, k_h.transpose(-2, -1)) - return self.forward_attention(v_h, scores, memory_mask) \ No newline at end of file + return self.forward_attention(v_h, scores, memory_mask) + + +class MultiHeadSelfAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, n_head, in_feat, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadSelfAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_out = nn.Linear(n_feat, n_feat) + self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, x): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + b, t, d = x.size() + q_k_v = self.linear_q_k_v(x) + q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) + q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k) + k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) + v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) + + return q_h, k_h, v_h, v + + def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + if mask_att_chunk_encoder is not None: + mask = mask * mask_att_chunk_encoder + + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + + min_value = float( + numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min + ) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, x, mask, mask_att_chunk_encoder=None): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q_h, k_h, v_h, v = self.forward_qkv(x) + q_h = q_h * self.d_k ** (-0.5) + scores = torch.matmul(q_h, k_h.transpose(-2, -1)) + att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) + return att_outs diff --git a/funasr/modules/multi_layer_conv.py b/funasr/modules/multi_layer_conv.py index 5fb0717b0..9d269ab56 100644 --- a/funasr/modules/multi_layer_conv.py +++ b/funasr/modules/multi_layer_conv.py @@ -63,6 +63,58 @@ class MultiLayeredConv1d(torch.nn.Module): return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) +class FsmnFeedForward(torch.nn.Module): + """Position-wise feed forward for FSMN blocks. + + This is a module of multi-leyered conv1d designed + to replace position-wise feed-forward network + in FSMN block. + """ + + def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate): + """Initialize FsmnFeedForward module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + out_chans (int): Number of output channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + """ + super(FsmnFeedForward, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Conv1d( + hidden_chans, + out_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + self.norm = torch.nn.LayerNorm(hidden_chans) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x, ilens=None): + """Calculate forward propagation. + + Args: + x (torch.Tensor): Batch of input tensors (B, T, in_chans). + + Returns: + torch.Tensor: Batch of output tensors (B, T, out_chans). + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens + + class Conv1dLinear(torch.nn.Module): """Conv1D + Linear for Transformer block. diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py new file mode 100644 index 000000000..f3212f1b1 --- /dev/null +++ b/funasr/tasks/diar.py @@ -0,0 +1,585 @@ +import argparse +import logging +import os +from pathlib import Path +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import torch +import yaml +from typeguard import check_argument_types +from typeguard import check_return_type + +from funasr.datasets.collate_fn import CommonCollateFn +from funasr.datasets.preprocessor import CommonPreprocessor +from funasr.layers.abs_normalize import AbsNormalize +from funasr.layers.global_mvn import GlobalMVN +from funasr.layers.utterance_mvn import UtteranceMVN +from funasr.layers.label_aggregation import LabelAggregate +from funasr.models.ctc import CTC +from funasr.models.encoder.resnet34_encoder import ResNet34Diar +from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder +from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder +from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder +from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer +from funasr.models.e2e_diar_sond import DiarSondModel +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.models.encoder.conformer_encoder import ConformerEncoder +from funasr.models.encoder.data2vec_encoder import Data2VecEncoder +from funasr.models.encoder.rnn_encoder import RNNEncoder +from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt +from funasr.models.encoder.transformer_encoder import TransformerEncoder +from funasr.models.frontend.abs_frontend import AbsFrontend +from funasr.models.frontend.default import DefaultFrontend +from funasr.models.frontend.fused import FusedFrontends +from funasr.models.frontend.s3prl import S3prlFrontend +from funasr.models.frontend.wav_frontend import WavFrontend +from funasr.models.frontend.windowing import SlidingWindow +from funasr.models.postencoder.abs_postencoder import AbsPostEncoder +from funasr.models.postencoder.hugging_face_transformers_postencoder import ( + HuggingFaceTransformersPostEncoder, # noqa: H301 +) +from funasr.models.preencoder.abs_preencoder import AbsPreEncoder +from funasr.models.preencoder.linear import LinearProjection +from funasr.models.preencoder.sinc import LightweightSincConvs +from funasr.models.specaug.abs_specaug import AbsSpecAug +from funasr.models.specaug.specaug import SpecAug +from funasr.models.specaug.specaug import SpecAugLFR +from funasr.tasks.abs_task import AbsTask +from funasr.torch_utils.initialize import initialize +from funasr.train.abs_espnet_model import AbsESPnetModel +from funasr.train.class_choices import ClassChoices +from funasr.train.trainer import Trainer +from funasr.utils.types import float_or_none +from funasr.utils.types import int_or_none +from funasr.utils.types import str2bool +from funasr.utils.types import str_or_none + +frontend_choices = ClassChoices( + name="frontend", + classes=dict( + default=DefaultFrontend, + sliding_window=SlidingWindow, + s3prl=S3prlFrontend, + fused=FusedFrontends, + wav_frontend=WavFrontend, + ), + type_check=AbsFrontend, + default="default", +) +specaug_choices = ClassChoices( + name="specaug", + classes=dict( + specaug=SpecAug, + specaug_lfr=SpecAugLFR, + ), + type_check=AbsSpecAug, + default=None, + optional=True, +) +normalize_choices = ClassChoices( + "normalize", + classes=dict( + global_mvn=GlobalMVN, + utterance_mvn=UtteranceMVN, + ), + type_check=AbsNormalize, + default=None, + optional=True, +) +label_aggregator_choices = ClassChoices( + "label_aggregator", + classes=dict( + label_aggregator=LabelAggregate + ), + type_check=torch.nn.Module, + default=None, + optional=True, +) +model_choices = ClassChoices( + "model", + classes=dict( + sond=DiarSondModel, + ), + type_check=AbsESPnetModel, + default="sond", +) +encoder_choices = ClassChoices( + "encoder", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + rnn=RNNEncoder, + sanm=SANMEncoder, + san=SelfAttentionEncoder, + fsmn=FsmnEncoder, + conv=ConvEncoder, + resnet34=ResNet34Diar, + sanm_chunk_opt=SANMEncoderChunkOpt, + data2vec_encoder=Data2VecEncoder, + ), + type_check=AbsEncoder, + default="resnet34", +) +speaker_encoder_choices = ClassChoices( + "speaker_encoder", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + rnn=RNNEncoder, + sanm=SANMEncoder, + san=SelfAttentionEncoder, + fsmn=FsmnEncoder, + conv=ConvEncoder, + sanm_chunk_opt=SANMEncoderChunkOpt, + data2vec_encoder=Data2VecEncoder, + ), + type_check=AbsEncoder, + default=None, + optional=True +) +cd_scorer_choices = ClassChoices( + "cd_scorer", + classes=dict( + san=SelfAttentionEncoder, + ), + type_check=AbsEncoder, + default=None, + optional=True, +) +ci_scorer_choices = ClassChoices( + "ci_scorer", + classes=dict( + dot=DotScorer, + cosine=CosScorer, + ), + type_check=torch.nn.Module, + default=None, + optional=True, +) +# decoder is used for output (e.g. post_net in SOND) +decoder_choices = ClassChoices( + "decoder", + classes=dict( + rnn=RNNEncoder, + fsmn=FsmnEncoder, + ), + type_check=torch.nn.Module, + default="fsmn", +) + + +class DiarTask(AbsTask): + # If you need more than 1 optimizer, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + specaug_choices, + # --normalize and --normalize_conf + normalize_choices, + # --model and --model_conf + model_choices, + # --encoder and --encoder_conf + encoder_choices, + # --speaker_encoder and --speaker_encoder_conf + speaker_encoder_choices, + # --cd_scorer and cd_scorer_conf + cd_scorer_choices, + # --ci_scorer and ci_scorer_conf + ci_scorer_choices, + # --decoder and --decoder_conf + decoder_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + # required = parser.get_default("required") + # required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--split_with_space", + type=str2bool, + default=True, + help="whether to split text using ", + ) + group.add_argument( + "--seg_dict_file", + type=str, + default=None, + help="seg_dict_file for text processing", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="char", + choices=["char"], + help="The text will be tokenized in the specified level token", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Scale the maximum amplitude to the given value.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The file path of rir scp file.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="THe probability for applying RIR convolution.", + ) + parser.add_argument( + "--cmvn_file", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability applying Noise adding.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of noise decibel level.", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=None, + non_linguistic_symbols=None, + text_cleaner=None, + g2p_type=None, + split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, + seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech", "profile", "label") + else: + # Recognition mode + retval = ("speech", "profile") + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace): + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size}") + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + if args.frontend == 'wav_frontend': + frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) + else: + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size=input_size, **args.encoder_conf) + + # 5. speaker encoder + if getattr(args, "speaker_encoder", None) is not None: + speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder) + speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf) + else: + speaker_encoder = None + + # 6. CI & CD scorer + if getattr(args, "ci_scorer", None) is not None: + ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer) + ci_scorer = ci_scorer_class(**args.ci_scorer_conf) + else: + ci_scorer = None + + if getattr(args, "cd_scorer", None) is not None: + cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer) + cd_scorer = cd_scorer_class(**args.cd_scorer_conf) + else: + cd_scorer = None + + # 7. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + decoder = decoder_class(**args.decoder_conf) + + if getattr(args, "label_aggregator", None) is not None: + label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator) + label_aggregator = label_aggregator_class(**args.label_aggregator_conf) + else: + label_aggregator = None + + # 9. Build model + model_class = model_choices.get_class(args.model) + model = model_class( + vocab_size=vocab_size, + frontend=frontend, + specaug=specaug, + normalize=normalize, + label_aggregator=label_aggregator, + encoder=encoder, + speaker_encoder=speaker_encoder, + ci_scorer=ci_scorer, + cd_scorer=cd_scorer, + decoder=decoder, + token_list=token_list, + **args.model_conf, + ) + + # 10. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model + + # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ + @classmethod + def build_model_from_file( + cls, + config_file: Union[Path, str] = None, + model_file: Union[Path, str] = None, + cmvn_file: Union[Path, str] = None, + device: str = "cpu", + ): + """Build model from the files. + + This method is used for inference or fine-tuning. + + Args: + config_file: The yaml file saved when training. + model_file: The model file saved when training. + cmvn_file: The cmvn file for front-end + device: Device type, "cpu", "cuda", or "cuda:N". + + """ + assert check_argument_types() + if config_file is None: + assert model_file is not None, ( + "The argument 'model_file' must be provided " + "if the argument 'config_file' is not specified." + ) + config_file = Path(model_file).parent / "config.yaml" + else: + config_file = Path(config_file) + + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + if cmvn_file is not None: + args["cmvn_file"] = cmvn_file + args = argparse.Namespace(**args) + model = cls.build_model(args) + if not isinstance(model, AbsESPnetModel): + raise RuntimeError( + f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" + ) + model.to(device) + model_dict = dict() + model_name_pth = None + if model_file is not None: + logging.info("model_file is {}".format(model_file)) + if device == "cuda": + device = f"cuda:{torch.cuda.current_device()}" + model_dir = os.path.dirname(model_file) + model_name = os.path.basename(model_file) + if "model.ckpt-" in model_name or ".bin" in model_name: + if ".bin" in model_name: + model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb')) + else: + model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name)) + if os.path.exists(model_name_pth): + logging.info("model_file is load from pth: {}".format(model_name_pth)) + model_dict = torch.load(model_name_pth, map_location=device) + else: + model_dict = cls.convert_tf2torch(model, model_file) + model.load_state_dict(model_dict) + else: + model_dict = torch.load(model_file, map_location=device) + model.load_state_dict(model_dict) + if model_name_pth is not None and not os.path.exists(model_name_pth): + torch.save(model_dict, model_name_pth) + logging.info("model_file is saved to pth: {}".format(model_name_pth)) + + return model, args + + @classmethod + def convert_tf2torch( + cls, + model, + ckpt, + ): + logging.info("start convert tf model to torch model") + from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict + var_dict_tf = load_tf_dict(ckpt) + var_dict_torch = model.state_dict() + var_dict_torch_update = dict() + # speech encoder + var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # speaker encoder + var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # cd scorer + var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # ci scorer + var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + # decoder + var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) + + return var_dict_torch_update diff --git a/funasr/utils/job_runner.py b/funasr/utils/job_runner.py new file mode 100644 index 000000000..a35d49c4c --- /dev/null +++ b/funasr/utils/job_runner.py @@ -0,0 +1,103 @@ +from __future__ import print_function +from multiprocessing import Pool +import argparse +from tqdm import tqdm +import math + + +class MultiProcessRunner: + def __init__(self, fn): + self.args = None + self.process = fn + + def run(self): + parser = argparse.ArgumentParser("") + # Task-independent options + parser.add_argument("--nj", type=int, default=16) + parser.add_argument("--debug", action="store_true", default=False) + parser.add_argument("--no_pbar", action="store_true", default=False) + parser.add_argument("--verbose", action="store_ture", default=False) + + task_list, args = self.prepare(parser) + result_list = self.pool_run(task_list, args) + self.post(result_list, args) + + def prepare(self, parser): + raise NotImplementedError("Please implement the prepare function.") + + def post(self, result_list, args): + raise NotImplementedError("Please implement the post function.") + + def pool_run(self, tasks, args): + results = [] + if args.debug: + one_result = self.process(tasks[0]) + results.append(one_result) + else: + pool = Pool(args.nj) + for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar): + results.append(one_result) + pool.close() + + return results + + +class MultiProcessRunnerV2: + def __init__(self, fn): + self.args = None + self.process = fn + + def run(self): + parser = argparse.ArgumentParser("") + # Task-independent options + parser.add_argument("--nj", type=int, default=16) + parser.add_argument("--debug", action="store_true", default=False) + parser.add_argument("--no_pbar", action="store_true", default=False) + parser.add_argument("--verbose", action="store_true", default=False) + + task_list, args = self.prepare(parser) + chunk_size = int(math.ceil(float(len(task_list)) / args.nj)) + if args.verbose: + print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size)) + subtask_list = [task_list[i*chunk_size: (i+1)*chunk_size] for i in range(args.nj)] + result_list = self.pool_run(subtask_list, args) + self.post(result_list, args) + + def prepare(self, parser): + raise NotImplementedError("Please implement the prepare function.") + + def post(self, result_list, args): + raise NotImplementedError("Please implement the post function.") + + def pool_run(self, tasks, args): + results = [] + if args.debug: + one_result = self.process(tasks[0]) + results.append(one_result) + else: + pool = Pool(args.nj) + for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar): + results.append(one_result) + pool.close() + + return results + + +class MultiProcessRunnerV3(MultiProcessRunnerV2): + def run(self): + parser = argparse.ArgumentParser("") + # Task-independent options + parser.add_argument("--nj", type=int, default=16) + parser.add_argument("--debug", action="store_true", default=False) + parser.add_argument("--no_pbar", action="store_true", default=False) + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--sr", type=int, default=16000) + + task_list, shared_param, args = self.prepare(parser) + chunk_size = int(math.ceil(float(len(task_list)) / args.nj)) + if args.verbose: + print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size)) + subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args) + for i in range(args.nj)] + result_list = self.pool_run(subtask_list, args) + self.post(result_list, args) diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py new file mode 100644 index 000000000..f27a63c4e --- /dev/null +++ b/funasr/utils/misc.py @@ -0,0 +1,48 @@ +import io +from collections import OrderedDict +import numpy as np + + +def statistic_model_parameters(model, prefix=None): + var_dict = model.state_dict() + numel = 0 + for i, key in enumerate(sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))): + if prefix is None or key.startswith(prefix): + numel += var_dict[key].numel() + return numel + + +def int2vec(x, vec_dim=8, dtype=np.int): + b = ('{:0' + str(vec_dim) + 'b}').format(x) + # little-endian order: lower bit first + return (np.array(list(b)[::-1]) == '1').astype(dtype) + + +def seq2arr(seq, vec_dim=8): + return np.row_stack([int2vec(int(x), vec_dim) for x in seq]) + + +def load_scp_as_dict(scp_path, value_type='str', kv_sep=" "): + with io.open(scp_path, 'r', encoding='utf-8') as f: + ret_dict = OrderedDict() + for one_line in f.readlines(): + one_line = one_line.strip() + pos = one_line.find(kv_sep) + key, value = one_line[:pos], one_line[pos + 1:] + if value_type == 'list': + value = value.split(' ') + ret_dict[key] = value + return ret_dict + + +def load_scp_as_list(scp_path, value_type='str', kv_sep=" "): + with io.open(scp_path, 'r', encoding='utf8') as f: + ret_dict = [] + for one_line in f.readlines(): + one_line = one_line.strip() + pos = one_line.find(kv_sep) + key, value = one_line[:pos], one_line[pos + 1:] + if value_type == 'list': + value = value.split(' ') + ret_dict.append((key, value)) + return ret_dict