| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- import math
- from typing import Dict, List, Optional, Tuple
- import torch
- import torch.nn.functional as F
- from torch import Tensor, nn
- from torch.nn import Parameter
- from funasr.models.data2vec.quant_noise import quant_noise
- class FairseqDropout(nn.Module):
- def __init__(self, p, module_name=None):
- super().__init__()
- self.p = p
- self.module_name = module_name
- self.apply_during_inference = False
- def forward(self, x, inplace: bool = False):
- if self.p > 0 and (self.training or self.apply_during_inference):
- return F.dropout(x, p=self.p, training=True, inplace=inplace)
- else:
- return x
- def make_generation_fast_(
- self,
- name: str,
- retain_dropout: bool = False,
- retain_dropout_modules: Optional[List[str]] = None,
- **kwargs
- ):
- if retain_dropout:
- if retain_dropout_modules is not None and self.module_name is None:
- logging.warning(
- "Cannot enable dropout during inference for module {} "
- "because module_name was not set".format(name)
- )
- elif (
- retain_dropout_modules is None # if None, apply to all modules
- or self.module_name in retain_dropout_modules
- ):
- logging.info(
- "Enabling dropout during inference for module: {}".format(name)
- )
- self.apply_during_inference = True
- else:
- logging.info("Disabling dropout for module: {}".format(name))
- class MultiheadAttention(nn.Module):
- """Multi-headed attention.
- See "Attention Is All You Need" for more details.
- """
- def __init__(
- self,
- embed_dim,
- num_heads,
- kdim=None,
- vdim=None,
- dropout=0.0,
- bias=True,
- add_bias_kv=False,
- add_zero_attn=False,
- self_attention=False,
- encoder_decoder_attention=False,
- q_noise=0.0,
- qn_block_size=8,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.kdim = kdim if kdim is not None else embed_dim
- self.vdim = vdim if vdim is not None else embed_dim
- self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
- self.num_heads = num_heads
- self.dropout_module = FairseqDropout(
- dropout, module_name=self.__class__.__name__
- )
- self.head_dim = embed_dim // num_heads
- assert (
- self.head_dim * num_heads == self.embed_dim
- ), "embed_dim must be divisible by num_heads"
- self.scaling = self.head_dim ** -0.5
- self.self_attention = self_attention
- self.encoder_decoder_attention = encoder_decoder_attention
- assert not self.self_attention or self.qkv_same_dim, (
- "Self-attention requires query, key and " "value to be of the same size"
- )
- self.k_proj = quant_noise(
- nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
- )
- self.v_proj = quant_noise(
- nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
- )
- self.q_proj = quant_noise(
- nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
- )
- self.out_proj = quant_noise(
- nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
- )
- if add_bias_kv:
- self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
- self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
- else:
- self.bias_k = self.bias_v = None
- self.add_zero_attn = add_zero_attn
- self.reset_parameters()
- self.onnx_trace = False
- self.skip_embed_dim_check = False
- def prepare_for_onnx_export_(self):
- self.onnx_trace = True
- def reset_parameters(self):
- if self.qkv_same_dim:
- # Empirically observed the convergence to be much better with
- # the scaled initialization
- nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
- nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
- nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
- else:
- nn.init.xavier_uniform_(self.k_proj.weight)
- nn.init.xavier_uniform_(self.v_proj.weight)
- nn.init.xavier_uniform_(self.q_proj.weight)
- nn.init.xavier_uniform_(self.out_proj.weight)
- if self.out_proj.bias is not None:
- nn.init.constant_(self.out_proj.bias, 0.0)
- if self.bias_k is not None:
- nn.init.xavier_normal_(self.bias_k)
- if self.bias_v is not None:
- nn.init.xavier_normal_(self.bias_v)
- def _get_reserve_head_index(self, num_heads_to_keep: int):
- k_proj_heads_norm = []
- q_proj_heads_norm = []
- v_proj_heads_norm = []
- for i in range(self.num_heads):
- start_idx = i * self.head_dim
- end_idx = (i + 1) * self.head_dim
- k_proj_heads_norm.append(
- torch.sum(
- torch.abs(
- self.k_proj.weight[
- start_idx:end_idx,
- ]
- )
- ).tolist()
- + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
- )
- q_proj_heads_norm.append(
- torch.sum(
- torch.abs(
- self.q_proj.weight[
- start_idx:end_idx,
- ]
- )
- ).tolist()
- + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
- )
- v_proj_heads_norm.append(
- torch.sum(
- torch.abs(
- self.v_proj.weight[
- start_idx:end_idx,
- ]
- )
- ).tolist()
- + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
- )
- heads_norm = []
- for i in range(self.num_heads):
- heads_norm.append(
- k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
- )
- sorted_head_index = sorted(
- range(self.num_heads), key=lambda k: heads_norm[k], reverse=True
- )
- reserve_head_index = []
- for i in range(num_heads_to_keep):
- start = sorted_head_index[i] * self.head_dim
- end = (sorted_head_index[i] + 1) * self.head_dim
- reserve_head_index.append((start, end))
- return reserve_head_index
- def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]):
- new_q_weight = []
- new_q_bias = []
- new_k_weight = []
- new_k_bias = []
- new_v_weight = []
- new_v_bias = []
- new_out_proj_weight = []
- for ele in reserve_head_index:
- start_idx, end_idx = ele
- new_q_weight.append(
- self.q_proj.weight[
- start_idx:end_idx,
- ]
- )
- new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
- new_k_weight.append(
- self.k_proj.weight[
- start_idx:end_idx,
- ]
- )
- new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
- new_v_weight.append(
- self.v_proj.weight[
- start_idx:end_idx,
- ]
- )
- new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
- new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
- new_q_weight = torch.cat(new_q_weight).detach()
- new_k_weight = torch.cat(new_k_weight).detach()
- new_v_weight = torch.cat(new_v_weight).detach()
- new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
- new_q_weight.requires_grad = True
- new_k_weight.requires_grad = True
- new_v_weight.requires_grad = True
- new_out_proj_weight.requires_grad = True
- new_q_bias = torch.cat(new_q_bias).detach()
- new_q_bias.requires_grad = True
- new_k_bias = torch.cat(new_k_bias).detach()
- new_k_bias.requires_grad = True
- new_v_bias = torch.cat(new_v_bias).detach()
- new_v_bias.requires_grad = True
- self.q_proj.weight = torch.nn.Parameter(new_q_weight)
- self.q_proj.bias = torch.nn.Parameter(new_q_bias)
- self.k_proj.weight = torch.nn.Parameter(new_k_weight)
- self.k_proj.bias = torch.nn.Parameter(new_k_bias)
- self.v_proj.weight = torch.nn.Parameter(new_v_weight)
- self.v_proj.bias = torch.nn.Parameter(new_v_bias)
- self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight)
- self.num_heads = len(reserve_head_index)
- self.embed_dim = self.head_dim * self.num_heads
- self.q_proj.out_features = self.embed_dim
- self.k_proj.out_features = self.embed_dim
- self.v_proj.out_features = self.embed_dim
- def _set_skip_embed_dim_check(self):
- self.skip_embed_dim_check = True
- def forward(
- self,
- query,
- key: Optional[Tensor],
- value: Optional[Tensor],
- key_padding_mask: Optional[Tensor] = None,
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
- need_weights: bool = True,
- static_kv: bool = False,
- attn_mask: Optional[Tensor] = None,
- before_softmax: bool = False,
- need_head_weights: bool = False,
- ) -> Tuple[Tensor, Optional[Tensor]]:
- """Input shape: Time x Batch x Channel
- Args:
- key_padding_mask (ByteTensor, optional): mask to exclude
- keys that are pads, of shape `(batch, src_len)`, where
- padding elements are indicated by 1s.
- need_weights (bool, optional): return the attention weights,
- averaged over heads (default: False).
- attn_mask (ByteTensor, optional): typically used to
- implement causal attention, where the mask prevents the
- attention from looking forward in time (default: None).
- before_softmax (bool, optional): return the raw attention
- weights and values before the attention softmax.
- need_head_weights (bool, optional): return the attention
- weights for each head. Implies *need_weights*. Default:
- return the average attention weights over all heads.
- """
- if need_head_weights:
- need_weights = True
- is_tpu = query.device.type == "xla"
- tgt_len, bsz, embed_dim = query.size()
- src_len = tgt_len
- if not self.skip_embed_dim_check:
- assert (
- embed_dim == self.embed_dim
- ), f"query dim {embed_dim} != {self.embed_dim}"
- assert list(query.size()) == [tgt_len, bsz, embed_dim]
- if key is not None:
- src_len, key_bsz, _ = key.size()
- if not torch.jit.is_scripting():
- assert key_bsz == bsz
- assert value is not None
- assert src_len, bsz == value.shape[:2]
- if (
- not self.onnx_trace
- and not is_tpu # don't use PyTorch version on TPUs
- and incremental_state is None
- and not static_kv
- # A workaround for quantization to work. Otherwise JIT compilation
- # treats bias in linear module as method.
- and not torch.jit.is_scripting()
- # The Multihead attention implemented in pytorch forces strong dimension check
- # for input embedding dimention and K,Q,V projection dimension.
- # Since pruning will break the dimension check and it is not easy to modify the pytorch API,
- # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check
- and not self.skip_embed_dim_check
- ):
- assert key is not None and value is not None
- return F.multi_head_attention_forward(
- query,
- key,
- value,
- self.embed_dim,
- self.num_heads,
- torch.empty([0]),
- torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
- self.bias_k,
- self.bias_v,
- self.add_zero_attn,
- self.dropout_module.p,
- self.out_proj.weight,
- self.out_proj.bias,
- self.training or self.dropout_module.apply_during_inference,
- key_padding_mask,
- need_weights,
- attn_mask,
- use_separate_proj_weight=True,
- q_proj_weight=self.q_proj.weight,
- k_proj_weight=self.k_proj.weight,
- v_proj_weight=self.v_proj.weight,
- )
- if incremental_state is not None:
- saved_state = self._get_input_buffer(incremental_state)
- if saved_state is not None and "prev_key" in saved_state:
- # previous time steps are cached - no need to recompute
- # key and value if they are static
- if static_kv:
- assert self.encoder_decoder_attention and not self.self_attention
- key = value = None
- else:
- saved_state = None
- if self.self_attention:
- q = self.q_proj(query)
- k = self.k_proj(query)
- v = self.v_proj(query)
- elif self.encoder_decoder_attention:
- # encoder-decoder attention
- q = self.q_proj(query)
- if key is None:
- assert value is None
- k = v = None
- else:
- k = self.k_proj(key)
- v = self.v_proj(key)
- else:
- assert key is not None and value is not None
- q = self.q_proj(query)
- k = self.k_proj(key)
- v = self.v_proj(value)
- q *= self.scaling
- if self.bias_k is not None:
- assert self.bias_v is not None
- k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
- v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
- if attn_mask is not None:
- attn_mask = torch.cat(
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
- )
- if key_padding_mask is not None:
- key_padding_mask = torch.cat(
- [
- key_padding_mask,
- key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
- ],
- dim=1,
- )
- q = (
- q.contiguous()
- .view(tgt_len, bsz * self.num_heads, self.head_dim)
- .transpose(0, 1)
- )
- if k is not None:
- k = (
- k.contiguous()
- .view(-1, bsz * self.num_heads, self.head_dim)
- .transpose(0, 1)
- )
- if v is not None:
- v = (
- v.contiguous()
- .view(-1, bsz * self.num_heads, self.head_dim)
- .transpose(0, 1)
- )
- if saved_state is not None:
- # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
- if "prev_key" in saved_state:
- _prev_key = saved_state["prev_key"]
- assert _prev_key is not None
- prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
- if static_kv:
- k = prev_key
- else:
- assert k is not None
- k = torch.cat([prev_key, k], dim=1)
- src_len = k.size(1)
- if "prev_value" in saved_state:
- _prev_value = saved_state["prev_value"]
- assert _prev_value is not None
- prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
- if static_kv:
- v = prev_value
- else:
- assert v is not None
- v = torch.cat([prev_value, v], dim=1)
- prev_key_padding_mask: Optional[Tensor] = None
- if "prev_key_padding_mask" in saved_state:
- prev_key_padding_mask = saved_state["prev_key_padding_mask"]
- assert k is not None and v is not None
- key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
- key_padding_mask=key_padding_mask,
- prev_key_padding_mask=prev_key_padding_mask,
- batch_size=bsz,
- src_len=k.size(1),
- static_kv=static_kv,
- )
- saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
- saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
- saved_state["prev_key_padding_mask"] = key_padding_mask
- # In this branch incremental_state is never None
- assert incremental_state is not None
- incremental_state = self._set_input_buffer(incremental_state, saved_state)
- assert k is not None
- assert k.size(1) == src_len
- # This is part of a workaround to get around fork/join parallelism
- # not supporting Optional types.
- if key_padding_mask is not None and key_padding_mask.dim() == 0:
- key_padding_mask = None
- if key_padding_mask is not None:
- assert key_padding_mask.size(0) == bsz
- assert key_padding_mask.size(1) == src_len
- if self.add_zero_attn:
- assert v is not None
- src_len += 1
- k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
- v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
- if attn_mask is not None:
- attn_mask = torch.cat(
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
- )
- if key_padding_mask is not None:
- key_padding_mask = torch.cat(
- [
- key_padding_mask,
- torch.zeros(key_padding_mask.size(0), 1).type_as(
- key_padding_mask
- ),
- ],
- dim=1,
- )
- attn_weights = torch.bmm(q, k.transpose(1, 2))
- attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
- assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
- if attn_mask is not None:
- attn_mask = attn_mask.unsqueeze(0)
- if self.onnx_trace:
- attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
- attn_weights += attn_mask
- if key_padding_mask is not None:
- # don't attend to padding symbols
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- if not is_tpu:
- attn_weights = attn_weights.masked_fill(
- key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
- float("-inf"),
- )
- else:
- attn_weights = attn_weights.transpose(0, 2)
- attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
- attn_weights = attn_weights.transpose(0, 2)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- if before_softmax:
- return attn_weights, v
- attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
- attn_weights = attn_weights_float.type_as(attn_weights)
- attn_probs = self.dropout_module(attn_weights)
- assert v is not None
- attn = torch.bmm(attn_probs, v)
- assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
- if self.onnx_trace and attn.size(1) == 1:
- # when ONNX tracing a single decoder step (sequence length == 1)
- # the transpose is a no-op copy before view, thus unnecessary
- attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
- else:
- attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
- attn = self.out_proj(attn)
- attn_weights: Optional[Tensor] = None
- if need_weights:
- attn_weights = attn_weights_float.view(
- bsz, self.num_heads, tgt_len, src_len
- ).transpose(1, 0)
- if not need_head_weights:
- # average attention weights over heads
- attn_weights = attn_weights.mean(dim=0)
- return attn, attn_weights
- @staticmethod
- def _append_prev_key_padding_mask(
- key_padding_mask: Optional[Tensor],
- prev_key_padding_mask: Optional[Tensor],
- batch_size: int,
- src_len: int,
- static_kv: bool,
- ) -> Optional[Tensor]:
- # saved key padding masks have shape (bsz, seq_len)
- if prev_key_padding_mask is not None and static_kv:
- new_key_padding_mask = prev_key_padding_mask
- elif prev_key_padding_mask is not None and key_padding_mask is not None:
- new_key_padding_mask = torch.cat(
- [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
- )
- # During incremental decoding, as the padding token enters and
- # leaves the frame, there will be a time when prev or current
- # is None
- elif prev_key_padding_mask is not None:
- if src_len > prev_key_padding_mask.size(1):
- filler = torch.zeros(
- (batch_size, src_len - prev_key_padding_mask.size(1)),
- device=prev_key_padding_mask.device,
- )
- new_key_padding_mask = torch.cat(
- [prev_key_padding_mask.float(), filler.float()], dim=1
- )
- else:
- new_key_padding_mask = prev_key_padding_mask.float()
- elif key_padding_mask is not None:
- if src_len > key_padding_mask.size(1):
- filler = torch.zeros(
- (batch_size, src_len - key_padding_mask.size(1)),
- device=key_padding_mask.device,
- )
- new_key_padding_mask = torch.cat(
- [filler.float(), key_padding_mask.float()], dim=1
- )
- else:
- new_key_padding_mask = key_padding_mask.float()
- else:
- new_key_padding_mask = prev_key_padding_mask
- return new_key_padding_mask
- @torch.jit.export
- def reorder_incremental_state(
- self,
- incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
- new_order: Tensor,
- ):
- """Reorder buffered internal state (for incremental generation)."""
- input_buffer = self._get_input_buffer(incremental_state)
- if input_buffer is not None:
- for k in input_buffer.keys():
- input_buffer_k = input_buffer[k]
- if input_buffer_k is not None:
- if self.encoder_decoder_attention and input_buffer_k.size(
- 0
- ) == new_order.size(0):
- break
- input_buffer[k] = input_buffer_k.index_select(0, new_order)
- incremental_state = self._set_input_buffer(incremental_state, input_buffer)
- return incremental_state
- def _get_input_buffer(
- self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
- ) -> Dict[str, Optional[Tensor]]:
- result = self.get_incremental_state(incremental_state, "attn_state")
- if result is not None:
- return result
- else:
- empty_result: Dict[str, Optional[Tensor]] = {}
- return empty_result
- def _set_input_buffer(
- self,
- incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
- buffer: Dict[str, Optional[Tensor]],
- ):
- return self.set_incremental_state(incremental_state, "attn_state", buffer)
- def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
- return attn_weights
- def upgrade_state_dict_named(self, state_dict, name):
- prefix = name + "." if name != "" else ""
- items_to_add = {}
- keys_to_remove = []
- for k in state_dict.keys():
- if k.endswith(prefix + "in_proj_weight"):
- # in_proj_weight used to be q + k + v with same dimensions
- dim = int(state_dict[k].shape[0] / 3)
- items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
- items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim: 2 * dim]
- items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:]
- keys_to_remove.append(k)
- k_bias = prefix + "in_proj_bias"
- if k_bias in state_dict.keys():
- dim = int(state_dict[k].shape[0] / 3)
- items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
- items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
- dim: 2 * dim
- ]
- items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:]
- keys_to_remove.append(prefix + "in_proj_bias")
- for k in keys_to_remove:
- del state_dict[k]
- for key, value in items_to_add.items():
- state_dict[key] = value
|