update sensevoice onnx

This commit is contained in:
shixian 2024-12-20 10:50:51 +08:00
parent 1368a9bca4
commit a62cd7a3fd

View File

@ -3,8 +3,6 @@
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT) # MIT License (https://opensource.org/licenses/MIT)
import torch
import os.path import os.path
import librosa import librosa
import numpy as np import numpy as np
@ -181,12 +179,12 @@ class SenseVoiceSmall:
) )
for b in range(feats.shape[0]): for b in range(feats.shape[0]):
# back to torch.Tensor # back to torch.Tensor
if isinstance(ctc_logits, np.ndarray): # if isinstance(ctc_logits, np.ndarray):
ctc_logits = torch.from_numpy(ctc_logits).float() # ctc_logits = torch.from_numpy(ctc_logits).float()
# support batch_size=1 only currently # support batch_size=1 only currently
x = ctc_logits[b, : encoder_out_lens[b].item(), :] x = ctc_logits[b, : encoder_out_lens[b].item(), :]
yseq = x.argmax(dim=-1) yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1) yseq = np.unique(yseq)
mask = yseq != self.blank_id mask = yseq != self.blank_id
token_int = yseq[mask].tolist() token_int = yseq[mask].tolist()