mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #73 from alibaba-damo-academy/main
fix bug, batch cif predictor tail
This commit is contained in:
commit
8871dcb93a
@ -208,7 +208,8 @@ class CifPredictorV2(nn.Module):
|
|||||||
mask_2 = torch.cat([ones_t, mask], dim=1)
|
mask_2 = torch.cat([ones_t, mask], dim=1)
|
||||||
mask = mask_2 - mask_1
|
mask = mask_2 - mask_1
|
||||||
tail_threshold = mask * tail_threshold
|
tail_threshold = mask * tail_threshold
|
||||||
alphas = torch.cat([alphas, tail_threshold], dim=1)
|
alphas = torch.cat([alphas, zeros_t], dim=1)
|
||||||
|
alphas = torch.add(alphas, tail_threshold)
|
||||||
else:
|
else:
|
||||||
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
||||||
tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
||||||
@ -654,4 +655,4 @@ class CifPredictorV3(nn.Module):
|
|||||||
|
|
||||||
predictor_alignments = index_div_bool_zeros_count_tile_out
|
predictor_alignments = index_div_bool_zeros_count_tile_out
|
||||||
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
||||||
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user