TOLD/SOND: update SequenceBinaryCrossEntropy loss

This commit is contained in:
志浩 2023-08-01 21:00:50 +08:00
parent 8e4ff62a72
commit 66880c2a1a

View File

@ -75,10 +75,10 @@ class SequenceBinaryCrossEntropy(nn.Module):
self.criterion = criterion
def forward(self, pred, label, lengths):
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask, 0).sum() / denom
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
class NllLoss(nn.Module):