mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
TOLD/SOND: update SequenceBinaryCrossEntropy loss
This commit is contained in:
parent
8e4ff62a72
commit
66880c2a1a
@ -75,10 +75,10 @@ class SequenceBinaryCrossEntropy(nn.Module):
|
|||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
|
|
||||||
def forward(self, pred, label, lengths):
|
def forward(self, pred, label, lengths):
|
||||||
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
|
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
|
||||||
loss = self.criterion(pred, label)
|
loss = self.criterion(pred, label)
|
||||||
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
|
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
|
||||||
return loss.masked_fill(pad_mask, 0).sum() / denom
|
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
|
||||||
|
|
||||||
|
|
||||||
class NllLoss(nn.Module):
|
class NllLoss(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user