diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index c70baf023..ff8bb8c77 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -8,6 +8,7 @@ import os import codecs import tempfile import requests +import yaml from pathlib import Path from typing import Optional from typing import Sequence @@ -462,13 +463,23 @@ def inference_modelscope( array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) return array + def _read_yaml(yaml_path: Union[str, Path]) -> Dict: + if not Path(yaml_path).exists(): + raise FileExistsError(f'The {yaml_path} does not exist.') + + with open(str(yaml_path), 'rb') as f: + data = yaml.load(f, Loader=yaml.Loader) + return data + def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): if len(cache) > 0: return cache - - cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), + config = _read_yaml(asr_train_config) + enc_output_size = config["encoder_conf"]["output_size"] + feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"] + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, - "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False} + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False} cache["encoder"] = cache_en cache_de = {"decode_fsmn": None} @@ -478,9 +489,12 @@ def inference_modelscope( def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): if len(cache) > 0: - cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), + config = _read_yaml(asr_train_config) + enc_output_size = config["encoder_conf"]["output_size"] + feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"] + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, - "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False} + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False} cache["encoder"] = cache_en cache_de = {"decode_fsmn": None} @@ -720,4 +734,3 @@ if __name__ == "__main__": # # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') # print(rec_result) -