From fe5c955063e561eb4cd14e050f15dc43e79f20ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Thu, 23 Feb 2023 22:28:45 +0800 Subject: [PATCH] sond pipeline --- .../extract_nonoverlap_segments.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py index ff7208664..1d6f53e92 100644 --- a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py +++ b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py @@ -31,7 +31,11 @@ class MyRunner(MultiProcessRunnerV3): return task_list, None, args def post(self, result_list, args): - pass + count = [0, 0] + for result in result_list: + count[0] += result[0] + count[1] += result[1] + print("Found {} speakers, extracted {}.".format(count[1], count[0])) # SPEAKER R8001_M8004_MS801 1 6.90 11.39 1 @@ -59,18 +63,28 @@ def get_nonoverlap_turns(multi_label, spk_list): if not in_turn and label[i]: st, in_turn = i, True spk = spk_list[np.argmax(multi_label[:, i], axis=0)] - if in_turn and not label[i]: - in_turn = False - turns.append([st, i, spk]) + if in_turn: + if not label[i]: + in_turn = False + turns.append([st, i, spk]) + elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]: + turns.append([st, i, spk]) + st, in_turn = i, True + spk = spk_list[np.argmax(multi_label[:, i], axis=0)] + if in_turn: + turns.append([st, len(label), spk]) return turns def process(task_args): task_id, task_list, _, args = task_args + spk_count = [0, 0] for mid, wav_path, rttm_path in task_list: - wav = librosa.load(wav_path, args.sr)[0] * 32767 + wav, sr = sf.read(wav_path, dtype="int16") + assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr) multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num) turns = get_nonoverlap_turns(multi_label, spk_list) + extracted_spk = [] count = 1 for st, ed, spk in tqdm(turns, total=len(turns), ascii=True): if (ed - st) >= args.min_dur * args.sr: @@ -80,7 +94,15 @@ def process(task_args): os.makedirs(os.path.dirname(save_path)) sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) count += 1 - return None + if spk not in extracted_spk: + extracted_spk.append(spk) + if len(extracted_spk) != len(spk_list): + print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format( + mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk]) + )) + spk_count[0] += len(extracted_spk) + spk_count[1] += len(spk_list) + return spk_count if __name__ == '__main__':