rollback cif_v1 for training bug

This commit is contained in:
维石 2024-06-21 15:21:33 +08:00
parent 0df672a2d0
commit 362eed972c

View File

@ -80,7 +80,7 @@ class CifPredictor(torch.nn.Module):
hidden, alphas, token_num, mask=mask
)
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
@ -245,7 +245,7 @@ class CifPredictorV2(torch.nn.Module):
hidden, alphas, token_num, mask=None
)
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]