diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index d5d18737a..1b38f8f94 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -953,14 +953,15 @@ def inference_paraformer_online( # FIXME(kamo): The output format should be discussed about raw_inputs = torch.unsqueeze(raw_inputs, axis=0) asr_result_list = [] - cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1) + cache = _prepare_cache(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, + decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1) item = {} if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound": sample_offset = 0 speech_length = raw_inputs.shape[1] stride_size = chunk_size[1] * 960 - cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1, - encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back) + cache = _prepare_cache(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, + decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1) final_result = "" for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)): if sample_offset + stride_size >= speech_length - 1: