abs_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  4. from typing import Dict
  5. from typing import Optional
  6. from typing import Tuple
  7. import torch
  8. import torch.nn.functional as F
  9. from typeguard import check_argument_types
  10. from funasr.modules.nets_utils import make_pad_mask
  11. from funasr.torch_utils.device_funcs import force_gatherable
  12. from funasr.models.base_model import FunASRModel
  13. class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
  14. """The abstract LM class
  15. To share the loss calculation way among different models,
  16. We uses delegate pattern here:
  17. The instance of this class should be passed to "LanguageModel"
  18. This "model" is one of mediator objects for "Task" class.
  19. """
  20. @abstractmethod
  21. def forward(
  22. self, input: torch.Tensor, hidden: torch.Tensor
  23. ) -> Tuple[torch.Tensor, torch.Tensor]:
  24. raise NotImplementedError
  25. class LanguageModel(FunASRModel):
  26. def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
  27. assert check_argument_types()
  28. super().__init__()
  29. self.lm = lm
  30. self.sos = 1
  31. self.eos = 2
  32. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
  33. self.ignore_id = ignore_id
  34. def nll(
  35. self,
  36. text: torch.Tensor,
  37. text_lengths: torch.Tensor,
  38. max_length: Optional[int] = None,
  39. ) -> Tuple[torch.Tensor, torch.Tensor]:
  40. """Compute negative log likelihood(nll)
  41. Normally, this function is called in batchify_nll.
  42. Args:
  43. text: (Batch, Length)
  44. text_lengths: (Batch,)
  45. max_lengths: int
  46. """
  47. batch_size = text.size(0)
  48. # For data parallel
  49. if max_length is None:
  50. text = text[:, : text_lengths.max()]
  51. else:
  52. text = text[:, :max_length]
  53. # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
  54. # text: (Batch, Length) -> x, y: (Batch, Length + 1)
  55. x = F.pad(text, [1, 0], "constant", self.sos)
  56. t = F.pad(text, [0, 1], "constant", self.ignore_id)
  57. for i, l in enumerate(text_lengths):
  58. t[i, l] = self.eos
  59. x_lengths = text_lengths + 1
  60. # 2. Forward Language model
  61. # x: (Batch, Length) -> y: (Batch, Length, NVocab)
  62. y, _ = self.lm(x, None)
  63. # 3. Calc negative log likelihood
  64. # nll: (BxL,)
  65. nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
  66. # nll: (BxL,) -> (BxL,)
  67. if max_length is None:
  68. nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
  69. else:
  70. nll.masked_fill_(
  71. make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
  72. 0.0,
  73. )
  74. # nll: (BxL,) -> (B, L)
  75. nll = nll.view(batch_size, -1)
  76. return nll, x_lengths
  77. def batchify_nll(
  78. self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
  79. ) -> Tuple[torch.Tensor, torch.Tensor]:
  80. """Compute negative log likelihood(nll) from transformer language model
  81. To avoid OOM, this fuction seperate the input into batches.
  82. Then call nll for each batch and combine and return results.
  83. Args:
  84. text: (Batch, Length)
  85. text_lengths: (Batch,)
  86. batch_size: int, samples each batch contain when computing nll,
  87. you may change this to avoid OOM or increase
  88. """
  89. total_num = text.size(0)
  90. if total_num <= batch_size:
  91. nll, x_lengths = self.nll(text, text_lengths)
  92. else:
  93. nlls = []
  94. x_lengths = []
  95. max_length = text_lengths.max()
  96. start_idx = 0
  97. while True:
  98. end_idx = min(start_idx + batch_size, total_num)
  99. batch_text = text[start_idx:end_idx, :]
  100. batch_text_lengths = text_lengths[start_idx:end_idx]
  101. # batch_nll: [B * T]
  102. batch_nll, batch_x_lengths = self.nll(
  103. batch_text, batch_text_lengths, max_length=max_length
  104. )
  105. nlls.append(batch_nll)
  106. x_lengths.append(batch_x_lengths)
  107. start_idx = end_idx
  108. if start_idx == total_num:
  109. break
  110. nll = torch.cat(nlls)
  111. x_lengths = torch.cat(x_lengths)
  112. assert nll.size(0) == total_num
  113. assert x_lengths.size(0) == total_num
  114. return nll, x_lengths
  115. def forward(
  116. self, text: torch.Tensor, text_lengths: torch.Tensor
  117. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  118. nll, y_lengths = self.nll(text, text_lengths)
  119. ntokens = y_lengths.sum()
  120. loss = nll.sum() / ntokens
  121. stats = dict(loss=loss.detach())
  122. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  123. loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
  124. return loss, stats, weight
  125. def collect_feats(
  126. self, text: torch.Tensor, text_lengths: torch.Tensor
  127. ) -> Dict[str, torch.Tensor]:
  128. return {}
  129. class PunctuationModel(FunASRModel):
  130. def __init__(self, punc_model: torch.nn.Module, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
  131. assert check_argument_types()
  132. super().__init__()
  133. self.punc_model = punc_model
  134. self.punc_weight = torch.Tensor(punc_weight)
  135. self.sos = 1
  136. self.eos = 2
  137. # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
  138. self.ignore_id = ignore_id
  139. # if self.punc_model.with_vad():
  140. # print("This is a vad puncuation model.")
  141. def nll(
  142. self,
  143. text: torch.Tensor,
  144. punc: torch.Tensor,
  145. text_lengths: torch.Tensor,
  146. punc_lengths: torch.Tensor,
  147. max_length: Optional[int] = None,
  148. vad_indexes: Optional[torch.Tensor] = None,
  149. vad_indexes_lengths: Optional[torch.Tensor] = None,
  150. ) -> Tuple[torch.Tensor, torch.Tensor]:
  151. """Compute negative log likelihood(nll)
  152. Normally, this function is called in batchify_nll.
  153. Args:
  154. text: (Batch, Length)
  155. punc: (Batch, Length)
  156. text_lengths: (Batch,)
  157. max_lengths: int
  158. """
  159. batch_size = text.size(0)
  160. # For data parallel
  161. if max_length is None:
  162. text = text[:, :text_lengths.max()]
  163. punc = punc[:, :text_lengths.max()]
  164. else:
  165. text = text[:, :max_length]
  166. punc = punc[:, :max_length]
  167. if self.punc_model.with_vad():
  168. # Should be VadRealtimeTransformer
  169. assert vad_indexes is not None
  170. y, _ = self.punc_model(text, text_lengths, vad_indexes)
  171. else:
  172. # Should be TargetDelayTransformer,
  173. y, _ = self.punc_model(text, text_lengths)
  174. # Calc negative log likelihood
  175. # nll: (BxL,)
  176. if self.training == False:
  177. _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
  178. from sklearn.metrics import f1_score
  179. f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
  180. indices.squeeze(-1).detach().cpu().numpy(),
  181. average='micro')
  182. nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
  183. return nll, text_lengths
  184. else:
  185. self.punc_weight = self.punc_weight.to(punc.device)
  186. nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
  187. ignore_index=self.ignore_id)
  188. # nll: (BxL,) -> (BxL,)
  189. if max_length is None:
  190. nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
  191. else:
  192. nll.masked_fill_(
  193. make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
  194. 0.0,
  195. )
  196. # nll: (BxL,) -> (B, L)
  197. nll = nll.view(batch_size, -1)
  198. return nll, text_lengths
  199. def batchify_nll(self,
  200. text: torch.Tensor,
  201. punc: torch.Tensor,
  202. text_lengths: torch.Tensor,
  203. punc_lengths: torch.Tensor,
  204. batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
  205. """Compute negative log likelihood(nll) from transformer language model
  206. To avoid OOM, this fuction seperate the input into batches.
  207. Then call nll for each batch and combine and return results.
  208. Args:
  209. text: (Batch, Length)
  210. punc: (Batch, Length)
  211. text_lengths: (Batch,)
  212. batch_size: int, samples each batch contain when computing nll,
  213. you may change this to avoid OOM or increase
  214. """
  215. total_num = text.size(0)
  216. if total_num <= batch_size:
  217. nll, x_lengths = self.nll(text, punc, text_lengths)
  218. else:
  219. nlls = []
  220. x_lengths = []
  221. max_length = text_lengths.max()
  222. start_idx = 0
  223. while True:
  224. end_idx = min(start_idx + batch_size, total_num)
  225. batch_text = text[start_idx:end_idx, :]
  226. batch_punc = punc[start_idx:end_idx, :]
  227. batch_text_lengths = text_lengths[start_idx:end_idx]
  228. # batch_nll: [B * T]
  229. batch_nll, batch_x_lengths = self.nll(batch_text, batch_punc, batch_text_lengths, max_length=max_length)
  230. nlls.append(batch_nll)
  231. x_lengths.append(batch_x_lengths)
  232. start_idx = end_idx
  233. if start_idx == total_num:
  234. break
  235. nll = torch.cat(nlls)
  236. x_lengths = torch.cat(x_lengths)
  237. assert nll.size(0) == total_num
  238. assert x_lengths.size(0) == total_num
  239. return nll, x_lengths
  240. def forward(
  241. self,
  242. text: torch.Tensor,
  243. punc: torch.Tensor,
  244. text_lengths: torch.Tensor,
  245. punc_lengths: torch.Tensor,
  246. vad_indexes: Optional[torch.Tensor] = None,
  247. vad_indexes_lengths: Optional[torch.Tensor] = None,
  248. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  249. nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
  250. ntokens = y_lengths.sum()
  251. loss = nll.sum() / ntokens
  252. stats = dict(loss=loss.detach())
  253. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  254. loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
  255. return loss, stats, weight
  256. def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
  257. text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
  258. return {}
  259. def inference(self,
  260. text: torch.Tensor,
  261. text_lengths: torch.Tensor,
  262. vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
  263. if self.punc_model.with_vad():
  264. assert vad_indexes is not None
  265. return self.punc_model(text, text_lengths, vad_indexes)
  266. else:
  267. return self.punc_model(text, text_lengths)