From ae7aff2e9cf6c79cf61e18cdc60ce68cf9f98400 Mon Sep 17 00:00:00 2001 From: majic31 Date: Tue, 24 Dec 2024 17:51:11 +0800 Subject: [PATCH] fix: solve problems in sensevoice_bin.py related to argmax and unique, as mentioned in issue #2331 (#2332) --- runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py b/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py index 9cfe46f81..6a06ed105 100644 --- a/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py +++ b/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py @@ -183,8 +183,10 @@ class SenseVoiceSmall: # ctc_logits = torch.from_numpy(ctc_logits).float() # support batch_size=1 only currently x = ctc_logits[b, : encoder_out_lens[b].item(), :] - yseq = x.argmax(dim=-1) - yseq = np.unique(yseq) + yseq = np.argmax(x, axis=-1) + # Use np.diff and np.where instead of torch.unique_consecutive. + mask = np.concatenate(([True], np.diff(yseq) != 0)) + yseq = yseq[mask] mask = yseq != self.blank_id token_int = yseq[mask].tolist()