mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
rollback cif_v1 for training bug
This commit is contained in:
parent
0df672a2d0
commit
362eed972c
@ -80,7 +80,7 @@ class CifPredictor(torch.nn.Module):
|
||||
hidden, alphas, token_num, mask=mask
|
||||
)
|
||||
|
||||
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
@ -245,7 +245,7 @@ class CifPredictorV2(torch.nn.Module):
|
||||
hidden, alphas, token_num, mask=None
|
||||
)
|
||||
|
||||
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
if target_length is None and self.tail_threshold > 0.0:
|
||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user