espnet_model.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from typing import Dict
  2. from typing import Optional
  3. from typing import Tuple
  4. import torch
  5. import torch.nn.functional as F
  6. from typeguard import check_argument_types
  7. from funasr.modules.nets_utils import make_pad_mask
  8. from funasr.punctuation.abs_model import AbsPunctuation
  9. from funasr.torch_utils.device_funcs import force_gatherable
  10. from funasr.train.abs_espnet_model import AbsESPnetModel
  11. class ESPnetPunctuationModel(AbsESPnetModel):
  12. def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
  13. assert check_argument_types()
  14. super().__init__()
  15. self.punc_model = punc_model
  16. self.punc_weight = torch.Tensor(punc_weight)
  17. self.sos = 1
  18. self.eos = 2
  19. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
  20. self.ignore_id = ignore_id
  21. if self.punc_model.with_vad():
  22. print("This is a vad puncuation model.")
  23. def nll(
  24. self,
  25. text: torch.Tensor,
  26. punc: torch.Tensor,
  27. text_lengths: torch.Tensor,
  28. punc_lengths: torch.Tensor,
  29. max_length: Optional[int] = None,
  30. vad_indexes: Optional[torch.Tensor] = None,
  31. vad_indexes_lengths: Optional[torch.Tensor] = None,
  32. ) -> Tuple[torch.Tensor, torch.Tensor]:
  33. """Compute negative log likelihood(nll)
  34. Normally, this function is called in batchify_nll.
  35. Args:
  36. text: (Batch, Length)
  37. punc: (Batch, Length)
  38. text_lengths: (Batch,)
  39. max_lengths: int
  40. """
  41. batch_size = text.size(0)
  42. # For data parallel
  43. if max_length is None:
  44. text = text[:, :text_lengths.max()]
  45. punc = punc[:, :text_lengths.max()]
  46. else:
  47. text = text[:, :max_length]
  48. punc = punc[:, :max_length]
  49. if self.punc_model.with_vad():
  50. # Should be VadRealtimeTransformer
  51. assert vad_indexes is not None
  52. y, _ = self.punc_model(text, text_lengths, vad_indexes)
  53. else:
  54. # Should be TargetDelayTransformer,
  55. y, _ = self.punc_model(text, text_lengths)
  56. # Calc negative log likelihood
  57. # nll: (BxL,)
  58. if self.training == False:
  59. _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
  60. from sklearn.metrics import f1_score
  61. f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
  62. indices.squeeze(-1).detach().cpu().numpy(),
  63. average='micro')
  64. nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
  65. return nll, text_lengths
  66. else:
  67. self.punc_weight = self.punc_weight.to(punc.device)
  68. nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
  69. # nll: (BxL,) -> (BxL,)
  70. if max_length is None:
  71. nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
  72. else:
  73. nll.masked_fill_(
  74. make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
  75. 0.0,
  76. )
  77. # nll: (BxL,) -> (B, L)
  78. nll = nll.view(batch_size, -1)
  79. return nll, text_lengths
  80. def batchify_nll(self,
  81. text: torch.Tensor,
  82. punc: torch.Tensor,
  83. text_lengths: torch.Tensor,
  84. punc_lengths: torch.Tensor,
  85. batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
  86. """Compute negative log likelihood(nll) from transformer language model
  87. To avoid OOM, this fuction seperate the input into batches.
  88. Then call nll for each batch and combine and return results.
  89. Args:
  90. text: (Batch, Length)
  91. punc: (Batch, Length)
  92. text_lengths: (Batch,)
  93. batch_size: int, samples each batch contain when computing nll,
  94. you may change this to avoid OOM or increase
  95. """
  96. total_num = text.size(0)
  97. if total_num <= batch_size:
  98. nll, x_lengths = self.nll(text, punc, text_lengths)
  99. else:
  100. nlls = []
  101. x_lengths = []
  102. max_length = text_lengths.max()
  103. start_idx = 0
  104. while True:
  105. end_idx = min(start_idx + batch_size, total_num)
  106. batch_text = text[start_idx:end_idx, :]
  107. batch_punc = punc[start_idx:end_idx, :]
  108. batch_text_lengths = text_lengths[start_idx:end_idx]
  109. # batch_nll: [B * T]
  110. batch_nll, batch_x_lengths = self.nll(batch_text, batch_punc, batch_text_lengths, max_length=max_length)
  111. nlls.append(batch_nll)
  112. x_lengths.append(batch_x_lengths)
  113. start_idx = end_idx
  114. if start_idx == total_num:
  115. break
  116. nll = torch.cat(nlls)
  117. x_lengths = torch.cat(x_lengths)
  118. assert nll.size(0) == total_num
  119. assert x_lengths.size(0) == total_num
  120. return nll, x_lengths
  121. def forward(
  122. self,
  123. text: torch.Tensor,
  124. punc: torch.Tensor,
  125. text_lengths: torch.Tensor,
  126. punc_lengths: torch.Tensor,
  127. vad_indexes: Optional[torch.Tensor] = None,
  128. vad_indexes_lengths: Optional[torch.Tensor] = None,
  129. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  130. nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
  131. ntokens = y_lengths.sum()
  132. loss = nll.sum() / ntokens
  133. stats = dict(loss=loss.detach())
  134. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  135. loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
  136. return loss, stats, weight
  137. def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
  138. text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
  139. return {}
  140. def inference(self,
  141. text: torch.Tensor,
  142. text_lengths: torch.Tensor,
  143. vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
  144. if self.punc_model.with_vad():
  145. assert vad_indexes is not None
  146. return self.punc_model(text, text_lengths, vad_indexes)
  147. else:
  148. return self.punc_model(text, text_lengths)