update sensevoice onnx

This commit is contained in:
shixian 2024-12-20 10:50:51 +08:00
parent 1368a9bca4
commit a62cd7a3fd

View File

@ -3,8 +3,6 @@
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
import os.path
import librosa
import numpy as np
@ -181,12 +179,12 @@ class SenseVoiceSmall:
)
for b in range(feats.shape[0]):
# back to torch.Tensor
if isinstance(ctc_logits, np.ndarray):
ctc_logits = torch.from_numpy(ctc_logits).float()
# if isinstance(ctc_logits, np.ndarray):
# ctc_logits = torch.from_numpy(ctc_logits).float()
# support batch_size=1 only currently
x = ctc_logits[b, : encoder_out_lens[b].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
yseq = np.unique(yseq)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()