fix bug in predictor tail_process_fn

This commit is contained in:
lzr265946 2023-02-07 20:35:56 +08:00
parent 59eac0cc05
commit a3fe16f871

View File

@ -208,7 +208,8 @@ class CifPredictorV2(nn.Module):
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
alphas = torch.cat([alphas, tail_threshold], dim=1)
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))