From 7df8452a8520cf3e4609114b47510371b2c621a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=9D=E8=80=B3?= Date: Thu, 30 Mar 2023 16:15:49 +0800 Subject: [PATCH] fix --- funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index 034475c6c..c00a3d7f2 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -76,9 +76,8 @@ class TargetDelayTransformer(): try: outputs = self.infer(data['text'], data['text_lengths']) y = outputs[0] - _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) - punctuations = indices - assert punctuations.size()[0] == len(mini_sentence) + punctuations = np.argmax(y,axis=-1)[0] + assert punctuations.size == len(mini_sentence) except ONNXRuntimeError: logging.warning("error")