diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py new file mode 100644 index 000000000..f2024fbd2 --- /dev/null +++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py @@ -0,0 +1,52 @@ +import torch +import torchaudio +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + +inference_pipeline = pipeline( + task=Tasks.auto_speech_recognition, + model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online', + model_revision='v1.0.2') + +waveform, sample_rate = torchaudio.load("asr_example_zh.wav") +speech_length = waveform.shape[1] +speech = waveform[0] + +cache_en = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None} +cache_de = {"decode_fsmn": None} +cache = {"encoder": cache_en, "decoder": cache_de} +param_dict = {} +param_dict["cache"] = cache + +first_chunk = True +speech_buffer = speech +speech_cache = [] +final_result = "" + +while len(speech_buffer) > 0: + if first_chunk: + if len(speech_buffer) >= 14400: + rec_result = inference_pipeline(audio_in=speech_buffer[0:14400], param_dict=param_dict) + speech_buffer = speech_buffer[4800:] + else: + cache_en["stride"] = len(speech_buffer) // 960 + cache_en["pad_right"] = 0 + rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict) + speech_buffer = [] + cache_en["start_idx"] = -5 + first_chunk = False + else: + cache_en["start_idx"] += 10 + if len(speech_buffer) >= 4800: + cache_en["pad_left"] = 5 + rec_result = inference_pipeline(audio_in=speech_buffer[:19200], param_dict=param_dict) + speech_buffer = speech_buffer[9600:] + else: + cache_en["stride"] = len(speech_buffer) // 960 + cache_en["pad_right"] = 0 + rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict) + speech_buffer = [] + if rec_result['text'] != "sil": + final_result += rec_result['text'] + print(rec_result) +print(final_result)