multihead_attention.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. import math
  7. from typing import Dict, List, Optional, Tuple
  8. import torch
  9. import torch.nn.functional as F
  10. from torch import Tensor, nn
  11. from torch.nn import Parameter
  12. from funasr.models.data2vec.quant_noise import quant_noise
  13. class FairseqDropout(nn.Module):
  14. def __init__(self, p, module_name=None):
  15. super().__init__()
  16. self.p = p
  17. self.module_name = module_name
  18. self.apply_during_inference = False
  19. def forward(self, x, inplace: bool = False):
  20. if self.p > 0 and (self.training or self.apply_during_inference):
  21. return F.dropout(x, p=self.p, training=True, inplace=inplace)
  22. else:
  23. return x
  24. def make_generation_fast_(
  25. self,
  26. name: str,
  27. retain_dropout: bool = False,
  28. retain_dropout_modules: Optional[List[str]] = None,
  29. **kwargs
  30. ):
  31. if retain_dropout:
  32. if retain_dropout_modules is not None and self.module_name is None:
  33. logging.warning(
  34. "Cannot enable dropout during inference for module {} "
  35. "because module_name was not set".format(name)
  36. )
  37. elif (
  38. retain_dropout_modules is None # if None, apply to all modules
  39. or self.module_name in retain_dropout_modules
  40. ):
  41. logging.info(
  42. "Enabling dropout during inference for module: {}".format(name)
  43. )
  44. self.apply_during_inference = True
  45. else:
  46. logging.info("Disabling dropout for module: {}".format(name))
  47. class MultiheadAttention(nn.Module):
  48. """Multi-headed attention.
  49. See "Attention Is All You Need" for more details.
  50. """
  51. def __init__(
  52. self,
  53. embed_dim,
  54. num_heads,
  55. kdim=None,
  56. vdim=None,
  57. dropout=0.0,
  58. bias=True,
  59. add_bias_kv=False,
  60. add_zero_attn=False,
  61. self_attention=False,
  62. encoder_decoder_attention=False,
  63. q_noise=0.0,
  64. qn_block_size=8,
  65. ):
  66. super().__init__()
  67. self.embed_dim = embed_dim
  68. self.kdim = kdim if kdim is not None else embed_dim
  69. self.vdim = vdim if vdim is not None else embed_dim
  70. self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
  71. self.num_heads = num_heads
  72. self.dropout_module = FairseqDropout(
  73. dropout, module_name=self.__class__.__name__
  74. )
  75. self.head_dim = embed_dim // num_heads
  76. assert (
  77. self.head_dim * num_heads == self.embed_dim
  78. ), "embed_dim must be divisible by num_heads"
  79. self.scaling = self.head_dim ** -0.5
  80. self.self_attention = self_attention
  81. self.encoder_decoder_attention = encoder_decoder_attention
  82. assert not self.self_attention or self.qkv_same_dim, (
  83. "Self-attention requires query, key and " "value to be of the same size"
  84. )
  85. self.k_proj = quant_noise(
  86. nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
  87. )
  88. self.v_proj = quant_noise(
  89. nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
  90. )
  91. self.q_proj = quant_noise(
  92. nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
  93. )
  94. self.out_proj = quant_noise(
  95. nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
  96. )
  97. if add_bias_kv:
  98. self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
  99. self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
  100. else:
  101. self.bias_k = self.bias_v = None
  102. self.add_zero_attn = add_zero_attn
  103. self.reset_parameters()
  104. self.onnx_trace = False
  105. self.skip_embed_dim_check = False
  106. def prepare_for_onnx_export_(self):
  107. self.onnx_trace = True
  108. def reset_parameters(self):
  109. if self.qkv_same_dim:
  110. # Empirically observed the convergence to be much better with
  111. # the scaled initialization
  112. nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
  113. nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
  114. nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
  115. else:
  116. nn.init.xavier_uniform_(self.k_proj.weight)
  117. nn.init.xavier_uniform_(self.v_proj.weight)
  118. nn.init.xavier_uniform_(self.q_proj.weight)
  119. nn.init.xavier_uniform_(self.out_proj.weight)
  120. if self.out_proj.bias is not None:
  121. nn.init.constant_(self.out_proj.bias, 0.0)
  122. if self.bias_k is not None:
  123. nn.init.xavier_normal_(self.bias_k)
  124. if self.bias_v is not None:
  125. nn.init.xavier_normal_(self.bias_v)
  126. def _get_reserve_head_index(self, num_heads_to_keep: int):
  127. k_proj_heads_norm = []
  128. q_proj_heads_norm = []
  129. v_proj_heads_norm = []
  130. for i in range(self.num_heads):
  131. start_idx = i * self.head_dim
  132. end_idx = (i + 1) * self.head_dim
  133. k_proj_heads_norm.append(
  134. torch.sum(
  135. torch.abs(
  136. self.k_proj.weight[
  137. start_idx:end_idx,
  138. ]
  139. )
  140. ).tolist()
  141. + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
  142. )
  143. q_proj_heads_norm.append(
  144. torch.sum(
  145. torch.abs(
  146. self.q_proj.weight[
  147. start_idx:end_idx,
  148. ]
  149. )
  150. ).tolist()
  151. + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
  152. )
  153. v_proj_heads_norm.append(
  154. torch.sum(
  155. torch.abs(
  156. self.v_proj.weight[
  157. start_idx:end_idx,
  158. ]
  159. )
  160. ).tolist()
  161. + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
  162. )
  163. heads_norm = []
  164. for i in range(self.num_heads):
  165. heads_norm.append(
  166. k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
  167. )
  168. sorted_head_index = sorted(
  169. range(self.num_heads), key=lambda k: heads_norm[k], reverse=True
  170. )
  171. reserve_head_index = []
  172. for i in range(num_heads_to_keep):
  173. start = sorted_head_index[i] * self.head_dim
  174. end = (sorted_head_index[i] + 1) * self.head_dim
  175. reserve_head_index.append((start, end))
  176. return reserve_head_index
  177. def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]):
  178. new_q_weight = []
  179. new_q_bias = []
  180. new_k_weight = []
  181. new_k_bias = []
  182. new_v_weight = []
  183. new_v_bias = []
  184. new_out_proj_weight = []
  185. for ele in reserve_head_index:
  186. start_idx, end_idx = ele
  187. new_q_weight.append(
  188. self.q_proj.weight[
  189. start_idx:end_idx,
  190. ]
  191. )
  192. new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
  193. new_k_weight.append(
  194. self.k_proj.weight[
  195. start_idx:end_idx,
  196. ]
  197. )
  198. new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
  199. new_v_weight.append(
  200. self.v_proj.weight[
  201. start_idx:end_idx,
  202. ]
  203. )
  204. new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
  205. new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
  206. new_q_weight = torch.cat(new_q_weight).detach()
  207. new_k_weight = torch.cat(new_k_weight).detach()
  208. new_v_weight = torch.cat(new_v_weight).detach()
  209. new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
  210. new_q_weight.requires_grad = True
  211. new_k_weight.requires_grad = True
  212. new_v_weight.requires_grad = True
  213. new_out_proj_weight.requires_grad = True
  214. new_q_bias = torch.cat(new_q_bias).detach()
  215. new_q_bias.requires_grad = True
  216. new_k_bias = torch.cat(new_k_bias).detach()
  217. new_k_bias.requires_grad = True
  218. new_v_bias = torch.cat(new_v_bias).detach()
  219. new_v_bias.requires_grad = True
  220. self.q_proj.weight = torch.nn.Parameter(new_q_weight)
  221. self.q_proj.bias = torch.nn.Parameter(new_q_bias)
  222. self.k_proj.weight = torch.nn.Parameter(new_k_weight)
  223. self.k_proj.bias = torch.nn.Parameter(new_k_bias)
  224. self.v_proj.weight = torch.nn.Parameter(new_v_weight)
  225. self.v_proj.bias = torch.nn.Parameter(new_v_bias)
  226. self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight)
  227. self.num_heads = len(reserve_head_index)
  228. self.embed_dim = self.head_dim * self.num_heads
  229. self.q_proj.out_features = self.embed_dim
  230. self.k_proj.out_features = self.embed_dim
  231. self.v_proj.out_features = self.embed_dim
  232. def _set_skip_embed_dim_check(self):
  233. self.skip_embed_dim_check = True
  234. def forward(
  235. self,
  236. query,
  237. key: Optional[Tensor],
  238. value: Optional[Tensor],
  239. key_padding_mask: Optional[Tensor] = None,
  240. incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
  241. need_weights: bool = True,
  242. static_kv: bool = False,
  243. attn_mask: Optional[Tensor] = None,
  244. before_softmax: bool = False,
  245. need_head_weights: bool = False,
  246. ) -> Tuple[Tensor, Optional[Tensor]]:
  247. """Input shape: Time x Batch x Channel
  248. Args:
  249. key_padding_mask (ByteTensor, optional): mask to exclude
  250. keys that are pads, of shape `(batch, src_len)`, where
  251. padding elements are indicated by 1s.
  252. need_weights (bool, optional): return the attention weights,
  253. averaged over heads (default: False).
  254. attn_mask (ByteTensor, optional): typically used to
  255. implement causal attention, where the mask prevents the
  256. attention from looking forward in time (default: None).
  257. before_softmax (bool, optional): return the raw attention
  258. weights and values before the attention softmax.
  259. need_head_weights (bool, optional): return the attention
  260. weights for each head. Implies *need_weights*. Default:
  261. return the average attention weights over all heads.
  262. """
  263. if need_head_weights:
  264. need_weights = True
  265. is_tpu = query.device.type == "xla"
  266. tgt_len, bsz, embed_dim = query.size()
  267. src_len = tgt_len
  268. if not self.skip_embed_dim_check:
  269. assert (
  270. embed_dim == self.embed_dim
  271. ), f"query dim {embed_dim} != {self.embed_dim}"
  272. assert list(query.size()) == [tgt_len, bsz, embed_dim]
  273. if key is not None:
  274. src_len, key_bsz, _ = key.size()
  275. if not torch.jit.is_scripting():
  276. assert key_bsz == bsz
  277. assert value is not None
  278. assert src_len, bsz == value.shape[:2]
  279. if (
  280. not self.onnx_trace
  281. and not is_tpu # don't use PyTorch version on TPUs
  282. and incremental_state is None
  283. and not static_kv
  284. # A workaround for quantization to work. Otherwise JIT compilation
  285. # treats bias in linear module as method.
  286. and not torch.jit.is_scripting()
  287. # The Multihead attention implemented in pytorch forces strong dimension check
  288. # for input embedding dimention and K,Q,V projection dimension.
  289. # Since pruning will break the dimension check and it is not easy to modify the pytorch API,
  290. # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check
  291. and not self.skip_embed_dim_check
  292. ):
  293. assert key is not None and value is not None
  294. return F.multi_head_attention_forward(
  295. query,
  296. key,
  297. value,
  298. self.embed_dim,
  299. self.num_heads,
  300. torch.empty([0]),
  301. torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
  302. self.bias_k,
  303. self.bias_v,
  304. self.add_zero_attn,
  305. self.dropout_module.p,
  306. self.out_proj.weight,
  307. self.out_proj.bias,
  308. self.training or self.dropout_module.apply_during_inference,
  309. key_padding_mask,
  310. need_weights,
  311. attn_mask,
  312. use_separate_proj_weight=True,
  313. q_proj_weight=self.q_proj.weight,
  314. k_proj_weight=self.k_proj.weight,
  315. v_proj_weight=self.v_proj.weight,
  316. )
  317. if incremental_state is not None:
  318. saved_state = self._get_input_buffer(incremental_state)
  319. if saved_state is not None and "prev_key" in saved_state:
  320. # previous time steps are cached - no need to recompute
  321. # key and value if they are static
  322. if static_kv:
  323. assert self.encoder_decoder_attention and not self.self_attention
  324. key = value = None
  325. else:
  326. saved_state = None
  327. if self.self_attention:
  328. q = self.q_proj(query)
  329. k = self.k_proj(query)
  330. v = self.v_proj(query)
  331. elif self.encoder_decoder_attention:
  332. # encoder-decoder attention
  333. q = self.q_proj(query)
  334. if key is None:
  335. assert value is None
  336. k = v = None
  337. else:
  338. k = self.k_proj(key)
  339. v = self.v_proj(key)
  340. else:
  341. assert key is not None and value is not None
  342. q = self.q_proj(query)
  343. k = self.k_proj(key)
  344. v = self.v_proj(value)
  345. q *= self.scaling
  346. if self.bias_k is not None:
  347. assert self.bias_v is not None
  348. k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
  349. v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
  350. if attn_mask is not None:
  351. attn_mask = torch.cat(
  352. [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
  353. )
  354. if key_padding_mask is not None:
  355. key_padding_mask = torch.cat(
  356. [
  357. key_padding_mask,
  358. key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
  359. ],
  360. dim=1,
  361. )
  362. q = (
  363. q.contiguous()
  364. .view(tgt_len, bsz * self.num_heads, self.head_dim)
  365. .transpose(0, 1)
  366. )
  367. if k is not None:
  368. k = (
  369. k.contiguous()
  370. .view(-1, bsz * self.num_heads, self.head_dim)
  371. .transpose(0, 1)
  372. )
  373. if v is not None:
  374. v = (
  375. v.contiguous()
  376. .view(-1, bsz * self.num_heads, self.head_dim)
  377. .transpose(0, 1)
  378. )
  379. if saved_state is not None:
  380. # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
  381. if "prev_key" in saved_state:
  382. _prev_key = saved_state["prev_key"]
  383. assert _prev_key is not None
  384. prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
  385. if static_kv:
  386. k = prev_key
  387. else:
  388. assert k is not None
  389. k = torch.cat([prev_key, k], dim=1)
  390. src_len = k.size(1)
  391. if "prev_value" in saved_state:
  392. _prev_value = saved_state["prev_value"]
  393. assert _prev_value is not None
  394. prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
  395. if static_kv:
  396. v = prev_value
  397. else:
  398. assert v is not None
  399. v = torch.cat([prev_value, v], dim=1)
  400. prev_key_padding_mask: Optional[Tensor] = None
  401. if "prev_key_padding_mask" in saved_state:
  402. prev_key_padding_mask = saved_state["prev_key_padding_mask"]
  403. assert k is not None and v is not None
  404. key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
  405. key_padding_mask=key_padding_mask,
  406. prev_key_padding_mask=prev_key_padding_mask,
  407. batch_size=bsz,
  408. src_len=k.size(1),
  409. static_kv=static_kv,
  410. )
  411. saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
  412. saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
  413. saved_state["prev_key_padding_mask"] = key_padding_mask
  414. # In this branch incremental_state is never None
  415. assert incremental_state is not None
  416. incremental_state = self._set_input_buffer(incremental_state, saved_state)
  417. assert k is not None
  418. assert k.size(1) == src_len
  419. # This is part of a workaround to get around fork/join parallelism
  420. # not supporting Optional types.
  421. if key_padding_mask is not None and key_padding_mask.dim() == 0:
  422. key_padding_mask = None
  423. if key_padding_mask is not None:
  424. assert key_padding_mask.size(0) == bsz
  425. assert key_padding_mask.size(1) == src_len
  426. if self.add_zero_attn:
  427. assert v is not None
  428. src_len += 1
  429. k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
  430. v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
  431. if attn_mask is not None:
  432. attn_mask = torch.cat(
  433. [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
  434. )
  435. if key_padding_mask is not None:
  436. key_padding_mask = torch.cat(
  437. [
  438. key_padding_mask,
  439. torch.zeros(key_padding_mask.size(0), 1).type_as(
  440. key_padding_mask
  441. ),
  442. ],
  443. dim=1,
  444. )
  445. attn_weights = torch.bmm(q, k.transpose(1, 2))
  446. attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
  447. assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
  448. if attn_mask is not None:
  449. attn_mask = attn_mask.unsqueeze(0)
  450. if self.onnx_trace:
  451. attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
  452. attn_weights += attn_mask
  453. if key_padding_mask is not None:
  454. # don't attend to padding symbols
  455. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  456. if not is_tpu:
  457. attn_weights = attn_weights.masked_fill(
  458. key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
  459. float("-inf"),
  460. )
  461. else:
  462. attn_weights = attn_weights.transpose(0, 2)
  463. attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
  464. attn_weights = attn_weights.transpose(0, 2)
  465. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  466. if before_softmax:
  467. return attn_weights, v
  468. attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
  469. attn_weights = attn_weights_float.type_as(attn_weights)
  470. attn_probs = self.dropout_module(attn_weights)
  471. assert v is not None
  472. attn = torch.bmm(attn_probs, v)
  473. assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
  474. if self.onnx_trace and attn.size(1) == 1:
  475. # when ONNX tracing a single decoder step (sequence length == 1)
  476. # the transpose is a no-op copy before view, thus unnecessary
  477. attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
  478. else:
  479. attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
  480. attn = self.out_proj(attn)
  481. attn_weights: Optional[Tensor] = None
  482. if need_weights:
  483. attn_weights = attn_weights_float.view(
  484. bsz, self.num_heads, tgt_len, src_len
  485. ).transpose(1, 0)
  486. if not need_head_weights:
  487. # average attention weights over heads
  488. attn_weights = attn_weights.mean(dim=0)
  489. return attn, attn_weights
  490. @staticmethod
  491. def _append_prev_key_padding_mask(
  492. key_padding_mask: Optional[Tensor],
  493. prev_key_padding_mask: Optional[Tensor],
  494. batch_size: int,
  495. src_len: int,
  496. static_kv: bool,
  497. ) -> Optional[Tensor]:
  498. # saved key padding masks have shape (bsz, seq_len)
  499. if prev_key_padding_mask is not None and static_kv:
  500. new_key_padding_mask = prev_key_padding_mask
  501. elif prev_key_padding_mask is not None and key_padding_mask is not None:
  502. new_key_padding_mask = torch.cat(
  503. [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
  504. )
  505. # During incremental decoding, as the padding token enters and
  506. # leaves the frame, there will be a time when prev or current
  507. # is None
  508. elif prev_key_padding_mask is not None:
  509. if src_len > prev_key_padding_mask.size(1):
  510. filler = torch.zeros(
  511. (batch_size, src_len - prev_key_padding_mask.size(1)),
  512. device=prev_key_padding_mask.device,
  513. )
  514. new_key_padding_mask = torch.cat(
  515. [prev_key_padding_mask.float(), filler.float()], dim=1
  516. )
  517. else:
  518. new_key_padding_mask = prev_key_padding_mask.float()
  519. elif key_padding_mask is not None:
  520. if src_len > key_padding_mask.size(1):
  521. filler = torch.zeros(
  522. (batch_size, src_len - key_padding_mask.size(1)),
  523. device=key_padding_mask.device,
  524. )
  525. new_key_padding_mask = torch.cat(
  526. [filler.float(), key_padding_mask.float()], dim=1
  527. )
  528. else:
  529. new_key_padding_mask = key_padding_mask.float()
  530. else:
  531. new_key_padding_mask = prev_key_padding_mask
  532. return new_key_padding_mask
  533. @torch.jit.export
  534. def reorder_incremental_state(
  535. self,
  536. incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
  537. new_order: Tensor,
  538. ):
  539. """Reorder buffered internal state (for incremental generation)."""
  540. input_buffer = self._get_input_buffer(incremental_state)
  541. if input_buffer is not None:
  542. for k in input_buffer.keys():
  543. input_buffer_k = input_buffer[k]
  544. if input_buffer_k is not None:
  545. if self.encoder_decoder_attention and input_buffer_k.size(
  546. 0
  547. ) == new_order.size(0):
  548. break
  549. input_buffer[k] = input_buffer_k.index_select(0, new_order)
  550. incremental_state = self._set_input_buffer(incremental_state, input_buffer)
  551. return incremental_state
  552. def _get_input_buffer(
  553. self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
  554. ) -> Dict[str, Optional[Tensor]]:
  555. result = self.get_incremental_state(incremental_state, "attn_state")
  556. if result is not None:
  557. return result
  558. else:
  559. empty_result: Dict[str, Optional[Tensor]] = {}
  560. return empty_result
  561. def _set_input_buffer(
  562. self,
  563. incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
  564. buffer: Dict[str, Optional[Tensor]],
  565. ):
  566. return self.set_incremental_state(incremental_state, "attn_state", buffer)
  567. def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
  568. return attn_weights
  569. def upgrade_state_dict_named(self, state_dict, name):
  570. prefix = name + "." if name != "" else ""
  571. items_to_add = {}
  572. keys_to_remove = []
  573. for k in state_dict.keys():
  574. if k.endswith(prefix + "in_proj_weight"):
  575. # in_proj_weight used to be q + k + v with same dimensions
  576. dim = int(state_dict[k].shape[0] / 3)
  577. items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
  578. items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim: 2 * dim]
  579. items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim:]
  580. keys_to_remove.append(k)
  581. k_bias = prefix + "in_proj_bias"
  582. if k_bias in state_dict.keys():
  583. dim = int(state_dict[k].shape[0] / 3)
  584. items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
  585. items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
  586. dim: 2 * dim
  587. ]
  588. items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim:]
  589. keys_to_remove.append(prefix + "in_proj_bias")
  590. for k in keys_to_remove:
  591. del state_dict[k]
  592. for key, value in items_to_add.items():
  593. state_dict[key] = value