update asr_spk inference for shot utt

This commit is contained in:
shixian.shi 2023-11-24 14:29:33 +08:00
parent 34b2682fba
commit 72fecc8e03
2 changed files with 25 additions and 19 deletions

View File

@ -956,24 +956,29 @@ def inference_paraformer_vad_speaker(
ed = int(vadsegment[1]) / 1000
vad_segments.append(
[st, ed, audio[int(st * 16000):int(ed * 16000)]])
check_audio_list(vad_segments)
# sv pipeline
segments = sv_chunk(vad_segments)
embeddings = []
for s in segments:
#_, embs = self.sv_pipeline([s[2]], output_emb=True)
# embeddings.append(embs)
wavs = sv_preprocess([s[2]])
# embs = self.forward(wavs)
embs = []
for x in wavs:
x = extract_feature([x])
embs.append(sv_model(x))
embs = torch.cat(embs)
embeddings.append(embs.detach().numpy())
embeddings = np.concatenate(embeddings)
labels = cb_model(embeddings)
sv_output = postprocess(segments, vad_segments, labels, embeddings)
audio_dur = check_audio_list(vad_segments)
if audio_dur > 5:
# sv pipeline
segments = sv_chunk(vad_segments)
embeddings = []
for s in segments:
#_, embs = self.sv_pipeline([s[2]], output_emb=True)
# embeddings.append(embs)
wavs = sv_preprocess([s[2]])
# embs = self.forward(wavs)
embs = []
for x in wavs:
x = extract_feature([x])
embs.append(sv_model(x))
embs = torch.cat(embs)
embeddings.append(embs.detach().numpy())
embeddings = np.concatenate(embeddings)
labels = cb_model(embeddings)
sv_output = postprocess(segments, vad_segments, labels, embeddings)
else:
# fake speaker res for too shot utterance
sv_output = [[0.0, vadsegments[-1][-1]/1000.0, 0]]
logging.warning("Too short utterence found: {}, return default speaker results.".format(keys))
speech, speech_lengths = batch["speech"], batch["speech_lengths"]

View File

@ -35,7 +35,8 @@ def check_audio_list(audio: list):
assert seg[0] >= audio[
i - 1][1], 'modelscope error: Wrong time stamps.'
audio_dur += seg[1] - seg[0]
assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
return audio_dur
# assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
def sv_preprocess(inputs: Union[np.ndarray, list]):