From a3fe16f871528864288737ef4a2df6a18000e138 Mon Sep 17 00:00:00 2001 From: lzr265946 Date: Tue, 7 Feb 2023 20:35:56 +0800 Subject: [PATCH] fix bug in predictor tail_process_fn --- funasr/models/predictor/cif.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index 60cf902f1..00c5a3e92 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -208,7 +208,8 @@ class CifPredictorV2(nn.Module): mask_2 = torch.cat([ones_t, mask], dim=1) mask = mask_2 - mask_1 tail_threshold = mask * tail_threshold - alphas = torch.cat([alphas, tail_threshold], dim=1) + alphas = torch.cat([alphas, zeros_t], dim=1) + alphas = torch.add(alphas, tail_threshold) else: tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device) tail_threshold = torch.reshape(tail_threshold, (1, 1)) @@ -654,4 +655,4 @@ class CifPredictorV3(nn.Module): predictor_alignments = index_div_bool_zeros_count_tile_out predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype) - return predictor_alignments.detach(), predictor_alignments_length.detach() \ No newline at end of file + return predictor_alignments.detach(), predictor_alignments_length.detach()