| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- # Copyright 2019 Shigeki Karita
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Label smoothing module."""
- import torch
- from torch import nn
- from funasr.modules.nets_utils import make_pad_mask
- class LabelSmoothingLoss(nn.Module):
- """Label-smoothing loss.
- :param int size: the number of class
- :param int padding_idx: ignored class id
- :param float smoothing: smoothing rate (0.0 means the conventional CE)
- :param bool normalize_length: normalize loss by sequence length if True
- :param torch.nn.Module criterion: loss function to be smoothed
- """
- def __init__(
- self,
- size,
- padding_idx,
- smoothing,
- normalize_length=False,
- criterion=nn.KLDivLoss(reduction="none"),
- ):
- """Construct an LabelSmoothingLoss object."""
- super(LabelSmoothingLoss, self).__init__()
- self.criterion = criterion
- self.padding_idx = padding_idx
- self.confidence = 1.0 - smoothing
- self.smoothing = smoothing
- self.size = size
- self.true_dist = None
- self.normalize_length = normalize_length
- def forward(self, x, target):
- """Compute loss between x and target.
- :param torch.Tensor x: prediction (batch, seqlen, class)
- :param torch.Tensor target:
- target signal masked with self.padding_id (batch, seqlen)
- :return: scalar float value
- :rtype torch.Tensor
- """
- assert x.size(2) == self.size
- batch_size = x.size(0)
- x = x.view(-1, self.size)
- target = target.view(-1)
- with torch.no_grad():
- true_dist = x.clone()
- true_dist.fill_(self.smoothing / (self.size - 1))
- ignore = target == self.padding_idx # (B,)
- total = len(target) - ignore.sum().item()
- target = target.masked_fill(ignore, 0) # avoid -1 index
- true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
- kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
- denom = total if self.normalize_length else batch_size
- return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
- class SequenceBinaryCrossEntropy(nn.Module):
- def __init__(
- self,
- normalize_length=False,
- criterion=nn.BCEWithLogitsLoss(reduction="none")
- ):
- super().__init__()
- self.normalize_length = normalize_length
- self.criterion = criterion
- def forward(self, pred, label, lengths):
- pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
- loss = self.criterion(pred, label)
- denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
- return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
- class NllLoss(nn.Module):
- """Nll loss.
- :param int size: the number of class
- :param int padding_idx: ignored class id
- :param bool normalize_length: normalize loss by sequence length if True
- :param torch.nn.Module criterion: loss function
- """
- def __init__(
- self,
- size,
- padding_idx,
- normalize_length=False,
- criterion=nn.NLLLoss(reduction='none'),
- ):
- """Construct an NllLoss object."""
- super(NllLoss, self).__init__()
- self.criterion = criterion
- self.padding_idx = padding_idx
- self.size = size
- self.true_dist = None
- self.normalize_length = normalize_length
- def forward(self, x, target):
- """Compute loss between x and target.
- :param torch.Tensor x: prediction (batch, seqlen, class)
- :param torch.Tensor target:
- target signal masked with self.padding_id (batch, seqlen)
- :return: scalar float value
- :rtype torch.Tensor
- """
- assert x.size(2) == self.size
- batch_size = x.size(0)
- x = x.view(-1, self.size)
- target = target.view(-1)
- with torch.no_grad():
- ignore = target == self.padding_idx # (B,)
- total = len(target) - ignore.sum().item()
- target = target.masked_fill(ignore, 0) # avoid -1 index
- kl = self.criterion(x , target)
- denom = total if self.normalize_length else batch_size
- return kl.masked_fill(ignore, 0).sum() / denom
|