mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update sensevoice onnx
This commit is contained in:
parent
1368a9bca4
commit
a62cd7a3fd
@ -3,8 +3,6 @@
|
||||
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
import torch
|
||||
import os.path
|
||||
import librosa
|
||||
import numpy as np
|
||||
@ -181,12 +179,12 @@ class SenseVoiceSmall:
|
||||
)
|
||||
for b in range(feats.shape[0]):
|
||||
# back to torch.Tensor
|
||||
if isinstance(ctc_logits, np.ndarray):
|
||||
ctc_logits = torch.from_numpy(ctc_logits).float()
|
||||
# if isinstance(ctc_logits, np.ndarray):
|
||||
# ctc_logits = torch.from_numpy(ctc_logits).float()
|
||||
# support batch_size=1 only currently
|
||||
x = ctc_logits[b, : encoder_out_lens[b].item(), :]
|
||||
yseq = x.argmax(dim=-1)
|
||||
yseq = torch.unique_consecutive(yseq, dim=-1)
|
||||
yseq = np.unique(yseq)
|
||||
|
||||
mask = yseq != self.blank_id
|
||||
token_int = yseq[mask].tolist()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user