From ab828bcf7badb228fdc59647f5c9c75e33acce9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Fri, 17 Feb 2023 00:04:14 +0800 Subject: [PATCH] add scripts for simu data --- .../scripts/extract_nonoverlap_segments_v2.py | 4 - egs/mars/sd/scripts/simu_chunk_with_labels.py | 83 +++++++++++++++++++ 2 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 egs/mars/sd/scripts/simu_chunk_with_labels.py diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py index 56ad78702..cd1ec7b34 100644 --- a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py +++ b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py @@ -1,13 +1,9 @@ import numpy as np import os -import sys import argparse from funasr.utils.job_runner import MultiProcessRunnerV3 from funasr.utils.misc import load_scp_as_list, load_scp_as_dict -import librosa import soundfile as sf -from copy import deepcopy -import json from tqdm import tqdm diff --git a/egs/mars/sd/scripts/simu_chunk_with_labels.py b/egs/mars/sd/scripts/simu_chunk_with_labels.py new file mode 100644 index 000000000..d1d9a2ff1 --- /dev/null +++ b/egs/mars/sd/scripts/simu_chunk_with_labels.py @@ -0,0 +1,83 @@ +import logging +import numpy as np +import soundfile +import kaldiio +from funasr.utils.job_runner import MultiProcessRunnerV3 +from funasr.utils.misc import load_scp_as_list, load_scp_as_dict +import os +import argparse +from collections import OrderedDict +import random + + +class MyRunner(MultiProcessRunnerV3): + + def prepare(self, parser: argparse.ArgumentParser): + parser.add_argument("--label_scp", type=str, required=True) + parser.add_argument("--wav_scp", type=str, required=True) + parser.add_argument("--utt2spk", type=str, required=True) + parser.add_argument("--spk2meeting", type=str, required=True) + parser.add_argument("--utt2xvec", type=str, required=True) + parser.add_argument("--out_dir", type=str, required=True) + parser.add_argument("--chunk_size", type=int, default=16) + parser.add_argument("--chunk_shift", type=int, default=4) + parser.add_argument("--frame_shift", type=float, default=0.01) + args = parser.parse_args() + + if not os.path.exists(args.out_dir): + os.makedirs(args.out_dir) + + label_list = load_scp_as_list(args.label_scp) + wav_scp = load_scp_as_dict(args.wav_scp) + utt2spk = load_scp_as_dict(args.utt2spk) + utt2xvec = load_scp_as_dict(args.utt2xvec) + spk2meeting = load_scp_as_dict(args.spk2meeting) + + meeting2spks = OrderedDict() + for spk, meeting in spk2meeting.items(): + if meeting not in meeting2spks: + meeting2spks[meeting] = [] + meeting2spks[meeting].append(spk) + + spk2utts = OrderedDict() + for utt, spk in utt2spk.items(): + if spk not in spk2utts: + spk2utts[spk] = [] + spk2utts[spk].append(utt) + + return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args + + def post(self, results_list, args): + pass + + +def process(task_args): + task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args + out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1)) + wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) + + out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1)) + wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) + + out_path = os.path.join(args.out_dir, "label.{}".format(task_idx + 1)) + label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path)) + + idx = 0 + for _, label_path in task_list: + rand_shift = random.randint(0, int(args.chunk_shift / args.frame_shift)) + whole_label = kaldiio.load_mat(label_path) + whole_label = whole_label[rand_shift:] + num_chunk = (whole_label.shape[0] - args.chunk_size) // args.chunk_shift + for i in range(num_chunk): + utt_id = "part{}_utt{:10d}".format(task_idx + 1, idx + 1) + + + wav_mix_writer.close() + wav_sep_writer.close() + label_writer.close() + return None + + +if __name__ == '__main__': + my_runner = MyRunner(process) + my_runner.run()