mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
rollback cif_v1 for training bug
This commit is contained in:
parent
0df672a2d0
commit
362eed972c
@ -80,7 +80,7 @@ class CifPredictor(torch.nn.Module):
|
|||||||
hidden, alphas, token_num, mask=mask
|
hidden, alphas, token_num, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||||
|
|
||||||
if target_length is None and self.tail_threshold > 0.0:
|
if target_length is None and self.tail_threshold > 0.0:
|
||||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||||
@ -245,7 +245,7 @@ class CifPredictorV2(torch.nn.Module):
|
|||||||
hidden, alphas, token_num, mask=None
|
hidden, alphas, token_num, mask=None
|
||||||
)
|
)
|
||||||
|
|
||||||
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||||
if target_length is None and self.tail_threshold > 0.0:
|
if target_length is None and self.tail_threshold > 0.0:
|
||||||
token_num_int = torch.max(token_num).type(torch.int32).item()
|
token_num_int = torch.max(token_num).type(torch.int32).item()
|
||||||
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user