mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update asr_spk inference for shot utt
This commit is contained in:
parent
34b2682fba
commit
72fecc8e03
@ -956,24 +956,29 @@ def inference_paraformer_vad_speaker(
|
||||
ed = int(vadsegment[1]) / 1000
|
||||
vad_segments.append(
|
||||
[st, ed, audio[int(st * 16000):int(ed * 16000)]])
|
||||
check_audio_list(vad_segments)
|
||||
# sv pipeline
|
||||
segments = sv_chunk(vad_segments)
|
||||
embeddings = []
|
||||
for s in segments:
|
||||
#_, embs = self.sv_pipeline([s[2]], output_emb=True)
|
||||
# embeddings.append(embs)
|
||||
wavs = sv_preprocess([s[2]])
|
||||
# embs = self.forward(wavs)
|
||||
embs = []
|
||||
for x in wavs:
|
||||
x = extract_feature([x])
|
||||
embs.append(sv_model(x))
|
||||
embs = torch.cat(embs)
|
||||
embeddings.append(embs.detach().numpy())
|
||||
embeddings = np.concatenate(embeddings)
|
||||
labels = cb_model(embeddings)
|
||||
sv_output = postprocess(segments, vad_segments, labels, embeddings)
|
||||
audio_dur = check_audio_list(vad_segments)
|
||||
if audio_dur > 5:
|
||||
# sv pipeline
|
||||
segments = sv_chunk(vad_segments)
|
||||
embeddings = []
|
||||
for s in segments:
|
||||
#_, embs = self.sv_pipeline([s[2]], output_emb=True)
|
||||
# embeddings.append(embs)
|
||||
wavs = sv_preprocess([s[2]])
|
||||
# embs = self.forward(wavs)
|
||||
embs = []
|
||||
for x in wavs:
|
||||
x = extract_feature([x])
|
||||
embs.append(sv_model(x))
|
||||
embs = torch.cat(embs)
|
||||
embeddings.append(embs.detach().numpy())
|
||||
embeddings = np.concatenate(embeddings)
|
||||
labels = cb_model(embeddings)
|
||||
sv_output = postprocess(segments, vad_segments, labels, embeddings)
|
||||
else:
|
||||
# fake speaker res for too shot utterance
|
||||
sv_output = [[0.0, vadsegments[-1][-1]/1000.0, 0]]
|
||||
logging.warning("Too short utterence found: {}, return default speaker results.".format(keys))
|
||||
|
||||
speech, speech_lengths = batch["speech"], batch["speech_lengths"]
|
||||
|
||||
|
||||
@ -35,7 +35,8 @@ def check_audio_list(audio: list):
|
||||
assert seg[0] >= audio[
|
||||
i - 1][1], 'modelscope error: Wrong time stamps.'
|
||||
audio_dur += seg[1] - seg[0]
|
||||
assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
|
||||
return audio_dur
|
||||
# assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
|
||||
|
||||
|
||||
def sv_preprocess(inputs: Union[np.ndarray, list]):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user