mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Update asr_inference_paraformer_streaming.py
This commit is contained in:
parent
493dda8f98
commit
9624eba825
@ -8,6 +8,7 @@ import os
|
||||
import codecs
|
||||
import tempfile
|
||||
import requests
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
@ -462,13 +463,23 @@ def inference_modelscope(
|
||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
||||
return array
|
||||
|
||||
def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
|
||||
if not Path(yaml_path).exists():
|
||||
raise FileExistsError(f'The {yaml_path} does not exist.')
|
||||
|
||||
with open(str(yaml_path), 'rb') as f:
|
||||
data = yaml.load(f, Loader=yaml.Loader)
|
||||
return data
|
||||
|
||||
def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
|
||||
if len(cache) > 0:
|
||||
return cache
|
||||
|
||||
cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
|
||||
config = _read_yaml(asr_train_config)
|
||||
enc_output_size = config["encoder_conf"]["output_size"]
|
||||
feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
|
||||
cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
|
||||
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
|
||||
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
|
||||
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
|
||||
cache["encoder"] = cache_en
|
||||
|
||||
cache_de = {"decode_fsmn": None}
|
||||
@ -478,9 +489,12 @@ def inference_modelscope(
|
||||
|
||||
def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
|
||||
if len(cache) > 0:
|
||||
cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
|
||||
config = _read_yaml(asr_train_config)
|
||||
enc_output_size = config["encoder_conf"]["output_size"]
|
||||
feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
|
||||
cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
|
||||
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
|
||||
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
|
||||
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
|
||||
cache["encoder"] = cache_en
|
||||
|
||||
cache_de = {"decode_fsmn": None}
|
||||
@ -720,4 +734,3 @@ if __name__ == "__main__":
|
||||
#
|
||||
# rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||
# print(rec_result)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user