From a7814a7bc32aa62ed70631f6478d407fdc0ff488 Mon Sep 17 00:00:00 2001 From: "haoneng.lhn" Date: Wed, 17 May 2023 17:13:32 +0800 Subject: [PATCH] fix paraformer online last chunk decoding strategy --- funasr/bin/asr_infer.py | 17 ----------------- funasr/models/encoder/sanm_encoder.py | 11 +---------- funasr/models/predictor/cif.py | 5 +++-- 3 files changed, 4 insertions(+), 29 deletions(-) diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index f6c5504b8..03145f859 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -762,23 +762,6 @@ class Speech2TextParaformerOnline: feats_len = speech_lengths if feats.shape[1] != 0: - if cache_en["is_final"]: - if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]: - cache_en["last_chunk"] = True - else: - # first chunk - feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :] - feats_len = torch.tensor([feats_chunk1.shape[1]]) - results_chunk1 = self.infer(feats_chunk1, feats_len, cache) - - # last chunk - cache_en["last_chunk"] = True - feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :] - feats_len = torch.tensor([feats_chunk2.shape[1]]) - results_chunk2 = self.infer(feats_chunk2, feats_len, cache) - - return [" ".join(results_chunk1 + results_chunk2)] - results = self.infer(feats, feats_len, cache) return results diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index e071e575b..da675864c 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -355,18 +355,9 @@ class SANMEncoder(AbsEncoder): def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}): if len(cache) == 0: return feats - # process last chunk cache["feats"] = to_device(cache["feats"], device=feats.device) overlap_feats = torch.cat((cache["feats"], feats), dim=1) - if cache["is_final"]: - cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :] - if not cache["last_chunk"]: - padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1] - overlap_feats = overlap_feats.transpose(1, 2) - overlap_feats = F.pad(overlap_feats, (0, padding_length)) - overlap_feats = overlap_feats.transpose(1, 2) - else: - cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] + cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] return overlap_feats def forward_chunk(self, diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index c59e24502..3c363dbab 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -221,13 +221,14 @@ class CifPredictorV2(nn.Module): if cache is not None and "chunk_size" in cache: alphas[:, :cache["chunk_size"][0]] = 0.0 - alphas[:, sum(cache["chunk_size"][:2]):] = 0.0 + if "is_final" in cache and not cache["is_final"]: + alphas[:, sum(cache["chunk_size"][:2]):] = 0.0 if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache: cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device) cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device) hidden = torch.cat((cache["cif_hidden"], hidden), dim=1) alphas = torch.cat((cache["cif_alphas"], alphas), dim=1) - if cache is not None and "last_chunk" in cache and cache["last_chunk"]: + if cache is not None and "is_final" in cache and cache["is_final"]: tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device) tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device) tail_alphas = torch.tile(tail_alphas, (batch_size, 1))