diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index 0856eede2..83ca464da 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -80,7 +80,7 @@ class CifPredictor(torch.nn.Module): hidden, alphas, token_num, mask=mask ) - acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold) + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) if target_length is None and self.tail_threshold > 0.0: token_num_int = torch.max(token_num).type(torch.int32).item() @@ -245,7 +245,7 @@ class CifPredictorV2(torch.nn.Module): hidden, alphas, token_num, mask=None ) - acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold) + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) if target_length is None and self.tail_threshold > 0.0: token_num_int = torch.max(token_num).type(torch.int32).item() acoustic_embeds = acoustic_embeds[:, :token_num_int, :]