| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- """Fastformer attention definition.
- Reference:
- Wu et al., "Fastformer: Additive Attention Can Be All You Need"
- https://arxiv.org/abs/2108.09084
- https://github.com/wuch15/Fastformer
- """
- import numpy
- import torch
- class FastSelfAttention(torch.nn.Module):
- """Fast self-attention used in Fastformer."""
- def __init__(
- self,
- size,
- attention_heads,
- dropout_rate,
- ):
- super().__init__()
- if size % attention_heads != 0:
- raise ValueError(
- f"Hidden size ({size}) is not an integer multiple "
- f"of attention heads ({attention_heads})"
- )
- self.attention_head_size = size // attention_heads
- self.num_attention_heads = attention_heads
- self.query = torch.nn.Linear(size, size)
- self.query_att = torch.nn.Linear(size, attention_heads)
- self.key = torch.nn.Linear(size, size)
- self.key_att = torch.nn.Linear(size, attention_heads)
- self.transform = torch.nn.Linear(size, size)
- self.dropout = torch.nn.Dropout(dropout_rate)
- def espnet_initialization_fn(self):
- self.apply(self.init_weights)
- def init_weights(self, module):
- if isinstance(module, torch.nn.Linear):
- module.weight.data.normal_(mean=0.0, std=0.02)
- if isinstance(module, torch.nn.Linear) and module.bias is not None:
- module.bias.data.zero_()
- def transpose_for_scores(self, x):
- """Reshape and transpose to compute scores.
- Args:
- x: (batch, time, size = n_heads * attn_dim)
- Returns:
- (batch, n_heads, time, attn_dim)
- """
- new_x_shape = x.shape[:-1] + (
- self.num_attention_heads,
- self.attention_head_size,
- )
- return x.reshape(*new_x_shape).transpose(1, 2)
- def forward(self, xs_pad, mask):
- """Forward method.
- Args:
- xs_pad: (batch, time, size = n_heads * attn_dim)
- mask: (batch, 1, time), nonpadding is 1, padding is 0
- Returns:
- torch.Tensor: (batch, time, size)
- """
- batch_size, seq_len, _ = xs_pad.shape
- mixed_query_layer = self.query(xs_pad) # (batch, time, size)
- mixed_key_layer = self.key(xs_pad) # (batch, time, size)
- if mask is not None:
- mask = mask.eq(0) # padding is 1, nonpadding is 0
- # (batch, n_heads, time)
- query_for_score = (
- self.query_att(mixed_query_layer).transpose(1, 2)
- / self.attention_head_size**0.5
- )
- if mask is not None:
- min_value = float(
- numpy.finfo(
- torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
- ).min
- )
- query_for_score = query_for_score.masked_fill(mask, min_value)
- query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
- else:
- query_weight = torch.softmax(query_for_score, dim=-1)
- query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time)
- query_layer = self.transpose_for_scores(
- mixed_query_layer
- ) # (batch, n_heads, time, attn_dim)
- pooled_query = (
- torch.matmul(query_weight, query_layer)
- .transpose(1, 2)
- .reshape(-1, 1, self.num_attention_heads * self.attention_head_size)
- ) # (batch, 1, size = n_heads * attn_dim)
- pooled_query = self.dropout(pooled_query)
- pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size)
- mixed_query_key_layer = (
- mixed_key_layer * pooled_query_repeat
- ) # (batch, time, size)
- # (batch, n_heads, time)
- query_key_score = (
- self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5
- ).transpose(1, 2)
- if mask is not None:
- min_value = float(
- numpy.finfo(
- torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
- ).min
- )
- query_key_score = query_key_score.masked_fill(mask, min_value)
- query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
- mask, 0.0
- )
- else:
- query_key_weight = torch.softmax(query_key_score, dim=-1)
- query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time)
- key_layer = self.transpose_for_scores(
- mixed_query_key_layer
- ) # (batch, n_heads, time, attn_dim)
- pooled_key = torch.matmul(
- query_key_weight, key_layer
- ) # (batch, n_heads, 1, attn_dim)
- pooled_key = self.dropout(pooled_key)
- # NOTE: value = query, due to param sharing
- weighted_value = (pooled_key * query_layer).transpose(
- 1, 2
- ) # (batch, time, n_heads, attn_dim)
- weighted_value = weighted_value.reshape(
- weighted_value.shape[:-2]
- + (self.num_attention_heads * self.attention_head_size,)
- ) # (batch, time, size)
- weighted_value = (
- self.dropout(self.transform(weighted_value)) + mixed_query_layer
- )
- return weighted_value
|