From ce7914034dd8496409af3b6b368218be1c71d3a1 Mon Sep 17 00:00:00 2001 From: "mengzhe.cmz" Date: Wed, 17 May 2023 17:12:14 +0800 Subject: [PATCH] increase vad realtime punc --- .../demo.py | 2 +- funasr/export/test/test_onnx_punc_vadrealtime.py | 2 +- funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py index cf115b16d..c449ab296 100644 --- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py +++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py @@ -9,7 +9,7 @@ logger.setLevel(logging.CRITICAL) inference_pipeline = pipeline( task=Tasks.punctuation, model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727', - output_dir="./tmp/" + model_revision = 'v1.0.2' ) ##################text二进制数据##################### diff --git a/funasr/export/test/test_onnx_punc_vadrealtime.py b/funasr/export/test/test_onnx_punc_vadrealtime.py index 86be026dc..507226eb8 100644 --- a/funasr/export/test/test_onnx_punc_vadrealtime.py +++ b/funasr/export/test/test_onnx_punc_vadrealtime.py @@ -12,7 +12,7 @@ if __name__ == '__main__': return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32), 'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32), - 'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) + 'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32), } def _run(feed_dict): diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index 8890714e6..035dd0026 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -186,11 +186,12 @@ class CT_Transformer_VadRealtime(CT_Transformer): mini_sentence = cache_sent + mini_sentence mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32') text_length = len(mini_sentence_id) + vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32) data = { "input": mini_sentence_id[None,:], "text_lengths": np.array([text_length], dtype='int32'), - "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32), - "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) + "vad_mask": vad_mask + "sub_masks": vad_mask } try: outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])