abs_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  4. from typing import Dict
  5. from typing import Optional
  6. from typing import Tuple
  7. import torch
  8. import torch.nn.functional as F
  9. from funasr.modules.nets_utils import make_pad_mask
  10. from funasr.torch_utils.device_funcs import force_gatherable
  11. from funasr.models.base_model import FunASRModel
  12. class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
  13. """The abstract LM class
  14. To share the loss calculation way among different models,
  15. We uses delegate pattern here:
  16. The instance of this class should be passed to "LanguageModel"
  17. This "model" is one of mediator objects for "Task" class.
  18. """
  19. @abstractmethod
  20. def forward(
  21. self, input: torch.Tensor, hidden: torch.Tensor
  22. ) -> Tuple[torch.Tensor, torch.Tensor]:
  23. raise NotImplementedError
  24. class LanguageModel(FunASRModel):
  25. def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
  26. super().__init__()
  27. self.lm = lm
  28. self.sos = 1
  29. self.eos = 2
  30. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
  31. self.ignore_id = ignore_id
  32. def nll(
  33. self,
  34. text: torch.Tensor,
  35. text_lengths: torch.Tensor,
  36. max_length: Optional[int] = None,
  37. ) -> Tuple[torch.Tensor, torch.Tensor]:
  38. """Compute negative log likelihood(nll)
  39. Normally, this function is called in batchify_nll.
  40. Args:
  41. text: (Batch, Length)
  42. text_lengths: (Batch,)
  43. max_lengths: int
  44. """
  45. batch_size = text.size(0)
  46. # For data parallel
  47. if max_length is None:
  48. text = text[:, : text_lengths.max()]
  49. else:
  50. text = text[:, :max_length]
  51. # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
  52. # text: (Batch, Length) -> x, y: (Batch, Length + 1)
  53. x = F.pad(text, [1, 0], "constant", self.sos)
  54. t = F.pad(text, [0, 1], "constant", self.ignore_id)
  55. for i, l in enumerate(text_lengths):
  56. t[i, l] = self.eos
  57. x_lengths = text_lengths + 1
  58. # 2. Forward Language model
  59. # x: (Batch, Length) -> y: (Batch, Length, NVocab)
  60. y, _ = self.lm(x, None)
  61. # 3. Calc negative log likelihood
  62. # nll: (BxL,)
  63. nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
  64. # nll: (BxL,) -> (BxL,)
  65. if max_length is None:
  66. nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
  67. else:
  68. nll.masked_fill_(
  69. make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
  70. 0.0,
  71. )
  72. # nll: (BxL,) -> (B, L)
  73. nll = nll.view(batch_size, -1)
  74. return nll, x_lengths
  75. def batchify_nll(
  76. self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
  77. ) -> Tuple[torch.Tensor, torch.Tensor]:
  78. """Compute negative log likelihood(nll) from transformer language model
  79. To avoid OOM, this fuction seperate the input into batches.
  80. Then call nll for each batch and combine and return results.
  81. Args:
  82. text: (Batch, Length)
  83. text_lengths: (Batch,)
  84. batch_size: int, samples each batch contain when computing nll,
  85. you may change this to avoid OOM or increase
  86. """
  87. total_num = text.size(0)
  88. if total_num <= batch_size:
  89. nll, x_lengths = self.nll(text, text_lengths)
  90. else:
  91. nlls = []
  92. x_lengths = []
  93. max_length = text_lengths.max()
  94. start_idx = 0
  95. while True:
  96. end_idx = min(start_idx + batch_size, total_num)
  97. batch_text = text[start_idx:end_idx, :]
  98. batch_text_lengths = text_lengths[start_idx:end_idx]
  99. # batch_nll: [B * T]
  100. batch_nll, batch_x_lengths = self.nll(
  101. batch_text, batch_text_lengths, max_length=max_length
  102. )
  103. nlls.append(batch_nll)
  104. x_lengths.append(batch_x_lengths)
  105. start_idx = end_idx
  106. if start_idx == total_num:
  107. break
  108. nll = torch.cat(nlls)
  109. x_lengths = torch.cat(x_lengths)
  110. assert nll.size(0) == total_num
  111. assert x_lengths.size(0) == total_num
  112. return nll, x_lengths
  113. def forward(
  114. self, text: torch.Tensor, text_lengths: torch.Tensor
  115. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  116. nll, y_lengths = self.nll(text, text_lengths)
  117. ntokens = y_lengths.sum()
  118. loss = nll.sum() / ntokens
  119. stats = dict(loss=loss.detach())
  120. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  121. loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
  122. return loss, stats, weight
  123. def collect_feats(
  124. self, text: torch.Tensor, text_lengths: torch.Tensor
  125. ) -> Dict[str, torch.Tensor]:
  126. return {}
  127. class PunctuationModel(FunASRModel):
  128. def __init__(self, punc_model: torch.nn.Module, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
  129. super().__init__()
  130. self.punc_model = punc_model
  131. self.punc_weight = torch.Tensor(punc_weight)
  132. self.sos = 1
  133. self.eos = 2
  134. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
  135. self.ignore_id = ignore_id
  136. # if self.punc_model.with_vad():
  137. # print("This is a vad puncuation model.")
  138. def nll(
  139. self,
  140. text: torch.Tensor,
  141. punc: torch.Tensor,
  142. text_lengths: torch.Tensor,
  143. punc_lengths: torch.Tensor,
  144. max_length: Optional[int] = None,
  145. vad_indexes: Optional[torch.Tensor] = None,
  146. vad_indexes_lengths: Optional[torch.Tensor] = None,
  147. ) -> Tuple[torch.Tensor, torch.Tensor]:
  148. """Compute negative log likelihood(nll)
  149. Normally, this function is called in batchify_nll.
  150. Args:
  151. text: (Batch, Length)
  152. punc: (Batch, Length)
  153. text_lengths: (Batch,)
  154. max_lengths: int
  155. """
  156. batch_size = text.size(0)
  157. # For data parallel
  158. if max_length is None:
  159. text = text[:, :text_lengths.max()]
  160. punc = punc[:, :text_lengths.max()]
  161. else:
  162. text = text[:, :max_length]
  163. punc = punc[:, :max_length]
  164. if self.punc_model.with_vad():
  165. # Should be VadRealtimeTransformer
  166. assert vad_indexes is not None
  167. y, _ = self.punc_model(text, text_lengths, vad_indexes)
  168. else:
  169. # Should be TargetDelayTransformer,
  170. y, _ = self.punc_model(text, text_lengths)
  171. # Calc negative log likelihood
  172. # nll: (BxL,)
  173. if self.training == False:
  174. _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
  175. from sklearn.metrics import f1_score
  176. f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
  177. indices.squeeze(-1).detach().cpu().numpy(),
  178. average='micro')
  179. nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
  180. return nll, text_lengths
  181. else:
  182. self.punc_weight = self.punc_weight.to(punc.device)
  183. nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
  184. ignore_index=self.ignore_id)
  185. # nll: (BxL,) -> (BxL,)
  186. if max_length is None:
  187. nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
  188. else:
  189. nll.masked_fill_(
  190. make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
  191. 0.0,
  192. )
  193. # nll: (BxL,) -> (B, L)
  194. nll = nll.view(batch_size, -1)
  195. return nll, text_lengths
  196. def batchify_nll(self,
  197. text: torch.Tensor,
  198. punc: torch.Tensor,
  199. text_lengths: torch.Tensor,
  200. punc_lengths: torch.Tensor,
  201. batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
  202. """Compute negative log likelihood(nll) from transformer language model
  203. To avoid OOM, this fuction seperate the input into batches.
  204. Then call nll for each batch and combine and return results.
  205. Args:
  206. text: (Batch, Length)
  207. punc: (Batch, Length)
  208. text_lengths: (Batch,)
  209. batch_size: int, samples each batch contain when computing nll,
  210. you may change this to avoid OOM or increase
  211. """
  212. total_num = text.size(0)
  213. if total_num <= batch_size:
  214. nll, x_lengths = self.nll(text, punc, text_lengths)
  215. else:
  216. nlls = []
  217. x_lengths = []
  218. max_length = text_lengths.max()
  219. start_idx = 0
  220. while True:
  221. end_idx = min(start_idx + batch_size, total_num)
  222. batch_text = text[start_idx:end_idx, :]
  223. batch_punc = punc[start_idx:end_idx, :]
  224. batch_text_lengths = text_lengths[start_idx:end_idx]
  225. # batch_nll: [B * T]
  226. batch_nll, batch_x_lengths = self.nll(batch_text, batch_punc, batch_text_lengths, max_length=max_length)
  227. nlls.append(batch_nll)
  228. x_lengths.append(batch_x_lengths)
  229. start_idx = end_idx
  230. if start_idx == total_num:
  231. break
  232. nll = torch.cat(nlls)
  233. x_lengths = torch.cat(x_lengths)
  234. assert nll.size(0) == total_num
  235. assert x_lengths.size(0) == total_num
  236. return nll, x_lengths
  237. def forward(
  238. self,
  239. text: torch.Tensor,
  240. punc: torch.Tensor,
  241. text_lengths: torch.Tensor,
  242. punc_lengths: torch.Tensor,
  243. vad_indexes: Optional[torch.Tensor] = None,
  244. vad_indexes_lengths: Optional[torch.Tensor] = None,
  245. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  246. nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
  247. ntokens = y_lengths.sum()
  248. loss = nll.sum() / ntokens
  249. stats = dict(loss=loss.detach())
  250. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  251. loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
  252. return loss, stats, weight
  253. def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
  254. text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
  255. return {}
  256. def inference(self,
  257. text: torch.Tensor,
  258. text_lengths: torch.Tensor,
  259. vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
  260. if self.punc_model.with_vad():
  261. assert vad_indexes is not None
  262. return self.punc_model(text, text_lengths, vad_indexes)
  263. else:
  264. return self.punc_model(text, text_lengths)