calculate_all_attentions.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from collections import defaultdict
  2. from typing import Dict
  3. from typing import List
  4. import torch
  5. from funasr.modules.rnn.attentions import AttAdd
  6. from funasr.modules.rnn.attentions import AttCov
  7. from funasr.modules.rnn.attentions import AttCovLoc
  8. from funasr.modules.rnn.attentions import AttDot
  9. from funasr.modules.rnn.attentions import AttForward
  10. from funasr.modules.rnn.attentions import AttForwardTA
  11. from funasr.modules.rnn.attentions import AttLoc
  12. from funasr.modules.rnn.attentions import AttLoc2D
  13. from funasr.modules.rnn.attentions import AttLocRec
  14. from funasr.modules.rnn.attentions import AttMultiHeadAdd
  15. from funasr.modules.rnn.attentions import AttMultiHeadDot
  16. from funasr.modules.rnn.attentions import AttMultiHeadLoc
  17. from funasr.modules.rnn.attentions import AttMultiHeadMultiResLoc
  18. from funasr.modules.rnn.attentions import NoAtt
  19. from funasr.modules.attention import MultiHeadedAttention
  20. from funasr.models.base_model import FunASRModel
  21. @torch.no_grad()
  22. def calculate_all_attentions(
  23. model: FunASRModel, batch: Dict[str, torch.Tensor]
  24. ) -> Dict[str, List[torch.Tensor]]:
  25. """Derive the outputs from the all attention layers
  26. Args:
  27. model:
  28. batch: same as forward
  29. Returns:
  30. return_dict: A dict of a list of tensor.
  31. key_names x batch x (D1, D2, ...)
  32. """
  33. bs = len(next(iter(batch.values())))
  34. assert all(len(v) == bs for v in batch.values()), {
  35. k: v.shape for k, v in batch.items()
  36. }
  37. # 1. Register forward_hook fn to save the output from specific layers
  38. outputs = {}
  39. handles = {}
  40. for name, modu in model.named_modules():
  41. def hook(module, input, output, name=name):
  42. if isinstance(module, MultiHeadedAttention):
  43. # NOTE(kamo): MultiHeadedAttention doesn't return attention weight
  44. # attn: (B, Head, Tout, Tin)
  45. outputs[name] = module.attn.detach().cpu()
  46. elif isinstance(module, AttLoc2D):
  47. c, w = output
  48. # w: previous concate attentions
  49. # w: (B, nprev, Tin)
  50. att_w = w[:, -1].detach().cpu()
  51. outputs.setdefault(name, []).append(att_w)
  52. elif isinstance(module, (AttCov, AttCovLoc)):
  53. c, w = output
  54. assert isinstance(w, list), type(w)
  55. # w: list of previous attentions
  56. # w: nprev x (B, Tin)
  57. att_w = w[-1].detach().cpu()
  58. outputs.setdefault(name, []).append(att_w)
  59. elif isinstance(module, AttLocRec):
  60. # w: (B, Tin)
  61. c, (w, (att_h, att_c)) = output
  62. att_w = w.detach().cpu()
  63. outputs.setdefault(name, []).append(att_w)
  64. elif isinstance(
  65. module,
  66. (
  67. AttMultiHeadDot,
  68. AttMultiHeadAdd,
  69. AttMultiHeadLoc,
  70. AttMultiHeadMultiResLoc,
  71. ),
  72. ):
  73. c, w = output
  74. # w: nhead x (B, Tin)
  75. assert isinstance(w, list), type(w)
  76. att_w = [_w.detach().cpu() for _w in w]
  77. outputs.setdefault(name, []).append(att_w)
  78. elif isinstance(
  79. module,
  80. (
  81. AttAdd,
  82. AttDot,
  83. AttForward,
  84. AttForwardTA,
  85. AttLoc,
  86. NoAtt,
  87. ),
  88. ):
  89. c, w = output
  90. att_w = w.detach().cpu()
  91. outputs.setdefault(name, []).append(att_w)
  92. handle = modu.register_forward_hook(hook)
  93. handles[name] = handle
  94. # 2. Just forward one by one sample.
  95. # Batch-mode can't be used to keep requirements small for each models.
  96. keys = []
  97. for k in batch:
  98. if not k.endswith("_lengths"):
  99. keys.append(k)
  100. return_dict = defaultdict(list)
  101. for ibatch in range(bs):
  102. # *: (B, L, ...) -> (1, L2, ...)
  103. _sample = {
  104. k: batch[k][ibatch, None, : batch[k + "_lengths"][ibatch]]
  105. if k + "_lengths" in batch
  106. else batch[k][ibatch, None]
  107. for k in keys
  108. }
  109. # *_lengths: (B,) -> (1,)
  110. _sample.update(
  111. {
  112. k + "_lengths": batch[k + "_lengths"][ibatch, None]
  113. for k in keys
  114. if k + "_lengths" in batch
  115. }
  116. )
  117. model(**_sample)
  118. # Derive the attention results
  119. for name, output in outputs.items():
  120. if isinstance(output, list):
  121. if isinstance(output[0], list):
  122. # output: nhead x (Tout, Tin)
  123. output = torch.stack(
  124. [
  125. # Tout x (1, Tin) -> (Tout, Tin)
  126. torch.cat([o[idx] for o in output], dim=0)
  127. for idx in range(len(output[0]))
  128. ],
  129. dim=0,
  130. )
  131. else:
  132. # Tout x (1, Tin) -> (Tout, Tin)
  133. output = torch.cat(output, dim=0)
  134. else:
  135. # output: (1, NHead, Tout, Tin) -> (NHead, Tout, Tin)
  136. output = output.squeeze(0)
  137. # output: (Tout, Tin) or (NHead, Tout, Tin)
  138. return_dict[name].append(output)
  139. outputs.clear()
  140. # 3. Remove all hooks
  141. for _, handle in handles.items():
  142. handle.remove()
  143. return dict(return_dict)