diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py index 07efde67c..68980301f 100644 --- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py +++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py @@ -10,7 +10,6 @@ encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-atten decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.2") -cache = {} res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 943cb476a..7c2156174 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -501,7 +501,9 @@ class FsmnVADStreaming(nn.Module): # self.AllResetDetection() return segments + def init_cache(self, cache: dict = {}, **kwargs): + cache["frontend"] = {} cache["prev_samples"] = torch.empty(0) cache["encoder"] = {} @@ -528,7 +530,7 @@ class FsmnVADStreaming(nn.Module): cache: dict = {}, **kwargs, ): - + if len(cache) == 0: self.init_cache(cache, **kwargs) @@ -583,7 +585,7 @@ class FsmnVADStreaming(nn.Module): cache["prev_samples"] = audio_sample[:-m] if _is_final: - cache = {} + self.init_cache(cache) ibest_writer = None if ibest_writer is None and kwargs.get("output_dir") is not None: diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py index bf4526947..9bf5d39b2 100644 --- a/funasr/models/paraformer_streaming/model.py +++ b/funasr/models/paraformer_streaming/model.py @@ -502,8 +502,7 @@ class ParaformerStreaming(Paraformer): logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) - - + if len(cache) == 0: self.init_cache(cache, **kwargs)