label_smoothing_loss.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Label smoothing module."""
  6. import torch
  7. from torch import nn
  8. from funasr.modules.nets_utils import make_pad_mask
  9. class LabelSmoothingLoss(nn.Module):
  10. """Label-smoothing loss.
  11. :param int size: the number of class
  12. :param int padding_idx: ignored class id
  13. :param float smoothing: smoothing rate (0.0 means the conventional CE)
  14. :param bool normalize_length: normalize loss by sequence length if True
  15. :param torch.nn.Module criterion: loss function to be smoothed
  16. """
  17. def __init__(
  18. self,
  19. size,
  20. padding_idx,
  21. smoothing,
  22. normalize_length=False,
  23. criterion=nn.KLDivLoss(reduction="none"),
  24. ):
  25. """Construct an LabelSmoothingLoss object."""
  26. super(LabelSmoothingLoss, self).__init__()
  27. self.criterion = criterion
  28. self.padding_idx = padding_idx
  29. self.confidence = 1.0 - smoothing
  30. self.smoothing = smoothing
  31. self.size = size
  32. self.true_dist = None
  33. self.normalize_length = normalize_length
  34. def forward(self, x, target):
  35. """Compute loss between x and target.
  36. :param torch.Tensor x: prediction (batch, seqlen, class)
  37. :param torch.Tensor target:
  38. target signal masked with self.padding_id (batch, seqlen)
  39. :return: scalar float value
  40. :rtype torch.Tensor
  41. """
  42. assert x.size(2) == self.size
  43. batch_size = x.size(0)
  44. x = x.view(-1, self.size)
  45. target = target.view(-1)
  46. with torch.no_grad():
  47. true_dist = x.clone()
  48. true_dist.fill_(self.smoothing / (self.size - 1))
  49. ignore = target == self.padding_idx # (B,)
  50. total = len(target) - ignore.sum().item()
  51. target = target.masked_fill(ignore, 0) # avoid -1 index
  52. true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
  53. kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
  54. denom = total if self.normalize_length else batch_size
  55. return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
  56. class SequenceBinaryCrossEntropy(nn.Module):
  57. def __init__(
  58. self,
  59. normalize_length=False,
  60. criterion=nn.BCEWithLogitsLoss(reduction="none")
  61. ):
  62. super().__init__()
  63. self.normalize_length = normalize_length
  64. self.criterion = criterion
  65. def forward(self, pred, label, lengths):
  66. pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
  67. loss = self.criterion(pred, label)
  68. denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
  69. return loss.masked_fill(pad_mask, 0).sum() / denom
  70. class NllLoss(nn.Module):
  71. """Nll loss.
  72. :param int size: the number of class
  73. :param int padding_idx: ignored class id
  74. :param bool normalize_length: normalize loss by sequence length if True
  75. :param torch.nn.Module criterion: loss function
  76. """
  77. def __init__(
  78. self,
  79. size,
  80. padding_idx,
  81. normalize_length=False,
  82. criterion=nn.NLLLoss(reduction='none'),
  83. ):
  84. """Construct an NllLoss object."""
  85. super(NllLoss, self).__init__()
  86. self.criterion = criterion
  87. self.padding_idx = padding_idx
  88. self.size = size
  89. self.true_dist = None
  90. self.normalize_length = normalize_length
  91. def forward(self, x, target):
  92. """Compute loss between x and target.
  93. :param torch.Tensor x: prediction (batch, seqlen, class)
  94. :param torch.Tensor target:
  95. target signal masked with self.padding_id (batch, seqlen)
  96. :return: scalar float value
  97. :rtype torch.Tensor
  98. """
  99. assert x.size(2) == self.size
  100. batch_size = x.size(0)
  101. x = x.view(-1, self.size)
  102. target = target.view(-1)
  103. with torch.no_grad():
  104. ignore = target == self.padding_idx # (B,)
  105. total = len(target) - ignore.sum().item()
  106. target = target.masked_fill(ignore, 0) # avoid -1 index
  107. kl = self.criterion(x , target)
  108. denom = total if self.normalize_length else batch_size
  109. return kl.masked_fill(ignore, 0).sum() / denom