mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
sond pipeline
This commit is contained in:
parent
2d015034a3
commit
fe5c955063
@ -31,7 +31,11 @@ class MyRunner(MultiProcessRunnerV3):
|
||||
return task_list, None, args
|
||||
|
||||
def post(self, result_list, args):
|
||||
pass
|
||||
count = [0, 0]
|
||||
for result in result_list:
|
||||
count[0] += result[0]
|
||||
count[1] += result[1]
|
||||
print("Found {} speakers, extracted {}.".format(count[1], count[0]))
|
||||
|
||||
|
||||
# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
|
||||
@ -59,18 +63,28 @@ def get_nonoverlap_turns(multi_label, spk_list):
|
||||
if not in_turn and label[i]:
|
||||
st, in_turn = i, True
|
||||
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
|
||||
if in_turn and not label[i]:
|
||||
in_turn = False
|
||||
turns.append([st, i, spk])
|
||||
if in_turn:
|
||||
if not label[i]:
|
||||
in_turn = False
|
||||
turns.append([st, i, spk])
|
||||
elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
|
||||
turns.append([st, i, spk])
|
||||
st, in_turn = i, True
|
||||
spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
|
||||
if in_turn:
|
||||
turns.append([st, len(label), spk])
|
||||
return turns
|
||||
|
||||
|
||||
def process(task_args):
|
||||
task_id, task_list, _, args = task_args
|
||||
spk_count = [0, 0]
|
||||
for mid, wav_path, rttm_path in task_list:
|
||||
wav = librosa.load(wav_path, args.sr)[0] * 32767
|
||||
wav, sr = sf.read(wav_path, dtype="int16")
|
||||
assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
|
||||
multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num)
|
||||
turns = get_nonoverlap_turns(multi_label, spk_list)
|
||||
extracted_spk = []
|
||||
count = 1
|
||||
for st, ed, spk in tqdm(turns, total=len(turns), ascii=True):
|
||||
if (ed - st) >= args.min_dur * args.sr:
|
||||
@ -80,7 +94,15 @@ def process(task_args):
|
||||
os.makedirs(os.path.dirname(save_path))
|
||||
sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
|
||||
count += 1
|
||||
return None
|
||||
if spk not in extracted_spk:
|
||||
extracted_spk.append(spk)
|
||||
if len(extracted_spk) != len(spk_list):
|
||||
print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
|
||||
mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
|
||||
))
|
||||
spk_count[0] += len(extracted_spk)
|
||||
spk_count[1] += len(spk_list)
|
||||
return spk_count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Loading…
Reference in New Issue
Block a user