fastformer.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. """Fastformer attention definition.
  2. Reference:
  3. Wu et al., "Fastformer: Additive Attention Can Be All You Need"
  4. https://arxiv.org/abs/2108.09084
  5. https://github.com/wuch15/Fastformer
  6. """
  7. import numpy
  8. import torch
  9. class FastSelfAttention(torch.nn.Module):
  10. """Fast self-attention used in Fastformer."""
  11. def __init__(
  12. self,
  13. size,
  14. attention_heads,
  15. dropout_rate,
  16. ):
  17. super().__init__()
  18. if size % attention_heads != 0:
  19. raise ValueError(
  20. f"Hidden size ({size}) is not an integer multiple "
  21. f"of attention heads ({attention_heads})"
  22. )
  23. self.attention_head_size = size // attention_heads
  24. self.num_attention_heads = attention_heads
  25. self.query = torch.nn.Linear(size, size)
  26. self.query_att = torch.nn.Linear(size, attention_heads)
  27. self.key = torch.nn.Linear(size, size)
  28. self.key_att = torch.nn.Linear(size, attention_heads)
  29. self.transform = torch.nn.Linear(size, size)
  30. self.dropout = torch.nn.Dropout(dropout_rate)
  31. def espnet_initialization_fn(self):
  32. self.apply(self.init_weights)
  33. def init_weights(self, module):
  34. if isinstance(module, torch.nn.Linear):
  35. module.weight.data.normal_(mean=0.0, std=0.02)
  36. if isinstance(module, torch.nn.Linear) and module.bias is not None:
  37. module.bias.data.zero_()
  38. def transpose_for_scores(self, x):
  39. """Reshape and transpose to compute scores.
  40. Args:
  41. x: (batch, time, size = n_heads * attn_dim)
  42. Returns:
  43. (batch, n_heads, time, attn_dim)
  44. """
  45. new_x_shape = x.shape[:-1] + (
  46. self.num_attention_heads,
  47. self.attention_head_size,
  48. )
  49. return x.reshape(*new_x_shape).transpose(1, 2)
  50. def forward(self, xs_pad, mask):
  51. """Forward method.
  52. Args:
  53. xs_pad: (batch, time, size = n_heads * attn_dim)
  54. mask: (batch, 1, time), nonpadding is 1, padding is 0
  55. Returns:
  56. torch.Tensor: (batch, time, size)
  57. """
  58. batch_size, seq_len, _ = xs_pad.shape
  59. mixed_query_layer = self.query(xs_pad) # (batch, time, size)
  60. mixed_key_layer = self.key(xs_pad) # (batch, time, size)
  61. if mask is not None:
  62. mask = mask.eq(0) # padding is 1, nonpadding is 0
  63. # (batch, n_heads, time)
  64. query_for_score = (
  65. self.query_att(mixed_query_layer).transpose(1, 2)
  66. / self.attention_head_size**0.5
  67. )
  68. if mask is not None:
  69. min_value = float(
  70. numpy.finfo(
  71. torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
  72. ).min
  73. )
  74. query_for_score = query_for_score.masked_fill(mask, min_value)
  75. query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
  76. else:
  77. query_weight = torch.softmax(query_for_score, dim=-1)
  78. query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time)
  79. query_layer = self.transpose_for_scores(
  80. mixed_query_layer
  81. ) # (batch, n_heads, time, attn_dim)
  82. pooled_query = (
  83. torch.matmul(query_weight, query_layer)
  84. .transpose(1, 2)
  85. .reshape(-1, 1, self.num_attention_heads * self.attention_head_size)
  86. ) # (batch, 1, size = n_heads * attn_dim)
  87. pooled_query = self.dropout(pooled_query)
  88. pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size)
  89. mixed_query_key_layer = (
  90. mixed_key_layer * pooled_query_repeat
  91. ) # (batch, time, size)
  92. # (batch, n_heads, time)
  93. query_key_score = (
  94. self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5
  95. ).transpose(1, 2)
  96. if mask is not None:
  97. min_value = float(
  98. numpy.finfo(
  99. torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
  100. ).min
  101. )
  102. query_key_score = query_key_score.masked_fill(mask, min_value)
  103. query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
  104. mask, 0.0
  105. )
  106. else:
  107. query_key_weight = torch.softmax(query_key_score, dim=-1)
  108. query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time)
  109. key_layer = self.transpose_for_scores(
  110. mixed_query_key_layer
  111. ) # (batch, n_heads, time, attn_dim)
  112. pooled_key = torch.matmul(
  113. query_key_weight, key_layer
  114. ) # (batch, n_heads, 1, attn_dim)
  115. pooled_key = self.dropout(pooled_key)
  116. # NOTE: value = query, due to param sharing
  117. weighted_value = (pooled_key * query_layer).transpose(
  118. 1, 2
  119. ) # (batch, time, n_heads, attn_dim)
  120. weighted_value = weighted_value.reshape(
  121. weighted_value.shape[:-2]
  122. + (self.num_attention_heads * self.attention_head_size,)
  123. ) # (batch, time, size)
  124. weighted_value = (
  125. self.dropout(self.transform(weighted_value)) + mixed_query_layer
  126. )
  127. return weighted_value