From cbe2ea7e07cbf364827bd89cefc42b3f643ea3be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 18 Mar 2024 23:59:09 +0800 Subject: [PATCH] paraformer streaming bugfix --- funasr/models/paraformer_streaming/model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py index 5daa73a4b..499b48752 100644 --- a/funasr/models/paraformer_streaming/model.py +++ b/funasr/models/paraformer_streaming/model.py @@ -532,11 +532,13 @@ class ParaformerStreaming(Paraformer): kwargs["is_final"] = _is_final and i == n -1 audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples] if kwargs["is_final"] and len(audio_sample_i) < 960: - continue - - # extract fbank feats - speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), - frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"]) + cache["encoder"]["tail_chunk"] = True + speech = cache["encoder"]["feats"] + speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(speech.device) + else: + # extract fbank feats + speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), + frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"]) time3 = time.perf_counter() meta_data["extract_feat"] = f"{time3 - time2:0.3f}" meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000