diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments.py b/egs/mars/sd/scripts/extract_nonoverlap_segments.py new file mode 100644 index 000000000..ff7208664 --- /dev/null +++ b/egs/mars/sd/scripts/extract_nonoverlap_segments.py @@ -0,0 +1,88 @@ +from __future__ import print_function +import numpy as np +import os +import sys +import argparse +from funasr.utils.job_runner import MultiProcessRunnerV3 +from funasr.utils.misc import load_scp_as_list, load_scp_as_dict +import librosa +import soundfile as sf +from copy import deepcopy +import json +from tqdm import tqdm + + +class MyRunner(MultiProcessRunnerV3): + def prepare(self, parser): + assert isinstance(parser, argparse.ArgumentParser) + parser.add_argument("wav_scp", type=str) + parser.add_argument("rttm_scp", type=str) + parser.add_argument("out_dir", type=str) + parser.add_argument("--min_dur", type=float, default=2.0) + parser.add_argument("--max_spk_num", type=int, default=4) + args = parser.parse_args() + + if not os.path.exists(args.out_dir): + os.makedirs(args.out_dir) + + wav_scp = load_scp_as_list(args.wav_scp) + rttm_scp = load_scp_as_dict(args.rttm_scp) + task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp] + return task_list, None, args + + def post(self, result_list, args): + pass + + +# SPEAKER R8001_M8004_MS801 1 6.90 11.39 1 +def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4): + labels = np.zeros([max_spk_num, length], int) + spk_list = [] + for one_line in open(rttm_path, 'rt'): + parts = one_line.strip().split(" ") + mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] + if spk_name.isdigit(): + spk_name = "{}_S{:03d}".format(mid, int(spk_name)) + if spk_name not in spk_list: + spk_list.append(spk_name) + st, dur = int(st*sr), int(dur*sr) + idx = spk_list.index(spk_name) + labels[idx, st:st+dur] = 1 + return labels, spk_list + + +def get_nonoverlap_turns(multi_label, spk_list): + turns = [] + label = np.sum(multi_label, axis=0) == 1 + spk, in_turn, st = None, False, 0 + for i in range(len(label)): + if not in_turn and label[i]: + st, in_turn = i, True + spk = spk_list[np.argmax(multi_label[:, i], axis=0)] + if in_turn and not label[i]: + in_turn = False + turns.append([st, i, spk]) + return turns + + +def process(task_args): + task_id, task_list, _, args = task_args + for mid, wav_path, rttm_path in task_list: + wav = librosa.load(wav_path, args.sr)[0] * 32767 + multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num) + turns = get_nonoverlap_turns(multi_label, spk_list) + count = 1 + for st, ed, spk in tqdm(turns, total=len(turns), ascii=True): + if (ed - st) >= args.min_dur * args.sr: + seg = wav[st: ed] + save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count)) + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) + sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True) + count += 1 + return None + + +if __name__ == '__main__': + my_runner = MyRunner(process) + my_runner.run()