espnet_model.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from typing import Dict
  2. from typing import Optional
  3. from typing import Tuple
  4. import torch
  5. import torch.nn.functional as F
  6. from typeguard import check_argument_types
  7. from funasr.modules.nets_utils import make_pad_mask
  8. from funasr.lm.abs_model import AbsLM
  9. from funasr.torch_utils.device_funcs import force_gatherable
  10. from funasr.train.abs_espnet_model import AbsESPnetModel
  11. class ESPnetLanguageModel(AbsESPnetModel):
  12. def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
  13. assert check_argument_types()
  14. super().__init__()
  15. self.lm = lm
  16. self.sos = 1
  17. self.eos = 2
  18. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
  19. self.ignore_id = ignore_id
  20. def nll(
  21. self,
  22. text: torch.Tensor,
  23. text_lengths: torch.Tensor,
  24. max_length: Optional[int] = None,
  25. ) -> Tuple[torch.Tensor, torch.Tensor]:
  26. """Compute negative log likelihood(nll)
  27. Normally, this function is called in batchify_nll.
  28. Args:
  29. text: (Batch, Length)
  30. text_lengths: (Batch,)
  31. max_lengths: int
  32. """
  33. batch_size = text.size(0)
  34. # For data parallel
  35. if max_length is None:
  36. text = text[:, : text_lengths.max()]
  37. else:
  38. text = text[:, :max_length]
  39. # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
  40. # text: (Batch, Length) -> x, y: (Batch, Length + 1)
  41. x = F.pad(text, [1, 0], "constant", self.eos)
  42. t = F.pad(text, [0, 1], "constant", self.ignore_id)
  43. for i, l in enumerate(text_lengths):
  44. t[i, l] = self.sos
  45. x_lengths = text_lengths + 1
  46. # 2. Forward Language model
  47. # x: (Batch, Length) -> y: (Batch, Length, NVocab)
  48. y, _ = self.lm(x, None)
  49. # 3. Calc negative log likelihood
  50. # nll: (BxL,)
  51. nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
  52. # nll: (BxL,) -> (BxL,)
  53. if max_length is None:
  54. nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
  55. else:
  56. nll.masked_fill_(
  57. make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
  58. 0.0,
  59. )
  60. # nll: (BxL,) -> (B, L)
  61. nll = nll.view(batch_size, -1)
  62. return nll, x_lengths
  63. def batchify_nll(
  64. self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
  65. ) -> Tuple[torch.Tensor, torch.Tensor]:
  66. """Compute negative log likelihood(nll) from transformer language model
  67. To avoid OOM, this fuction seperate the input into batches.
  68. Then call nll for each batch and combine and return results.
  69. Args:
  70. text: (Batch, Length)
  71. text_lengths: (Batch,)
  72. batch_size: int, samples each batch contain when computing nll,
  73. you may change this to avoid OOM or increase
  74. """
  75. total_num = text.size(0)
  76. if total_num <= batch_size:
  77. nll, x_lengths = self.nll(text, text_lengths)
  78. else:
  79. nlls = []
  80. x_lengths = []
  81. max_length = text_lengths.max()
  82. start_idx = 0
  83. while True:
  84. end_idx = min(start_idx + batch_size, total_num)
  85. batch_text = text[start_idx:end_idx, :]
  86. batch_text_lengths = text_lengths[start_idx:end_idx]
  87. # batch_nll: [B * T]
  88. batch_nll, batch_x_lengths = self.nll(
  89. batch_text, batch_text_lengths, max_length=max_length
  90. )
  91. nlls.append(batch_nll)
  92. x_lengths.append(batch_x_lengths)
  93. start_idx = end_idx
  94. if start_idx == total_num:
  95. break
  96. nll = torch.cat(nlls)
  97. x_lengths = torch.cat(x_lengths)
  98. assert nll.size(0) == total_num
  99. assert x_lengths.size(0) == total_num
  100. return nll, x_lengths
  101. def forward(
  102. self, text: torch.Tensor, text_lengths: torch.Tensor
  103. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  104. nll, y_lengths = self.nll(text, text_lengths)
  105. ntokens = y_lengths.sum()
  106. loss = nll.sum() / ntokens
  107. stats = dict(loss=loss.detach())
  108. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  109. loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
  110. return loss, stats, weight
  111. def collect_feats(
  112. self, text: torch.Tensor, text_lengths: torch.Tensor
  113. ) -> Dict[str, torch.Tensor]:
  114. return {}