attention.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Multi-Head Attention layer definition."""
  6. import math
  7. import numpy
  8. import torch
  9. from torch import nn
  10. class MultiHeadedAttention(nn.Module):
  11. """Multi-Head Attention layer.
  12. Args:
  13. n_head (int): The number of heads.
  14. n_feat (int): The number of features.
  15. dropout_rate (float): Dropout rate.
  16. """
  17. def __init__(self, n_head, n_feat, dropout_rate):
  18. """Construct an MultiHeadedAttention object."""
  19. super(MultiHeadedAttention, self).__init__()
  20. assert n_feat % n_head == 0
  21. # We assume d_v always equals d_k
  22. self.d_k = n_feat // n_head
  23. self.h = n_head
  24. self.linear_q = nn.Linear(n_feat, n_feat)
  25. self.linear_k = nn.Linear(n_feat, n_feat)
  26. self.linear_v = nn.Linear(n_feat, n_feat)
  27. self.linear_out = nn.Linear(n_feat, n_feat)
  28. self.attn = None
  29. self.dropout = nn.Dropout(p=dropout_rate)
  30. def forward_qkv(self, query, key, value):
  31. """Transform query, key and value.
  32. Args:
  33. query (torch.Tensor): Query tensor (#batch, time1, size).
  34. key (torch.Tensor): Key tensor (#batch, time2, size).
  35. value (torch.Tensor): Value tensor (#batch, time2, size).
  36. Returns:
  37. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  38. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  39. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  40. """
  41. n_batch = query.size(0)
  42. q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
  43. k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
  44. v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
  45. q = q.transpose(1, 2) # (batch, head, time1, d_k)
  46. k = k.transpose(1, 2) # (batch, head, time2, d_k)
  47. v = v.transpose(1, 2) # (batch, head, time2, d_k)
  48. return q, k, v
  49. def forward_attention(self, value, scores, mask):
  50. """Compute attention context vector.
  51. Args:
  52. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  53. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  54. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  55. Returns:
  56. torch.Tensor: Transformed value (#batch, time1, d_model)
  57. weighted by the attention score (#batch, time1, time2).
  58. """
  59. n_batch = value.size(0)
  60. if mask is not None:
  61. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  62. min_value = float(
  63. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  64. )
  65. scores = scores.masked_fill(mask, min_value)
  66. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  67. mask, 0.0
  68. ) # (batch, head, time1, time2)
  69. else:
  70. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  71. p_attn = self.dropout(self.attn)
  72. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  73. x = (
  74. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  75. ) # (batch, time1, d_model)
  76. return self.linear_out(x) # (batch, time1, d_model)
  77. def forward(self, query, key, value, mask):
  78. """Compute scaled dot product attention.
  79. Args:
  80. query (torch.Tensor): Query tensor (#batch, time1, size).
  81. key (torch.Tensor): Key tensor (#batch, time2, size).
  82. value (torch.Tensor): Value tensor (#batch, time2, size).
  83. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  84. (#batch, time1, time2).
  85. Returns:
  86. torch.Tensor: Output tensor (#batch, time1, d_model).
  87. """
  88. q, k, v = self.forward_qkv(query, key, value)
  89. scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
  90. return self.forward_attention(v, scores, mask)
  91. class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
  92. """Multi-Head Attention layer with relative position encoding (old version).
  93. Details can be found in https://github.com/espnet/espnet/pull/2816.
  94. Paper: https://arxiv.org/abs/1901.02860
  95. Args:
  96. n_head (int): The number of heads.
  97. n_feat (int): The number of features.
  98. dropout_rate (float): Dropout rate.
  99. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  100. """
  101. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  102. """Construct an RelPositionMultiHeadedAttention object."""
  103. super().__init__(n_head, n_feat, dropout_rate)
  104. self.zero_triu = zero_triu
  105. # linear transformation for positional encoding
  106. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  107. # these two learnable bias are used in matrix c and matrix d
  108. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  109. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  110. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  111. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  112. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  113. def rel_shift(self, x):
  114. """Compute relative positional encoding.
  115. Args:
  116. x (torch.Tensor): Input tensor (batch, head, time1, time2).
  117. Returns:
  118. torch.Tensor: Output tensor.
  119. """
  120. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  121. x_padded = torch.cat([zero_pad, x], dim=-1)
  122. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  123. x = x_padded[:, :, 1:].view_as(x)
  124. if self.zero_triu:
  125. ones = torch.ones((x.size(2), x.size(3)))
  126. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  127. return x
  128. def forward(self, query, key, value, pos_emb, mask):
  129. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  130. Args:
  131. query (torch.Tensor): Query tensor (#batch, time1, size).
  132. key (torch.Tensor): Key tensor (#batch, time2, size).
  133. value (torch.Tensor): Value tensor (#batch, time2, size).
  134. pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
  135. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  136. (#batch, time1, time2).
  137. Returns:
  138. torch.Tensor: Output tensor (#batch, time1, d_model).
  139. """
  140. q, k, v = self.forward_qkv(query, key, value)
  141. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  142. n_batch_pos = pos_emb.size(0)
  143. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  144. p = p.transpose(1, 2) # (batch, head, time1, d_k)
  145. # (batch, head, time1, d_k)
  146. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  147. # (batch, head, time1, d_k)
  148. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  149. # compute attention score
  150. # first compute matrix a and matrix c
  151. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  152. # (batch, head, time1, time2)
  153. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  154. # compute matrix b and matrix d
  155. # (batch, head, time1, time1)
  156. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  157. matrix_bd = self.rel_shift(matrix_bd)
  158. scores = (matrix_ac + matrix_bd) / math.sqrt(
  159. self.d_k
  160. ) # (batch, head, time1, time2)
  161. return self.forward_attention(v, scores, mask)
  162. class RelPositionMultiHeadedAttention(MultiHeadedAttention):
  163. """Multi-Head Attention layer with relative position encoding (new implementation).
  164. Details can be found in https://github.com/espnet/espnet/pull/2816.
  165. Paper: https://arxiv.org/abs/1901.02860
  166. Args:
  167. n_head (int): The number of heads.
  168. n_feat (int): The number of features.
  169. dropout_rate (float): Dropout rate.
  170. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  171. """
  172. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  173. """Construct an RelPositionMultiHeadedAttention object."""
  174. super().__init__(n_head, n_feat, dropout_rate)
  175. self.zero_triu = zero_triu
  176. # linear transformation for positional encoding
  177. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  178. # these two learnable bias are used in matrix c and matrix d
  179. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  180. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  181. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  182. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  183. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  184. def rel_shift(self, x):
  185. """Compute relative positional encoding.
  186. Args:
  187. x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
  188. time1 means the length of query vector.
  189. Returns:
  190. torch.Tensor: Output tensor.
  191. """
  192. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  193. x_padded = torch.cat([zero_pad, x], dim=-1)
  194. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  195. x = x_padded[:, :, 1:].view_as(x)[
  196. :, :, :, : x.size(-1) // 2 + 1
  197. ] # only keep the positions from 0 to time2
  198. if self.zero_triu:
  199. ones = torch.ones((x.size(2), x.size(3)), device=x.device)
  200. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  201. return x
  202. def forward(self, query, key, value, pos_emb, mask):
  203. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  204. Args:
  205. query (torch.Tensor): Query tensor (#batch, time1, size).
  206. key (torch.Tensor): Key tensor (#batch, time2, size).
  207. value (torch.Tensor): Value tensor (#batch, time2, size).
  208. pos_emb (torch.Tensor): Positional embedding tensor
  209. (#batch, 2*time1-1, size).
  210. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  211. (#batch, time1, time2).
  212. Returns:
  213. torch.Tensor: Output tensor (#batch, time1, d_model).
  214. """
  215. q, k, v = self.forward_qkv(query, key, value)
  216. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  217. n_batch_pos = pos_emb.size(0)
  218. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  219. p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
  220. # (batch, head, time1, d_k)
  221. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  222. # (batch, head, time1, d_k)
  223. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  224. # compute attention score
  225. # first compute matrix a and matrix c
  226. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  227. # (batch, head, time1, time2)
  228. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  229. # compute matrix b and matrix d
  230. # (batch, head, time1, 2*time1-1)
  231. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  232. matrix_bd = self.rel_shift(matrix_bd)
  233. scores = (matrix_ac + matrix_bd) / math.sqrt(
  234. self.d_k
  235. ) # (batch, head, time1, time2)
  236. return self.forward_attention(v, scores, mask)
  237. class MultiHeadedAttentionSANM(nn.Module):
  238. """Multi-Head Attention layer.
  239. Args:
  240. n_head (int): The number of heads.
  241. n_feat (int): The number of features.
  242. dropout_rate (float): Dropout rate.
  243. """
  244. def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  245. """Construct an MultiHeadedAttention object."""
  246. super(MultiHeadedAttentionSANM, self).__init__()
  247. assert n_feat % n_head == 0
  248. # We assume d_v always equals d_k
  249. self.d_k = n_feat // n_head
  250. self.h = n_head
  251. # self.linear_q = nn.Linear(n_feat, n_feat)
  252. # self.linear_k = nn.Linear(n_feat, n_feat)
  253. # self.linear_v = nn.Linear(n_feat, n_feat)
  254. self.linear_out = nn.Linear(n_feat, n_feat)
  255. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  256. self.attn = None
  257. self.dropout = nn.Dropout(p=dropout_rate)
  258. self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  259. # padding
  260. left_padding = (kernel_size - 1) // 2
  261. if sanm_shfit > 0:
  262. left_padding = left_padding + sanm_shfit
  263. right_padding = kernel_size - 1 - left_padding
  264. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  265. def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
  266. b, t, d = inputs.size()
  267. if mask is not None:
  268. mask = torch.reshape(mask, (b, -1, 1))
  269. if mask_shfit_chunk is not None:
  270. mask = mask * mask_shfit_chunk
  271. inputs = inputs * mask
  272. x = inputs.transpose(1, 2)
  273. x = self.pad_fn(x)
  274. x = self.fsmn_block(x)
  275. x = x.transpose(1, 2)
  276. x += inputs
  277. x = self.dropout(x)
  278. if mask is not None:
  279. x = x * mask
  280. return x
  281. def forward_qkv(self, x):
  282. """Transform query, key and value.
  283. Args:
  284. query (torch.Tensor): Query tensor (#batch, time1, size).
  285. key (torch.Tensor): Key tensor (#batch, time2, size).
  286. value (torch.Tensor): Value tensor (#batch, time2, size).
  287. Returns:
  288. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  289. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  290. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  291. """
  292. b, t, d = x.size()
  293. q_k_v = self.linear_q_k_v(x)
  294. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  295. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  296. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  297. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  298. return q_h, k_h, v_h, v
  299. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  300. """Compute attention context vector.
  301. Args:
  302. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  303. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  304. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  305. Returns:
  306. torch.Tensor: Transformed value (#batch, time1, d_model)
  307. weighted by the attention score (#batch, time1, time2).
  308. """
  309. n_batch = value.size(0)
  310. if mask is not None:
  311. if mask_att_chunk_encoder is not None:
  312. mask = mask * mask_att_chunk_encoder
  313. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  314. min_value = float(
  315. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  316. )
  317. scores = scores.masked_fill(mask, min_value)
  318. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  319. mask, 0.0
  320. ) # (batch, head, time1, time2)
  321. else:
  322. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  323. p_attn = self.dropout(self.attn)
  324. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  325. x = (
  326. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  327. ) # (batch, time1, d_model)
  328. return self.linear_out(x) # (batch, time1, d_model)
  329. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  330. """Compute scaled dot product attention.
  331. Args:
  332. query (torch.Tensor): Query tensor (#batch, time1, size).
  333. key (torch.Tensor): Key tensor (#batch, time2, size).
  334. value (torch.Tensor): Value tensor (#batch, time2, size).
  335. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  336. (#batch, time1, time2).
  337. Returns:
  338. torch.Tensor: Output tensor (#batch, time1, d_model).
  339. """
  340. q_h, k_h, v_h, v = self.forward_qkv(x)
  341. fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
  342. q_h = q_h * self.d_k ** (-0.5)
  343. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  344. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  345. return att_outs + fsmn_memory
  346. class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
  347. def __init__(self, *args, **kwargs):
  348. super().__init__(*args, **kwargs)
  349. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  350. q_h, k_h, v_h, v = self.forward_qkv(x)
  351. fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
  352. q_h = q_h * self.d_k ** (-0.5)
  353. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  354. att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
  355. return att_outs + fsmn_memory
  356. class MultiHeadedAttentionSANMDecoder(nn.Module):
  357. """Multi-Head Attention layer.
  358. Args:
  359. n_head (int): The number of heads.
  360. n_feat (int): The number of features.
  361. dropout_rate (float): Dropout rate.
  362. """
  363. def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  364. """Construct an MultiHeadedAttention object."""
  365. super(MultiHeadedAttentionSANMDecoder, self).__init__()
  366. self.dropout = nn.Dropout(p=dropout_rate)
  367. self.fsmn_block = nn.Conv1d(n_feat, n_feat,
  368. kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  369. # padding
  370. # padding
  371. left_padding = (kernel_size - 1) // 2
  372. if sanm_shfit > 0:
  373. left_padding = left_padding + sanm_shfit
  374. right_padding = kernel_size - 1 - left_padding
  375. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  376. self.kernel_size = kernel_size
  377. def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
  378. '''
  379. :param x: (#batch, time1, size).
  380. :param mask: Mask tensor (#batch, 1, time)
  381. :return:
  382. '''
  383. # print("in fsmn, inputs", inputs.size())
  384. b, t, d = inputs.size()
  385. # logging.info(
  386. # "mask: {}".format(mask.size()))
  387. if mask is not None:
  388. mask = torch.reshape(mask, (b ,-1, 1))
  389. # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  390. if mask_shfit_chunk is not None:
  391. # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
  392. mask = mask * mask_shfit_chunk
  393. # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  394. # print("in fsmn, mask", mask.size())
  395. # print("in fsmn, inputs", inputs.size())
  396. inputs = inputs * mask
  397. x = inputs.transpose(1, 2)
  398. b, d, t = x.size()
  399. if cache is None:
  400. # print("in fsmn, cache is None, x", x.size())
  401. x = self.pad_fn(x)
  402. if not self.training:
  403. cache = x
  404. else:
  405. # print("in fsmn, cache is not None, x", x.size())
  406. # x = torch.cat((x, cache), dim=2)[:, :, :-1]
  407. # if t < self.kernel_size:
  408. # x = self.pad_fn(x)
  409. x = torch.cat((cache[:, :, 1:], x), dim=2)
  410. x = x[:, :, -(self.kernel_size+t-1):]
  411. # print("in fsmn, cache is not None, x_cat", x.size())
  412. cache = x
  413. x = self.fsmn_block(x)
  414. x = x.transpose(1, 2)
  415. # print("in fsmn, fsmn_out", x.size())
  416. if x.size(1) != inputs.size(1):
  417. inputs = inputs[:, -1, :]
  418. x = x + inputs
  419. x = self.dropout(x)
  420. if mask is not None:
  421. x = x * mask
  422. return x, cache
  423. class MultiHeadedAttentionCrossAtt(nn.Module):
  424. """Multi-Head Attention layer.
  425. Args:
  426. n_head (int): The number of heads.
  427. n_feat (int): The number of features.
  428. dropout_rate (float): Dropout rate.
  429. """
  430. def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
  431. """Construct an MultiHeadedAttention object."""
  432. super(MultiHeadedAttentionCrossAtt, self).__init__()
  433. assert n_feat % n_head == 0
  434. # We assume d_v always equals d_k
  435. self.d_k = n_feat // n_head
  436. self.h = n_head
  437. self.linear_q = nn.Linear(n_feat, n_feat)
  438. # self.linear_k = nn.Linear(n_feat, n_feat)
  439. # self.linear_v = nn.Linear(n_feat, n_feat)
  440. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  441. self.linear_out = nn.Linear(n_feat, n_feat)
  442. self.attn = None
  443. self.dropout = nn.Dropout(p=dropout_rate)
  444. def forward_qkv(self, x, memory):
  445. """Transform query, key and value.
  446. Args:
  447. query (torch.Tensor): Query tensor (#batch, time1, size).
  448. key (torch.Tensor): Key tensor (#batch, time2, size).
  449. value (torch.Tensor): Value tensor (#batch, time2, size).
  450. Returns:
  451. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  452. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  453. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  454. """
  455. # print("in forward_qkv, x", x.size())
  456. b = x.size(0)
  457. q = self.linear_q(x)
  458. q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  459. k_v = self.linear_k_v(memory)
  460. k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
  461. k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  462. v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  463. return q_h, k_h, v_h
  464. def forward_attention(self, value, scores, mask):
  465. """Compute attention context vector.
  466. Args:
  467. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  468. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  469. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  470. Returns:
  471. torch.Tensor: Transformed value (#batch, time1, d_model)
  472. weighted by the attention score (#batch, time1, time2).
  473. """
  474. n_batch = value.size(0)
  475. if mask is not None:
  476. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  477. min_value = float(
  478. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  479. )
  480. # logging.info(
  481. # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
  482. scores = scores.masked_fill(mask, min_value)
  483. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  484. mask, 0.0
  485. ) # (batch, head, time1, time2)
  486. else:
  487. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  488. p_attn = self.dropout(self.attn)
  489. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  490. x = (
  491. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  492. ) # (batch, time1, d_model)
  493. return self.linear_out(x) # (batch, time1, d_model)
  494. def forward(self, x, memory, memory_mask):
  495. """Compute scaled dot product attention.
  496. Args:
  497. query (torch.Tensor): Query tensor (#batch, time1, size).
  498. key (torch.Tensor): Key tensor (#batch, time2, size).
  499. value (torch.Tensor): Value tensor (#batch, time2, size).
  500. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  501. (#batch, time1, time2).
  502. Returns:
  503. torch.Tensor: Output tensor (#batch, time1, d_model).
  504. """
  505. q_h, k_h, v_h = self.forward_qkv(x, memory)
  506. q_h = q_h * self.d_k ** (-0.5)
  507. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  508. return self.forward_attention(v_h, scores, memory_mask)
  509. class MultiHeadSelfAttention(nn.Module):
  510. """Multi-Head Attention layer.
  511. Args:
  512. n_head (int): The number of heads.
  513. n_feat (int): The number of features.
  514. dropout_rate (float): Dropout rate.
  515. """
  516. def __init__(self, n_head, in_feat, n_feat, dropout_rate):
  517. """Construct an MultiHeadedAttention object."""
  518. super(MultiHeadSelfAttention, self).__init__()
  519. assert n_feat % n_head == 0
  520. # We assume d_v always equals d_k
  521. self.d_k = n_feat // n_head
  522. self.h = n_head
  523. self.linear_out = nn.Linear(n_feat, n_feat)
  524. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  525. self.attn = None
  526. self.dropout = nn.Dropout(p=dropout_rate)
  527. def forward_qkv(self, x):
  528. """Transform query, key and value.
  529. Args:
  530. query (torch.Tensor): Query tensor (#batch, time1, size).
  531. key (torch.Tensor): Key tensor (#batch, time2, size).
  532. value (torch.Tensor): Value tensor (#batch, time2, size).
  533. Returns:
  534. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  535. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  536. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  537. """
  538. b, t, d = x.size()
  539. q_k_v = self.linear_q_k_v(x)
  540. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  541. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  542. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  543. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  544. return q_h, k_h, v_h, v
  545. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  546. """Compute attention context vector.
  547. Args:
  548. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  549. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  550. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  551. Returns:
  552. torch.Tensor: Transformed value (#batch, time1, d_model)
  553. weighted by the attention score (#batch, time1, time2).
  554. """
  555. n_batch = value.size(0)
  556. if mask is not None:
  557. if mask_att_chunk_encoder is not None:
  558. mask = mask * mask_att_chunk_encoder
  559. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  560. min_value = float(
  561. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  562. )
  563. scores = scores.masked_fill(mask, min_value)
  564. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  565. mask, 0.0
  566. ) # (batch, head, time1, time2)
  567. else:
  568. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  569. p_attn = self.dropout(self.attn)
  570. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  571. x = (
  572. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  573. ) # (batch, time1, d_model)
  574. return self.linear_out(x) # (batch, time1, d_model)
  575. def forward(self, x, mask, mask_att_chunk_encoder=None):
  576. """Compute scaled dot product attention.
  577. Args:
  578. query (torch.Tensor): Query tensor (#batch, time1, size).
  579. key (torch.Tensor): Key tensor (#batch, time2, size).
  580. value (torch.Tensor): Value tensor (#batch, time2, size).
  581. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  582. (#batch, time1, time2).
  583. Returns:
  584. torch.Tensor: Output tensor (#batch, time1, d_model).
  585. """
  586. q_h, k_h, v_h, v = self.forward_qkv(x)
  587. q_h = q_h * self.d_k ** (-0.5)
  588. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  589. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  590. return att_outs