From 3cde6fe8e0457a0cd3d65797d6d53bd947047a44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 22 May 2023 16:25:48 +0800 Subject: [PATCH] bugfix paraformer large long --- .../demo.py | 5 ++-- funasr/bin/asr_inference_launch.py | 28 ++++++++++++++++--- funasr/version.txt | 2 +- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py index 9b474dd7d..91212c065 100644 --- a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py +++ b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py @@ -3,15 +3,14 @@ from modelscope.utils.constant import Tasks if __name__ == '__main__': audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav' - output_dir = None + output_dir = "./results" inference_pipeline = pipeline( task=Tasks.auto_speech_recognition, model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch', vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch', punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch', output_dir=output_dir, - batch_size=64, ) - rec_result = inference_pipeline(audio_in=audio_in) + rec_result = inference_pipeline(audio_in=audio_in, batch_size_token=5000) print(rec_result) diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index dbbb3ed1e..ec5e17535 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -600,6 +600,9 @@ def inference_paraformer_vad_punc( if 'hotword' in kwargs: hotword_list_or_file = kwargs['hotword'] + batch_size_token = kwargs.get("batch_size_token", 6000) + print("batch_size_token: ", batch_size_token) + if speech2text.hotword_list is None: speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) @@ -642,8 +645,10 @@ def inference_paraformer_vad_punc( assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" - + beg_vad = time.time() vad_results = speech2vadsegment(**batch) + end_vad = time.time() + print("time cost vad: ", end_vad-beg_vad) _, vadsegments = vad_results[0], vad_results[1][0] speech, speech_lengths = batch["speech"], batch["speech_lengths"] @@ -652,17 +657,29 @@ def inference_paraformer_vad_punc( data_with_index = [(vadsegments[i], i) for i in range(n)] sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) results_sorted = [] - for j, beg_idx in enumerate(range(0, n, batch_size)): - end_idx = min(n, beg_idx + batch_size) + batch_size_token_ms = batch_size_token*60 + batch_size_token_ms_cum = 0 + beg_idx = 0 + for j, _ in enumerate(range(0, n)): + batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0]) + if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0]) 0 and text2punc is not None: + beg_punc = time.time() text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) + end_punc = time.time() + print("time cost punc: ", end_punc-beg_punc) item = {'key': key, 'value': text_postprocessed_punc} if text_postprocessed != "": diff --git a/funasr/version.txt b/funasr/version.txt index 7d8568351..d1d899fa3 100644 --- a/funasr/version.txt +++ b/funasr/version.txt @@ -1 +1 @@ -0.5.4 +0.5.5