From 12496e559feea69af2e77eac6f22b32df3bf6762 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Thu, 18 Jan 2024 23:21:12 +0800 Subject: [PATCH] streaming bugfix (#1271) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi * update with main (#1264) * Funasr1.0 (#1261) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi * bug fix --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi * funasr1.0 sanm scama * funasr1.0 infer_after_finetune * funasr1.0 fsmn-vad bug fix * funasr1.0 fsmn-vad bug fix * funasr1.0 fsmn-vad bug fix --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --- .../paraformer_streaming/demo.py | 1 - funasr/models/fsmn_vad_streaming/model.py | 6 ++++-- funasr/models/paraformer_streaming/model.py | 3 +-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py index 07efde67c..68980301f 100644 --- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py +++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py @@ -10,7 +10,6 @@ encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-atten decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.2") -cache = {} res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 943cb476a..7c2156174 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -501,7 +501,9 @@ class FsmnVADStreaming(nn.Module): # self.AllResetDetection() return segments + def init_cache(self, cache: dict = {}, **kwargs): + cache["frontend"] = {} cache["prev_samples"] = torch.empty(0) cache["encoder"] = {} @@ -528,7 +530,7 @@ class FsmnVADStreaming(nn.Module): cache: dict = {}, **kwargs, ): - + if len(cache) == 0: self.init_cache(cache, **kwargs) @@ -583,7 +585,7 @@ class FsmnVADStreaming(nn.Module): cache["prev_samples"] = audio_sample[:-m] if _is_final: - cache = {} + self.init_cache(cache) ibest_writer = None if ibest_writer is None and kwargs.get("output_dir") is not None: diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py index bf4526947..9bf5d39b2 100644 --- a/funasr/models/paraformer_streaming/model.py +++ b/funasr/models/paraformer_streaming/model.py @@ -502,8 +502,7 @@ class ParaformerStreaming(Paraformer): logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) - - + if len(cache) == 0: self.init_cache(cache, **kwargs)