diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index ff8bb8c77..4aae8e970 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -20,6 +20,7 @@ from typing import List import numpy as np import torch +import torchaudio from typeguard import check_argument_types from funasr.fileio.datadir_writer import DatadirWriter @@ -515,6 +516,8 @@ def inference_modelscope( if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": raw_inputs = _load_bytes(data_path_and_name_and_type[0]) raw_inputs = torch.tensor(raw_inputs) + if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound": + raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0] if data_path_and_name_and_type is None and raw_inputs is not None: if isinstance(raw_inputs, np.ndarray): raw_inputs = torch.tensor(raw_inputs) @@ -531,13 +534,32 @@ def inference_modelscope( # 7 .Start for-loop # FIXME(kamo): The output format should be discussed about raw_inputs = torch.unsqueeze(raw_inputs, axis=0) - input_lens = torch.tensor([raw_inputs.shape[1]]) asr_result_list = [] - cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1) - cache["encoder"]["is_final"] = is_final - asr_result = speech2text(cache, raw_inputs, input_lens) - item = {'key': "utt", 'value': asr_result} + item = {} + if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound": + sample_offset = 0 + speech_length = raw_inputs.shape[1] + stride_size = chunk_size[1] * 960 + cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1) + final_result = "" + for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)): + if sample_offset + stride_size >= speech_length - 1: + stride_size = speech_length - sample_offset + cache["encoder"]["is_final"] = True + else: + cache["encoder"]["is_final"] = False + input_lens = torch.tensor([stride_size]) + asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens) + if len(asr_result) != 0: + final_result += asr_result[0] + item = {'key': "utt", 'value': [final_result]} + else: + input_lens = torch.tensor([raw_inputs.shape[1]]) + cache["encoder"]["is_final"] = is_final + asr_result = speech2text(cache, raw_inputs, input_lens) + item = {'key': "utt", 'value': asr_result} + asr_result_list.append(item) if is_final: cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)