From 362eed972c885bd3526b75df6e1527925abe06c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BB=B4=E7=9F=B3?= Date: Fri, 21 Jun 2024 15:21:33 +0800 Subject: [PATCH] rollback cif_v1 for training bug --- funasr/models/paraformer/cif_predictor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index 0856eede2..83ca464da 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -80,7 +80,7 @@ class CifPredictor(torch.nn.Module): hidden, alphas, token_num, mask=mask ) - acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold) + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) if target_length is None and self.tail_threshold > 0.0: token_num_int = torch.max(token_num).type(torch.int32).item() @@ -245,7 +245,7 @@ class CifPredictorV2(torch.nn.Module): hidden, alphas, token_num, mask=None ) - acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold) + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) if target_length is None and self.tail_threshold > 0.0: token_num_int = torch.max(token_num).type(torch.int32).item() acoustic_embeds = acoustic_embeds[:, :token_num_int, :]