attention.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Multi-Head Attention layer definition."""
  6. import math
  7. import numpy
  8. import torch
  9. from torch import nn
  10. class MultiHeadedAttention(nn.Module):
  11. """Multi-Head Attention layer.
  12. Args:
  13. n_head (int): The number of heads.
  14. n_feat (int): The number of features.
  15. dropout_rate (float): Dropout rate.
  16. """
  17. def __init__(self, n_head, n_feat, dropout_rate):
  18. """Construct an MultiHeadedAttention object."""
  19. super(MultiHeadedAttention, self).__init__()
  20. assert n_feat % n_head == 0
  21. # We assume d_v always equals d_k
  22. self.d_k = n_feat // n_head
  23. self.h = n_head
  24. self.linear_q = nn.Linear(n_feat, n_feat)
  25. self.linear_k = nn.Linear(n_feat, n_feat)
  26. self.linear_v = nn.Linear(n_feat, n_feat)
  27. self.linear_out = nn.Linear(n_feat, n_feat)
  28. self.attn = None
  29. self.dropout = nn.Dropout(p=dropout_rate)
  30. def forward_qkv(self, query, key, value):
  31. """Transform query, key and value.
  32. Args:
  33. query (torch.Tensor): Query tensor (#batch, time1, size).
  34. key (torch.Tensor): Key tensor (#batch, time2, size).
  35. value (torch.Tensor): Value tensor (#batch, time2, size).
  36. Returns:
  37. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  38. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  39. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  40. """
  41. n_batch = query.size(0)
  42. q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
  43. k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
  44. v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
  45. q = q.transpose(1, 2) # (batch, head, time1, d_k)
  46. k = k.transpose(1, 2) # (batch, head, time2, d_k)
  47. v = v.transpose(1, 2) # (batch, head, time2, d_k)
  48. return q, k, v
  49. def forward_attention(self, value, scores, mask):
  50. """Compute attention context vector.
  51. Args:
  52. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  53. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  54. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  55. Returns:
  56. torch.Tensor: Transformed value (#batch, time1, d_model)
  57. weighted by the attention score (#batch, time1, time2).
  58. """
  59. n_batch = value.size(0)
  60. if mask is not None:
  61. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  62. min_value = float(
  63. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  64. )
  65. scores = scores.masked_fill(mask, min_value)
  66. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  67. mask, 0.0
  68. ) # (batch, head, time1, time2)
  69. else:
  70. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  71. p_attn = self.dropout(self.attn)
  72. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  73. x = (
  74. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  75. ) # (batch, time1, d_model)
  76. return self.linear_out(x) # (batch, time1, d_model)
  77. def forward(self, query, key, value, mask):
  78. """Compute scaled dot product attention.
  79. Args:
  80. query (torch.Tensor): Query tensor (#batch, time1, size).
  81. key (torch.Tensor): Key tensor (#batch, time2, size).
  82. value (torch.Tensor): Value tensor (#batch, time2, size).
  83. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  84. (#batch, time1, time2).
  85. Returns:
  86. torch.Tensor: Output tensor (#batch, time1, d_model).
  87. """
  88. q, k, v = self.forward_qkv(query, key, value)
  89. scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
  90. return self.forward_attention(v, scores, mask)
  91. class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
  92. """Multi-Head Attention layer with relative position encoding (old version).
  93. Details can be found in https://github.com/espnet/espnet/pull/2816.
  94. Paper: https://arxiv.org/abs/1901.02860
  95. Args:
  96. n_head (int): The number of heads.
  97. n_feat (int): The number of features.
  98. dropout_rate (float): Dropout rate.
  99. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  100. """
  101. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  102. """Construct an RelPositionMultiHeadedAttention object."""
  103. super().__init__(n_head, n_feat, dropout_rate)
  104. self.zero_triu = zero_triu
  105. # linear transformation for positional encoding
  106. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  107. # these two learnable bias are used in matrix c and matrix d
  108. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  109. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  110. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  111. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  112. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  113. def rel_shift(self, x):
  114. """Compute relative positional encoding.
  115. Args:
  116. x (torch.Tensor): Input tensor (batch, head, time1, time2).
  117. Returns:
  118. torch.Tensor: Output tensor.
  119. """
  120. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  121. x_padded = torch.cat([zero_pad, x], dim=-1)
  122. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  123. x = x_padded[:, :, 1:].view_as(x)
  124. if self.zero_triu:
  125. ones = torch.ones((x.size(2), x.size(3)))
  126. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  127. return x
  128. def forward(self, query, key, value, pos_emb, mask):
  129. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  130. Args:
  131. query (torch.Tensor): Query tensor (#batch, time1, size).
  132. key (torch.Tensor): Key tensor (#batch, time2, size).
  133. value (torch.Tensor): Value tensor (#batch, time2, size).
  134. pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
  135. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  136. (#batch, time1, time2).
  137. Returns:
  138. torch.Tensor: Output tensor (#batch, time1, d_model).
  139. """
  140. q, k, v = self.forward_qkv(query, key, value)
  141. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  142. n_batch_pos = pos_emb.size(0)
  143. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  144. p = p.transpose(1, 2) # (batch, head, time1, d_k)
  145. # (batch, head, time1, d_k)
  146. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  147. # (batch, head, time1, d_k)
  148. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  149. # compute attention score
  150. # first compute matrix a and matrix c
  151. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  152. # (batch, head, time1, time2)
  153. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  154. # compute matrix b and matrix d
  155. # (batch, head, time1, time1)
  156. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  157. matrix_bd = self.rel_shift(matrix_bd)
  158. scores = (matrix_ac + matrix_bd) / math.sqrt(
  159. self.d_k
  160. ) # (batch, head, time1, time2)
  161. return self.forward_attention(v, scores, mask)
  162. class RelPositionMultiHeadedAttention(MultiHeadedAttention):
  163. """Multi-Head Attention layer with relative position encoding (new implementation).
  164. Details can be found in https://github.com/espnet/espnet/pull/2816.
  165. Paper: https://arxiv.org/abs/1901.02860
  166. Args:
  167. n_head (int): The number of heads.
  168. n_feat (int): The number of features.
  169. dropout_rate (float): Dropout rate.
  170. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  171. """
  172. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  173. """Construct an RelPositionMultiHeadedAttention object."""
  174. super().__init__(n_head, n_feat, dropout_rate)
  175. self.zero_triu = zero_triu
  176. # linear transformation for positional encoding
  177. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  178. # these two learnable bias are used in matrix c and matrix d
  179. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  180. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  181. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  182. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  183. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  184. def rel_shift(self, x):
  185. """Compute relative positional encoding.
  186. Args:
  187. x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
  188. time1 means the length of query vector.
  189. Returns:
  190. torch.Tensor: Output tensor.
  191. """
  192. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  193. x_padded = torch.cat([zero_pad, x], dim=-1)
  194. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  195. x = x_padded[:, :, 1:].view_as(x)[
  196. :, :, :, : x.size(-1) // 2 + 1
  197. ] # only keep the positions from 0 to time2
  198. if self.zero_triu:
  199. ones = torch.ones((x.size(2), x.size(3)), device=x.device)
  200. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  201. return x
  202. def forward(self, query, key, value, pos_emb, mask):
  203. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  204. Args:
  205. query (torch.Tensor): Query tensor (#batch, time1, size).
  206. key (torch.Tensor): Key tensor (#batch, time2, size).
  207. value (torch.Tensor): Value tensor (#batch, time2, size).
  208. pos_emb (torch.Tensor): Positional embedding tensor
  209. (#batch, 2*time1-1, size).
  210. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  211. (#batch, time1, time2).
  212. Returns:
  213. torch.Tensor: Output tensor (#batch, time1, d_model).
  214. """
  215. q, k, v = self.forward_qkv(query, key, value)
  216. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  217. n_batch_pos = pos_emb.size(0)
  218. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  219. p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
  220. # (batch, head, time1, d_k)
  221. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  222. # (batch, head, time1, d_k)
  223. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  224. # compute attention score
  225. # first compute matrix a and matrix c
  226. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  227. # (batch, head, time1, time2)
  228. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  229. # compute matrix b and matrix d
  230. # (batch, head, time1, 2*time1-1)
  231. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  232. matrix_bd = self.rel_shift(matrix_bd)
  233. scores = (matrix_ac + matrix_bd) / math.sqrt(
  234. self.d_k
  235. ) # (batch, head, time1, time2)
  236. return self.forward_attention(v, scores, mask)
  237. class MultiHeadedAttentionSANM(nn.Module):
  238. """Multi-Head Attention layer.
  239. Args:
  240. n_head (int): The number of heads.
  241. n_feat (int): The number of features.
  242. dropout_rate (float): Dropout rate.
  243. """
  244. def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  245. """Construct an MultiHeadedAttention object."""
  246. super(MultiHeadedAttentionSANM, self).__init__()
  247. assert n_feat % n_head == 0
  248. # We assume d_v always equals d_k
  249. self.d_k = n_feat // n_head
  250. self.h = n_head
  251. # self.linear_q = nn.Linear(n_feat, n_feat)
  252. # self.linear_k = nn.Linear(n_feat, n_feat)
  253. # self.linear_v = nn.Linear(n_feat, n_feat)
  254. self.linear_out = nn.Linear(n_feat, n_feat)
  255. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  256. self.attn = None
  257. self.dropout = nn.Dropout(p=dropout_rate)
  258. self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  259. # padding
  260. left_padding = (kernel_size - 1) // 2
  261. if sanm_shfit > 0:
  262. left_padding = left_padding + sanm_shfit
  263. right_padding = kernel_size - 1 - left_padding
  264. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  265. def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
  266. b, t, d = inputs.size()
  267. if mask is not None:
  268. mask = torch.reshape(mask, (b, -1, 1))
  269. if mask_shfit_chunk is not None:
  270. mask = mask * mask_shfit_chunk
  271. inputs = inputs * mask
  272. x = inputs.transpose(1, 2)
  273. x = self.pad_fn(x)
  274. x = self.fsmn_block(x)
  275. x = x.transpose(1, 2)
  276. x += inputs
  277. x = self.dropout(x)
  278. return x * mask
  279. def forward_qkv(self, x):
  280. """Transform query, key and value.
  281. Args:
  282. query (torch.Tensor): Query tensor (#batch, time1, size).
  283. key (torch.Tensor): Key tensor (#batch, time2, size).
  284. value (torch.Tensor): Value tensor (#batch, time2, size).
  285. Returns:
  286. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  287. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  288. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  289. """
  290. b, t, d = x.size()
  291. q_k_v = self.linear_q_k_v(x)
  292. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  293. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  294. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  295. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  296. return q_h, k_h, v_h, v
  297. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  298. """Compute attention context vector.
  299. Args:
  300. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  301. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  302. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  303. Returns:
  304. torch.Tensor: Transformed value (#batch, time1, d_model)
  305. weighted by the attention score (#batch, time1, time2).
  306. """
  307. n_batch = value.size(0)
  308. if mask is not None:
  309. if mask_att_chunk_encoder is not None:
  310. mask = mask * mask_att_chunk_encoder
  311. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  312. min_value = float(
  313. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  314. )
  315. scores = scores.masked_fill(mask, min_value)
  316. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  317. mask, 0.0
  318. ) # (batch, head, time1, time2)
  319. else:
  320. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  321. p_attn = self.dropout(self.attn)
  322. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  323. x = (
  324. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  325. ) # (batch, time1, d_model)
  326. return self.linear_out(x) # (batch, time1, d_model)
  327. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  328. """Compute scaled dot product attention.
  329. Args:
  330. query (torch.Tensor): Query tensor (#batch, time1, size).
  331. key (torch.Tensor): Key tensor (#batch, time2, size).
  332. value (torch.Tensor): Value tensor (#batch, time2, size).
  333. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  334. (#batch, time1, time2).
  335. Returns:
  336. torch.Tensor: Output tensor (#batch, time1, d_model).
  337. """
  338. q_h, k_h, v_h, v = self.forward_qkv(x)
  339. fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
  340. q_h = q_h * self.d_k ** (-0.5)
  341. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  342. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  343. return att_outs + fsmn_memory
  344. class MultiHeadedAttentionSANMDecoder(nn.Module):
  345. """Multi-Head Attention layer.
  346. Args:
  347. n_head (int): The number of heads.
  348. n_feat (int): The number of features.
  349. dropout_rate (float): Dropout rate.
  350. """
  351. def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  352. """Construct an MultiHeadedAttention object."""
  353. super(MultiHeadedAttentionSANMDecoder, self).__init__()
  354. self.dropout = nn.Dropout(p=dropout_rate)
  355. self.fsmn_block = nn.Conv1d(n_feat, n_feat,
  356. kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  357. # padding
  358. # padding
  359. left_padding = (kernel_size - 1) // 2
  360. if sanm_shfit > 0:
  361. left_padding = left_padding + sanm_shfit
  362. right_padding = kernel_size - 1 - left_padding
  363. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  364. self.kernel_size = kernel_size
  365. def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
  366. '''
  367. :param x: (#batch, time1, size).
  368. :param mask: Mask tensor (#batch, 1, time)
  369. :return:
  370. '''
  371. # print("in fsmn, inputs", inputs.size())
  372. b, t, d = inputs.size()
  373. # logging.info(
  374. # "mask: {}".format(mask.size()))
  375. if mask is not None:
  376. mask = torch.reshape(mask, (b ,-1, 1))
  377. # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  378. if mask_shfit_chunk is not None:
  379. # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
  380. mask = mask * mask_shfit_chunk
  381. # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  382. # print("in fsmn, mask", mask.size())
  383. # print("in fsmn, inputs", inputs.size())
  384. inputs = inputs * mask
  385. x = inputs.transpose(1, 2)
  386. b, d, t = x.size()
  387. if cache is None:
  388. # print("in fsmn, cache is None, x", x.size())
  389. x = self.pad_fn(x)
  390. if not self.training and t <= 1:
  391. cache = x
  392. else:
  393. # print("in fsmn, cache is not None, x", x.size())
  394. # x = torch.cat((x, cache), dim=2)[:, :, :-1]
  395. # if t < self.kernel_size:
  396. # x = self.pad_fn(x)
  397. x = torch.cat((cache[:, :, 1:], x), dim=2)
  398. x = x[:, :, -self.kernel_size:]
  399. # print("in fsmn, cache is not None, x_cat", x.size())
  400. cache = x
  401. x = self.fsmn_block(x)
  402. x = x.transpose(1, 2)
  403. # print("in fsmn, fsmn_out", x.size())
  404. if x.size(1) != inputs.size(1):
  405. inputs = inputs[:, -1, :]
  406. x = x + inputs
  407. x = self.dropout(x)
  408. if mask is not None:
  409. x = x * mask
  410. return x, cache
  411. class MultiHeadedAttentionCrossAtt(nn.Module):
  412. """Multi-Head Attention layer.
  413. Args:
  414. n_head (int): The number of heads.
  415. n_feat (int): The number of features.
  416. dropout_rate (float): Dropout rate.
  417. """
  418. def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
  419. """Construct an MultiHeadedAttention object."""
  420. super(MultiHeadedAttentionCrossAtt, self).__init__()
  421. assert n_feat % n_head == 0
  422. # We assume d_v always equals d_k
  423. self.d_k = n_feat // n_head
  424. self.h = n_head
  425. self.linear_q = nn.Linear(n_feat, n_feat)
  426. # self.linear_k = nn.Linear(n_feat, n_feat)
  427. # self.linear_v = nn.Linear(n_feat, n_feat)
  428. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  429. self.linear_out = nn.Linear(n_feat, n_feat)
  430. self.attn = None
  431. self.dropout = nn.Dropout(p=dropout_rate)
  432. def forward_qkv(self, x, memory):
  433. """Transform query, key and value.
  434. Args:
  435. query (torch.Tensor): Query tensor (#batch, time1, size).
  436. key (torch.Tensor): Key tensor (#batch, time2, size).
  437. value (torch.Tensor): Value tensor (#batch, time2, size).
  438. Returns:
  439. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  440. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  441. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  442. """
  443. # print("in forward_qkv, x", x.size())
  444. b = x.size(0)
  445. q = self.linear_q(x)
  446. q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  447. k_v = self.linear_k_v(memory)
  448. k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
  449. k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  450. v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  451. return q_h, k_h, v_h
  452. def forward_attention(self, value, scores, mask):
  453. """Compute attention context vector.
  454. Args:
  455. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  456. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  457. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  458. Returns:
  459. torch.Tensor: Transformed value (#batch, time1, d_model)
  460. weighted by the attention score (#batch, time1, time2).
  461. """
  462. n_batch = value.size(0)
  463. if mask is not None:
  464. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  465. min_value = float(
  466. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  467. )
  468. # logging.info(
  469. # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
  470. scores = scores.masked_fill(mask, min_value)
  471. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  472. mask, 0.0
  473. ) # (batch, head, time1, time2)
  474. else:
  475. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  476. p_attn = self.dropout(self.attn)
  477. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  478. x = (
  479. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  480. ) # (batch, time1, d_model)
  481. return self.linear_out(x) # (batch, time1, d_model)
  482. def forward(self, x, memory, memory_mask):
  483. """Compute scaled dot product attention.
  484. Args:
  485. query (torch.Tensor): Query tensor (#batch, time1, size).
  486. key (torch.Tensor): Key tensor (#batch, time2, size).
  487. value (torch.Tensor): Value tensor (#batch, time2, size).
  488. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  489. (#batch, time1, time2).
  490. Returns:
  491. torch.Tensor: Output tensor (#batch, time1, d_model).
  492. """
  493. q_h, k_h, v_h = self.forward_qkv(x, memory)
  494. q_h = q_h * self.d_k ** (-0.5)
  495. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  496. return self.forward_attention(v_h, scores, memory_mask)
  497. class MultiHeadSelfAttention(nn.Module):
  498. """Multi-Head Attention layer.
  499. Args:
  500. n_head (int): The number of heads.
  501. n_feat (int): The number of features.
  502. dropout_rate (float): Dropout rate.
  503. """
  504. def __init__(self, n_head, in_feat, n_feat, dropout_rate):
  505. """Construct an MultiHeadedAttention object."""
  506. super(MultiHeadSelfAttention, self).__init__()
  507. assert n_feat % n_head == 0
  508. # We assume d_v always equals d_k
  509. self.d_k = n_feat // n_head
  510. self.h = n_head
  511. self.linear_out = nn.Linear(n_feat, n_feat)
  512. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  513. self.attn = None
  514. self.dropout = nn.Dropout(p=dropout_rate)
  515. def forward_qkv(self, x):
  516. """Transform query, key and value.
  517. Args:
  518. query (torch.Tensor): Query tensor (#batch, time1, size).
  519. key (torch.Tensor): Key tensor (#batch, time2, size).
  520. value (torch.Tensor): Value tensor (#batch, time2, size).
  521. Returns:
  522. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  523. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  524. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  525. """
  526. b, t, d = x.size()
  527. q_k_v = self.linear_q_k_v(x)
  528. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  529. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  530. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  531. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  532. return q_h, k_h, v_h, v
  533. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  534. """Compute attention context vector.
  535. Args:
  536. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  537. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  538. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  539. Returns:
  540. torch.Tensor: Transformed value (#batch, time1, d_model)
  541. weighted by the attention score (#batch, time1, time2).
  542. """
  543. n_batch = value.size(0)
  544. if mask is not None:
  545. if mask_att_chunk_encoder is not None:
  546. mask = mask * mask_att_chunk_encoder
  547. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  548. min_value = float(
  549. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  550. )
  551. scores = scores.masked_fill(mask, min_value)
  552. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  553. mask, 0.0
  554. ) # (batch, head, time1, time2)
  555. else:
  556. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  557. p_attn = self.dropout(self.attn)
  558. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  559. x = (
  560. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  561. ) # (batch, time1, d_model)
  562. return self.linear_out(x) # (batch, time1, d_model)
  563. def forward(self, x, mask, mask_att_chunk_encoder=None):
  564. """Compute scaled dot product attention.
  565. Args:
  566. query (torch.Tensor): Query tensor (#batch, time1, size).
  567. key (torch.Tensor): Key tensor (#batch, time2, size).
  568. value (torch.Tensor): Value tensor (#batch, time2, size).
  569. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  570. (#batch, time1, time2).
  571. Returns:
  572. torch.Tensor: Output tensor (#batch, time1, d_model).
  573. """
  574. q_h, k_h, v_h, v = self.forward_qkv(x)
  575. q_h = q_h * self.d_k ** (-0.5)
  576. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  577. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  578. return att_outs