diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index c28146062..c34759d0d 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -598,7 +598,8 @@ class CifPredictorV3(nn.Module): mask_2 = torch.cat([ones_t, mask], dim=1) mask = mask_2 - mask_1 tail_threshold = mask * tail_threshold - alphas = torch.cat([alphas, tail_threshold], dim=1) + alphas = torch.cat([alphas, zeros_t], dim=1) + alphas = torch.add(alphas, tail_threshold) else: tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device) tail_threshold = torch.reshape(tail_threshold, (1, 1))