abs_model.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Tuple
  4. import torch
  5. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  6. from typing import Dict
  7. from typing import Optional
  8. from typing import Tuple
  9. import torch
  10. import torch.nn.functional as F
  11. from typeguard import check_argument_types
  12. from funasr.modules.nets_utils import make_pad_mask
  13. from funasr.lm.abs_model import AbsLM
  14. from funasr.torch_utils.device_funcs import force_gatherable
  15. from funasr.train.abs_espnet_model import AbsESPnetModel
  16. class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
  17. """The abstract LM class
  18. To share the loss calculation way among different models,
  19. We uses delegate pattern here:
  20. The instance of this class should be passed to "LanguageModel"
  21. >>> from funasr.lm.abs_model import AbsLM
  22. >>> lm = AbsLM()
  23. >>> model = LanguageESPnetModel(lm=lm)
  24. This "model" is one of mediator objects for "Task" class.
  25. """
  26. @abstractmethod
  27. def forward(
  28. self, input: torch.Tensor, hidden: torch.Tensor
  29. ) -> Tuple[torch.Tensor, torch.Tensor]:
  30. raise NotImplementedError
  31. class LanguageModel(AbsESPnetModel):
  32. def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
  33. assert check_argument_types()
  34. super().__init__()
  35. self.lm = lm
  36. self.sos = 1
  37. self.eos = 2
  38. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
  39. self.ignore_id = ignore_id
  40. def nll(
  41. self,
  42. text: torch.Tensor,
  43. text_lengths: torch.Tensor,
  44. max_length: Optional[int] = None,
  45. ) -> Tuple[torch.Tensor, torch.Tensor]:
  46. """Compute negative log likelihood(nll)
  47. Normally, this function is called in batchify_nll.
  48. Args:
  49. text: (Batch, Length)
  50. text_lengths: (Batch,)
  51. max_lengths: int
  52. """
  53. batch_size = text.size(0)
  54. # For data parallel
  55. if max_length is None:
  56. text = text[:, : text_lengths.max()]
  57. else:
  58. text = text[:, :max_length]
  59. # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
  60. # text: (Batch, Length) -> x, y: (Batch, Length + 1)
  61. x = F.pad(text, [1, 0], "constant", self.sos)
  62. t = F.pad(text, [0, 1], "constant", self.ignore_id)
  63. for i, l in enumerate(text_lengths):
  64. t[i, l] = self.eos
  65. x_lengths = text_lengths + 1
  66. # 2. Forward Language model
  67. # x: (Batch, Length) -> y: (Batch, Length, NVocab)
  68. y, _ = self.lm(x, None)
  69. # 3. Calc negative log likelihood
  70. # nll: (BxL,)
  71. nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
  72. # nll: (BxL,) -> (BxL,)
  73. if max_length is None:
  74. nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
  75. else:
  76. nll.masked_fill_(
  77. make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
  78. 0.0,
  79. )
  80. # nll: (BxL,) -> (B, L)
  81. nll = nll.view(batch_size, -1)
  82. return nll, x_lengths
  83. def batchify_nll(
  84. self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
  85. ) -> Tuple[torch.Tensor, torch.Tensor]:
  86. """Compute negative log likelihood(nll) from transformer language model
  87. To avoid OOM, this fuction seperate the input into batches.
  88. Then call nll for each batch and combine and return results.
  89. Args:
  90. text: (Batch, Length)
  91. text_lengths: (Batch,)
  92. batch_size: int, samples each batch contain when computing nll,
  93. you may change this to avoid OOM or increase
  94. """
  95. total_num = text.size(0)
  96. if total_num <= batch_size:
  97. nll, x_lengths = self.nll(text, text_lengths)
  98. else:
  99. nlls = []
  100. x_lengths = []
  101. max_length = text_lengths.max()
  102. start_idx = 0
  103. while True:
  104. end_idx = min(start_idx + batch_size, total_num)
  105. batch_text = text[start_idx:end_idx, :]
  106. batch_text_lengths = text_lengths[start_idx:end_idx]
  107. # batch_nll: [B * T]
  108. batch_nll, batch_x_lengths = self.nll(
  109. batch_text, batch_text_lengths, max_length=max_length
  110. )
  111. nlls.append(batch_nll)
  112. x_lengths.append(batch_x_lengths)
  113. start_idx = end_idx
  114. if start_idx == total_num:
  115. break
  116. nll = torch.cat(nlls)
  117. x_lengths = torch.cat(x_lengths)
  118. assert nll.size(0) == total_num
  119. assert x_lengths.size(0) == total_num
  120. return nll, x_lengths
  121. def forward(
  122. self, text: torch.Tensor, text_lengths: torch.Tensor
  123. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  124. nll, y_lengths = self.nll(text, text_lengths)
  125. ntokens = y_lengths.sum()
  126. loss = nll.sum() / ntokens
  127. stats = dict(loss=loss.detach())
  128. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  129. loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
  130. return loss, stats, weight
  131. def collect_feats(
  132. self, text: torch.Tensor, text_lengths: torch.Tensor
  133. ) -> Dict[str, torch.Tensor]:
  134. return {}