abs_model.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Tuple
  4. import torch
  5. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  6. from typing import Dict
  7. from typing import Optional
  8. from typing import Tuple
  9. import torch
  10. import torch.nn.functional as F
  11. from typeguard import check_argument_types
  12. from funasr.modules.nets_utils import make_pad_mask
  13. from funasr.torch_utils.device_funcs import force_gatherable
  14. from funasr.train.abs_espnet_model import AbsESPnetModel
  15. class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
  16. """The abstract LM class
  17. To share the loss calculation way among different models,
  18. We uses delegate pattern here:
  19. The instance of this class should be passed to "LanguageModel"
  20. >>> from funasr.lm.abs_model import AbsLM
  21. >>> lm = AbsLM()
  22. >>> model = LanguageESPnetModel(lm=lm)
  23. This "model" is one of mediator objects for "Task" class.
  24. """
  25. @abstractmethod
  26. def forward(
  27. self, input: torch.Tensor, hidden: torch.Tensor
  28. ) -> Tuple[torch.Tensor, torch.Tensor]:
  29. raise NotImplementedError
  30. class LanguageModel(AbsESPnetModel):
  31. def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
  32. assert check_argument_types()
  33. super().__init__()
  34. self.lm = lm
  35. self.sos = 1
  36. self.eos = 2
  37. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
  38. self.ignore_id = ignore_id
  39. def nll(
  40. self,
  41. text: torch.Tensor,
  42. text_lengths: torch.Tensor,
  43. max_length: Optional[int] = None,
  44. ) -> Tuple[torch.Tensor, torch.Tensor]:
  45. """Compute negative log likelihood(nll)
  46. Normally, this function is called in batchify_nll.
  47. Args:
  48. text: (Batch, Length)
  49. text_lengths: (Batch,)
  50. max_lengths: int
  51. """
  52. batch_size = text.size(0)
  53. # For data parallel
  54. if max_length is None:
  55. text = text[:, : text_lengths.max()]
  56. else:
  57. text = text[:, :max_length]
  58. # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
  59. # text: (Batch, Length) -> x, y: (Batch, Length + 1)
  60. x = F.pad(text, [1, 0], "constant", self.sos)
  61. t = F.pad(text, [0, 1], "constant", self.ignore_id)
  62. for i, l in enumerate(text_lengths):
  63. t[i, l] = self.eos
  64. x_lengths = text_lengths + 1
  65. # 2. Forward Language model
  66. # x: (Batch, Length) -> y: (Batch, Length, NVocab)
  67. y, _ = self.lm(x, None)
  68. # 3. Calc negative log likelihood
  69. # nll: (BxL,)
  70. nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
  71. # nll: (BxL,) -> (BxL,)
  72. if max_length is None:
  73. nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
  74. else:
  75. nll.masked_fill_(
  76. make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
  77. 0.0,
  78. )
  79. # nll: (BxL,) -> (B, L)
  80. nll = nll.view(batch_size, -1)
  81. return nll, x_lengths
  82. def batchify_nll(
  83. self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
  84. ) -> Tuple[torch.Tensor, torch.Tensor]:
  85. """Compute negative log likelihood(nll) from transformer language model
  86. To avoid OOM, this fuction seperate the input into batches.
  87. Then call nll for each batch and combine and return results.
  88. Args:
  89. text: (Batch, Length)
  90. text_lengths: (Batch,)
  91. batch_size: int, samples each batch contain when computing nll,
  92. you may change this to avoid OOM or increase
  93. """
  94. total_num = text.size(0)
  95. if total_num <= batch_size:
  96. nll, x_lengths = self.nll(text, text_lengths)
  97. else:
  98. nlls = []
  99. x_lengths = []
  100. max_length = text_lengths.max()
  101. start_idx = 0
  102. while True:
  103. end_idx = min(start_idx + batch_size, total_num)
  104. batch_text = text[start_idx:end_idx, :]
  105. batch_text_lengths = text_lengths[start_idx:end_idx]
  106. # batch_nll: [B * T]
  107. batch_nll, batch_x_lengths = self.nll(
  108. batch_text, batch_text_lengths, max_length=max_length
  109. )
  110. nlls.append(batch_nll)
  111. x_lengths.append(batch_x_lengths)
  112. start_idx = end_idx
  113. if start_idx == total_num:
  114. break
  115. nll = torch.cat(nlls)
  116. x_lengths = torch.cat(x_lengths)
  117. assert nll.size(0) == total_num
  118. assert x_lengths.size(0) == total_num
  119. return nll, x_lengths
  120. def forward(
  121. self, text: torch.Tensor, text_lengths: torch.Tensor
  122. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  123. nll, y_lengths = self.nll(text, text_lengths)
  124. ntokens = y_lengths.sum()
  125. loss = nll.sum() / ntokens
  126. stats = dict(loss=loss.detach())
  127. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  128. loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
  129. return loss, stats, weight
  130. def collect_feats(
  131. self, text: torch.Tensor, text_lengths: torch.Tensor
  132. ) -> Dict[str, torch.Tensor]:
  133. return {}