Bladeren bron

TOLD/SOND: update SequenceBinaryCrossEntropy loss

志浩 2 jaren geleden
bovenliggende
commit
66880c2a1a
1 gewijzigde bestanden met toevoegingen van 2 en 2 verwijderingen
  1. 2 2
      funasr/losses/label_smoothing_loss.py

+ 2 - 2
funasr/losses/label_smoothing_loss.py

@@ -75,10 +75,10 @@ class SequenceBinaryCrossEntropy(nn.Module):
         self.criterion = criterion
 
     def forward(self, pred, label, lengths):
-        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
+        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
         loss = self.criterion(pred, label)
         denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
-        return loss.masked_fill(pad_mask, 0).sum() / denom
+        return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
 
 
 class NllLoss(nn.Module):