mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update sensevoice onnx
This commit is contained in:
parent
1368a9bca4
commit
a62cd7a3fd
@ -3,8 +3,6 @@
|
|||||||
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
||||||
# MIT License (https://opensource.org/licenses/MIT)
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import os.path
|
import os.path
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -181,12 +179,12 @@ class SenseVoiceSmall:
|
|||||||
)
|
)
|
||||||
for b in range(feats.shape[0]):
|
for b in range(feats.shape[0]):
|
||||||
# back to torch.Tensor
|
# back to torch.Tensor
|
||||||
if isinstance(ctc_logits, np.ndarray):
|
# if isinstance(ctc_logits, np.ndarray):
|
||||||
ctc_logits = torch.from_numpy(ctc_logits).float()
|
# ctc_logits = torch.from_numpy(ctc_logits).float()
|
||||||
# support batch_size=1 only currently
|
# support batch_size=1 only currently
|
||||||
x = ctc_logits[b, : encoder_out_lens[b].item(), :]
|
x = ctc_logits[b, : encoder_out_lens[b].item(), :]
|
||||||
yseq = x.argmax(dim=-1)
|
yseq = x.argmax(dim=-1)
|
||||||
yseq = torch.unique_consecutive(yseq, dim=-1)
|
yseq = np.unique(yseq)
|
||||||
|
|
||||||
mask = yseq != self.blank_id
|
mask = yseq != self.blank_id
|
||||||
token_int = yseq[mask].tolist()
|
token_int = yseq[mask].tolist()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user