diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index 8ea4517ec..034475c6c 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -70,8 +70,8 @@ class TargetDelayTransformer(): mini_sentence = cache_sent + mini_sentence mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) data = { - "text": mini_sentence_id, - "text_lengths": len(mini_sentence_id), + "text": mini_sentence_id[None,:].astype(np.int64), + "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'), } try: outputs = self.infer(data['text'], data['text_lengths']) @@ -125,8 +125,8 @@ class TargetDelayTransformer(): new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] return new_mini_sentence_out, new_mini_sentence_punc_out - def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]: - - outputs = self.ort_infer(feats) + def infer(self, feats: np.ndarray, + feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + outputs = self.ort_infer([feats, feats_len]) return outputs