abs_model.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Tuple
  4. import torch
  5. from typing import Dict
  6. from typing import Optional
  7. from typing import Tuple
  8. import torch
  9. import torch.nn.functional as F
  10. from typeguard import check_argument_types
  11. from funasr.modules.nets_utils import make_pad_mask
  12. from funasr.torch_utils.device_funcs import force_gatherable
  13. from funasr.train.abs_espnet_model import AbsESPnetModel
  14. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  15. class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
  16. """The abstract class
  17. To share the loss calculation way among different models,
  18. We uses delegate pattern here:
  19. The instance of this class should be passed to "LanguageModel"
  20. This "model" is one of mediator objects for "Task" class.
  21. """
  22. @abstractmethod
  23. def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  24. raise NotImplementedError
  25. @abstractmethod
  26. def with_vad(self) -> bool:
  27. raise NotImplementedError
  28. class PunctuationModel(AbsESPnetModel):
  29. def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
  30. assert check_argument_types()
  31. super().__init__()
  32. self.punc_model = punc_model
  33. self.punc_weight = torch.Tensor(punc_weight)
  34. self.sos = 1
  35. self.eos = 2
  36. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
  37. self.ignore_id = ignore_id
  38. # if self.punc_model.with_vad():
  39. # print("This is a vad puncuation model.")
  40. def nll(
  41. self,
  42. text: torch.Tensor,
  43. punc: torch.Tensor,
  44. text_lengths: torch.Tensor,
  45. punc_lengths: torch.Tensor,
  46. max_length: Optional[int] = None,
  47. vad_indexes: Optional[torch.Tensor] = None,
  48. vad_indexes_lengths: Optional[torch.Tensor] = None,
  49. ) -> Tuple[torch.Tensor, torch.Tensor]:
  50. """Compute negative log likelihood(nll)
  51. Normally, this function is called in batchify_nll.
  52. Args:
  53. text: (Batch, Length)
  54. punc: (Batch, Length)
  55. text_lengths: (Batch,)
  56. max_lengths: int
  57. """
  58. batch_size = text.size(0)
  59. # For data parallel
  60. if max_length is None:
  61. text = text[:, :text_lengths.max()]
  62. punc = punc[:, :text_lengths.max()]
  63. else:
  64. text = text[:, :max_length]
  65. punc = punc[:, :max_length]
  66. if self.punc_model.with_vad():
  67. # Should be VadRealtimeTransformer
  68. assert vad_indexes is not None
  69. y, _ = self.punc_model(text, text_lengths, vad_indexes)
  70. else:
  71. # Should be TargetDelayTransformer,
  72. y, _ = self.punc_model(text, text_lengths)
  73. # Calc negative log likelihood
  74. # nll: (BxL,)
  75. if self.training == False:
  76. _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
  77. from sklearn.metrics import f1_score
  78. f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
  79. indices.squeeze(-1).detach().cpu().numpy(),
  80. average='micro')
  81. nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
  82. return nll, text_lengths
  83. else:
  84. self.punc_weight = self.punc_weight.to(punc.device)
  85. nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
  86. ignore_index=self.ignore_id)
  87. # nll: (BxL,) -> (BxL,)
  88. if max_length is None:
  89. nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
  90. else:
  91. nll.masked_fill_(
  92. make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
  93. 0.0,
  94. )
  95. # nll: (BxL,) -> (B, L)
  96. nll = nll.view(batch_size, -1)
  97. return nll, text_lengths
  98. def batchify_nll(self,
  99. text: torch.Tensor,
  100. punc: torch.Tensor,
  101. text_lengths: torch.Tensor,
  102. punc_lengths: torch.Tensor,
  103. batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
  104. """Compute negative log likelihood(nll) from transformer language model
  105. To avoid OOM, this fuction seperate the input into batches.
  106. Then call nll for each batch and combine and return results.
  107. Args:
  108. text: (Batch, Length)
  109. punc: (Batch, Length)
  110. text_lengths: (Batch,)
  111. batch_size: int, samples each batch contain when computing nll,
  112. you may change this to avoid OOM or increase
  113. """
  114. total_num = text.size(0)
  115. if total_num <= batch_size:
  116. nll, x_lengths = self.nll(text, punc, text_lengths)
  117. else:
  118. nlls = []
  119. x_lengths = []
  120. max_length = text_lengths.max()
  121. start_idx = 0
  122. while True:
  123. end_idx = min(start_idx + batch_size, total_num)
  124. batch_text = text[start_idx:end_idx, :]
  125. batch_punc = punc[start_idx:end_idx, :]
  126. batch_text_lengths = text_lengths[start_idx:end_idx]
  127. # batch_nll: [B * T]
  128. batch_nll, batch_x_lengths = self.nll(batch_text, batch_punc, batch_text_lengths, max_length=max_length)
  129. nlls.append(batch_nll)
  130. x_lengths.append(batch_x_lengths)
  131. start_idx = end_idx
  132. if start_idx == total_num:
  133. break
  134. nll = torch.cat(nlls)
  135. x_lengths = torch.cat(x_lengths)
  136. assert nll.size(0) == total_num
  137. assert x_lengths.size(0) == total_num
  138. return nll, x_lengths
  139. def forward(
  140. self,
  141. text: torch.Tensor,
  142. punc: torch.Tensor,
  143. text_lengths: torch.Tensor,
  144. punc_lengths: torch.Tensor,
  145. vad_indexes: Optional[torch.Tensor] = None,
  146. vad_indexes_lengths: Optional[torch.Tensor] = None,
  147. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  148. nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
  149. ntokens = y_lengths.sum()
  150. loss = nll.sum() / ntokens
  151. stats = dict(loss=loss.detach())
  152. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  153. loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
  154. return loss, stats, weight
  155. def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
  156. text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
  157. return {}
  158. def inference(self,
  159. text: torch.Tensor,
  160. text_lengths: torch.Tensor,
  161. vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
  162. if self.punc_model.with_vad():
  163. assert vad_indexes is not None
  164. return self.punc_model(text, text_lengths, vad_indexes)
  165. else:
  166. return self.punc_model(text, text_lengths)