attention.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Multi-Head Attention layer definition."""
  6. import math
  7. import numpy
  8. import torch
  9. from torch import nn
  10. from typing import Optional, Tuple
  11. import torch.nn.functional as F
  12. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  13. import funasr.models.lora.layers as lora
  14. class MultiHeadedAttention(nn.Module):
  15. """Multi-Head Attention layer.
  16. Args:
  17. n_head (int): The number of heads.
  18. n_feat (int): The number of features.
  19. dropout_rate (float): Dropout rate.
  20. """
  21. def __init__(self, n_head, n_feat, dropout_rate):
  22. """Construct an MultiHeadedAttention object."""
  23. super(MultiHeadedAttention, self).__init__()
  24. assert n_feat % n_head == 0
  25. # We assume d_v always equals d_k
  26. self.d_k = n_feat // n_head
  27. self.h = n_head
  28. self.linear_q = nn.Linear(n_feat, n_feat)
  29. self.linear_k = nn.Linear(n_feat, n_feat)
  30. self.linear_v = nn.Linear(n_feat, n_feat)
  31. self.linear_out = nn.Linear(n_feat, n_feat)
  32. self.attn = None
  33. self.dropout = nn.Dropout(p=dropout_rate)
  34. def forward_qkv(self, query, key, value):
  35. """Transform query, key and value.
  36. Args:
  37. query (torch.Tensor): Query tensor (#batch, time1, size).
  38. key (torch.Tensor): Key tensor (#batch, time2, size).
  39. value (torch.Tensor): Value tensor (#batch, time2, size).
  40. Returns:
  41. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  42. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  43. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  44. """
  45. n_batch = query.size(0)
  46. q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
  47. k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
  48. v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
  49. q = q.transpose(1, 2) # (batch, head, time1, d_k)
  50. k = k.transpose(1, 2) # (batch, head, time2, d_k)
  51. v = v.transpose(1, 2) # (batch, head, time2, d_k)
  52. return q, k, v
  53. def forward_attention(self, value, scores, mask):
  54. """Compute attention context vector.
  55. Args:
  56. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  57. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  58. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  59. Returns:
  60. torch.Tensor: Transformed value (#batch, time1, d_model)
  61. weighted by the attention score (#batch, time1, time2).
  62. """
  63. n_batch = value.size(0)
  64. if mask is not None:
  65. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  66. min_value = float(
  67. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  68. )
  69. scores = scores.masked_fill(mask, min_value)
  70. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  71. mask, 0.0
  72. ) # (batch, head, time1, time2)
  73. else:
  74. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  75. p_attn = self.dropout(self.attn)
  76. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  77. x = (
  78. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  79. ) # (batch, time1, d_model)
  80. return self.linear_out(x) # (batch, time1, d_model)
  81. def forward(self, query, key, value, mask):
  82. """Compute scaled dot product attention.
  83. Args:
  84. query (torch.Tensor): Query tensor (#batch, time1, size).
  85. key (torch.Tensor): Key tensor (#batch, time2, size).
  86. value (torch.Tensor): Value tensor (#batch, time2, size).
  87. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  88. (#batch, time1, time2).
  89. Returns:
  90. torch.Tensor: Output tensor (#batch, time1, d_model).
  91. """
  92. q, k, v = self.forward_qkv(query, key, value)
  93. scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
  94. return self.forward_attention(v, scores, mask)
  95. class MultiHeadedAttentionSANM(nn.Module):
  96. """Multi-Head Attention layer.
  97. Args:
  98. n_head (int): The number of heads.
  99. n_feat (int): The number of features.
  100. dropout_rate (float): Dropout rate.
  101. """
  102. def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
  103. """Construct an MultiHeadedAttention object."""
  104. super().__init__()
  105. assert n_feat % n_head == 0
  106. # We assume d_v always equals d_k
  107. self.d_k = n_feat // n_head
  108. self.h = n_head
  109. # self.linear_q = nn.Linear(n_feat, n_feat)
  110. # self.linear_k = nn.Linear(n_feat, n_feat)
  111. # self.linear_v = nn.Linear(n_feat, n_feat)
  112. if lora_list is not None:
  113. if "o" in lora_list:
  114. self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
  115. else:
  116. self.linear_out = nn.Linear(n_feat, n_feat)
  117. lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
  118. if lora_qkv_list == [False, False, False]:
  119. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  120. else:
  121. self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
  122. else:
  123. self.linear_out = nn.Linear(n_feat, n_feat)
  124. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  125. self.attn = None
  126. self.dropout = nn.Dropout(p=dropout_rate)
  127. self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  128. # padding
  129. left_padding = (kernel_size - 1) // 2
  130. if sanm_shfit > 0:
  131. left_padding = left_padding + sanm_shfit
  132. right_padding = kernel_size - 1 - left_padding
  133. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  134. def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
  135. b, t, d = inputs.size()
  136. if mask is not None:
  137. mask = torch.reshape(mask, (b, -1, 1))
  138. if mask_shfit_chunk is not None:
  139. mask = mask * mask_shfit_chunk
  140. inputs = inputs * mask
  141. x = inputs.transpose(1, 2)
  142. x = self.pad_fn(x)
  143. x = self.fsmn_block(x)
  144. x = x.transpose(1, 2)
  145. x += inputs
  146. x = self.dropout(x)
  147. if mask is not None:
  148. x = x * mask
  149. return x
  150. def forward_qkv(self, x):
  151. """Transform query, key and value.
  152. Args:
  153. query (torch.Tensor): Query tensor (#batch, time1, size).
  154. key (torch.Tensor): Key tensor (#batch, time2, size).
  155. value (torch.Tensor): Value tensor (#batch, time2, size).
  156. Returns:
  157. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  158. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  159. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  160. """
  161. b, t, d = x.size()
  162. q_k_v = self.linear_q_k_v(x)
  163. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  164. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  165. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  166. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  167. return q_h, k_h, v_h, v
  168. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  169. """Compute attention context vector.
  170. Args:
  171. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  172. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  173. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  174. Returns:
  175. torch.Tensor: Transformed value (#batch, time1, d_model)
  176. weighted by the attention score (#batch, time1, time2).
  177. """
  178. n_batch = value.size(0)
  179. if mask is not None:
  180. if mask_att_chunk_encoder is not None:
  181. mask = mask * mask_att_chunk_encoder
  182. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  183. min_value = float(
  184. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  185. )
  186. scores = scores.masked_fill(mask, min_value)
  187. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  188. mask, 0.0
  189. ) # (batch, head, time1, time2)
  190. else:
  191. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  192. p_attn = self.dropout(self.attn)
  193. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  194. x = (
  195. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  196. ) # (batch, time1, d_model)
  197. return self.linear_out(x) # (batch, time1, d_model)
  198. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  199. """Compute scaled dot product attention.
  200. Args:
  201. query (torch.Tensor): Query tensor (#batch, time1, size).
  202. key (torch.Tensor): Key tensor (#batch, time2, size).
  203. value (torch.Tensor): Value tensor (#batch, time2, size).
  204. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  205. (#batch, time1, time2).
  206. Returns:
  207. torch.Tensor: Output tensor (#batch, time1, d_model).
  208. """
  209. q_h, k_h, v_h, v = self.forward_qkv(x)
  210. fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
  211. q_h = q_h * self.d_k ** (-0.5)
  212. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  213. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  214. return att_outs + fsmn_memory
  215. def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
  216. """Compute scaled dot product attention.
  217. Args:
  218. query (torch.Tensor): Query tensor (#batch, time1, size).
  219. key (torch.Tensor): Key tensor (#batch, time2, size).
  220. value (torch.Tensor): Value tensor (#batch, time2, size).
  221. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  222. (#batch, time1, time2).
  223. Returns:
  224. torch.Tensor: Output tensor (#batch, time1, d_model).
  225. """
  226. q_h, k_h, v_h, v = self.forward_qkv(x)
  227. if chunk_size is not None and look_back > 0 or look_back == -1:
  228. if cache is not None:
  229. k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
  230. v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
  231. k_h = torch.cat((cache["k"], k_h), dim=2)
  232. v_h = torch.cat((cache["v"], v_h), dim=2)
  233. cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
  234. cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
  235. if look_back != -1:
  236. cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
  237. cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
  238. else:
  239. cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
  240. "v": v_h[:, :, :-(chunk_size[2]), :]}
  241. cache = cache_tmp
  242. fsmn_memory = self.forward_fsmn(v, None)
  243. q_h = q_h * self.d_k ** (-0.5)
  244. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  245. att_outs = self.forward_attention(v_h, scores, None)
  246. return att_outs + fsmn_memory, cache
  247. class MultiHeadedAttentionSANMDecoder(nn.Module):
  248. """Multi-Head Attention layer.
  249. Args:
  250. n_head (int): The number of heads.
  251. n_feat (int): The number of features.
  252. dropout_rate (float): Dropout rate.
  253. """
  254. def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  255. """Construct an MultiHeadedAttention object."""
  256. super(MultiHeadedAttentionSANMDecoder, self).__init__()
  257. self.dropout = nn.Dropout(p=dropout_rate)
  258. self.fsmn_block = nn.Conv1d(n_feat, n_feat,
  259. kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  260. # padding
  261. # padding
  262. left_padding = (kernel_size - 1) // 2
  263. if sanm_shfit > 0:
  264. left_padding = left_padding + sanm_shfit
  265. right_padding = kernel_size - 1 - left_padding
  266. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  267. self.kernel_size = kernel_size
  268. def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
  269. '''
  270. :param x: (#batch, time1, size).
  271. :param mask: Mask tensor (#batch, 1, time)
  272. :return:
  273. '''
  274. # print("in fsmn, inputs", inputs.size())
  275. b, t, d = inputs.size()
  276. # logging.info(
  277. # "mask: {}".format(mask.size()))
  278. if mask is not None:
  279. mask = torch.reshape(mask, (b ,-1, 1))
  280. # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  281. if mask_shfit_chunk is not None:
  282. # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
  283. mask = mask * mask_shfit_chunk
  284. # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  285. # print("in fsmn, mask", mask.size())
  286. # print("in fsmn, inputs", inputs.size())
  287. inputs = inputs * mask
  288. x = inputs.transpose(1, 2)
  289. b, d, t = x.size()
  290. if cache is None:
  291. # print("in fsmn, cache is None, x", x.size())
  292. x = self.pad_fn(x)
  293. if not self.training:
  294. cache = x
  295. else:
  296. # print("in fsmn, cache is not None, x", x.size())
  297. # x = torch.cat((x, cache), dim=2)[:, :, :-1]
  298. # if t < self.kernel_size:
  299. # x = self.pad_fn(x)
  300. x = torch.cat((cache[:, :, 1:], x), dim=2)
  301. x = x[:, :, -(self.kernel_size+t-1):]
  302. # print("in fsmn, cache is not None, x_cat", x.size())
  303. cache = x
  304. x = self.fsmn_block(x)
  305. x = x.transpose(1, 2)
  306. # print("in fsmn, fsmn_out", x.size())
  307. if x.size(1) != inputs.size(1):
  308. inputs = inputs[:, -1, :]
  309. x = x + inputs
  310. x = self.dropout(x)
  311. if mask is not None:
  312. x = x * mask
  313. return x, cache
  314. class MultiHeadedAttentionCrossAtt(nn.Module):
  315. """Multi-Head Attention layer.
  316. Args:
  317. n_head (int): The number of heads.
  318. n_feat (int): The number of features.
  319. dropout_rate (float): Dropout rate.
  320. """
  321. def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
  322. """Construct an MultiHeadedAttention object."""
  323. super(MultiHeadedAttentionCrossAtt, self).__init__()
  324. assert n_feat % n_head == 0
  325. # We assume d_v always equals d_k
  326. self.d_k = n_feat // n_head
  327. self.h = n_head
  328. if lora_list is not None:
  329. if "q" in lora_list:
  330. self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
  331. else:
  332. self.linear_q = nn.Linear(n_feat, n_feat)
  333. lora_kv_list = ["k" in lora_list, "v" in lora_list]
  334. if lora_kv_list == [False, False]:
  335. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  336. else:
  337. self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
  338. r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
  339. if "o" in lora_list:
  340. self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
  341. else:
  342. self.linear_out = nn.Linear(n_feat, n_feat)
  343. else:
  344. self.linear_q = nn.Linear(n_feat, n_feat)
  345. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  346. self.linear_out = nn.Linear(n_feat, n_feat)
  347. self.attn = None
  348. self.dropout = nn.Dropout(p=dropout_rate)
  349. def forward_qkv(self, x, memory):
  350. """Transform query, key and value.
  351. Args:
  352. query (torch.Tensor): Query tensor (#batch, time1, size).
  353. key (torch.Tensor): Key tensor (#batch, time2, size).
  354. value (torch.Tensor): Value tensor (#batch, time2, size).
  355. Returns:
  356. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  357. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  358. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  359. """
  360. # print("in forward_qkv, x", x.size())
  361. b = x.size(0)
  362. q = self.linear_q(x)
  363. q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  364. k_v = self.linear_k_v(memory)
  365. k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
  366. k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  367. v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  368. return q_h, k_h, v_h
  369. def forward_attention(self, value, scores, mask, ret_attn=False):
  370. """Compute attention context vector.
  371. Args:
  372. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  373. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  374. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  375. Returns:
  376. torch.Tensor: Transformed value (#batch, time1, d_model)
  377. weighted by the attention score (#batch, time1, time2).
  378. """
  379. n_batch = value.size(0)
  380. if mask is not None:
  381. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  382. min_value = float(
  383. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  384. )
  385. # logging.info(
  386. # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
  387. scores = scores.masked_fill(mask, min_value)
  388. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  389. mask, 0.0
  390. ) # (batch, head, time1, time2)
  391. else:
  392. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  393. p_attn = self.dropout(self.attn)
  394. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  395. x = (
  396. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  397. ) # (batch, time1, d_model)
  398. if ret_attn:
  399. return self.linear_out(x), self.attn # (batch, time1, d_model)
  400. return self.linear_out(x) # (batch, time1, d_model)
  401. def forward(self, x, memory, memory_mask, ret_attn=False):
  402. """Compute scaled dot product attention.
  403. Args:
  404. query (torch.Tensor): Query tensor (#batch, time1, size).
  405. key (torch.Tensor): Key tensor (#batch, time2, size).
  406. value (torch.Tensor): Value tensor (#batch, time2, size).
  407. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  408. (#batch, time1, time2).
  409. Returns:
  410. torch.Tensor: Output tensor (#batch, time1, d_model).
  411. """
  412. q_h, k_h, v_h = self.forward_qkv(x, memory)
  413. q_h = q_h * self.d_k ** (-0.5)
  414. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  415. return self.forward_attention(v_h, scores, memory_mask, ret_attn=ret_attn)
  416. def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
  417. """Compute scaled dot product attention.
  418. Args:
  419. query (torch.Tensor): Query tensor (#batch, time1, size).
  420. key (torch.Tensor): Key tensor (#batch, time2, size).
  421. value (torch.Tensor): Value tensor (#batch, time2, size).
  422. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  423. (#batch, time1, time2).
  424. Returns:
  425. torch.Tensor: Output tensor (#batch, time1, d_model).
  426. """
  427. q_h, k_h, v_h = self.forward_qkv(x, memory)
  428. if chunk_size is not None and look_back > 0:
  429. if cache is not None:
  430. k_h = torch.cat((cache["k"], k_h), dim=2)
  431. v_h = torch.cat((cache["v"], v_h), dim=2)
  432. cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
  433. cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
  434. else:
  435. cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
  436. "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
  437. cache = cache_tmp
  438. q_h = q_h * self.d_k ** (-0.5)
  439. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  440. return self.forward_attention(v_h, scores, None), cache
  441. class MultiHeadSelfAttention(nn.Module):
  442. """Multi-Head Attention layer.
  443. Args:
  444. n_head (int): The number of heads.
  445. n_feat (int): The number of features.
  446. dropout_rate (float): Dropout rate.
  447. """
  448. def __init__(self, n_head, in_feat, n_feat, dropout_rate):
  449. """Construct an MultiHeadedAttention object."""
  450. super(MultiHeadSelfAttention, self).__init__()
  451. assert n_feat % n_head == 0
  452. # We assume d_v always equals d_k
  453. self.d_k = n_feat // n_head
  454. self.h = n_head
  455. self.linear_out = nn.Linear(n_feat, n_feat)
  456. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  457. self.attn = None
  458. self.dropout = nn.Dropout(p=dropout_rate)
  459. def forward_qkv(self, x):
  460. """Transform query, key and value.
  461. Args:
  462. query (torch.Tensor): Query tensor (#batch, time1, size).
  463. key (torch.Tensor): Key tensor (#batch, time2, size).
  464. value (torch.Tensor): Value tensor (#batch, time2, size).
  465. Returns:
  466. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  467. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  468. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  469. """
  470. b, t, d = x.size()
  471. q_k_v = self.linear_q_k_v(x)
  472. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  473. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  474. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  475. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  476. return q_h, k_h, v_h, v
  477. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  478. """Compute attention context vector.
  479. Args:
  480. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  481. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  482. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  483. Returns:
  484. torch.Tensor: Transformed value (#batch, time1, d_model)
  485. weighted by the attention score (#batch, time1, time2).
  486. """
  487. n_batch = value.size(0)
  488. if mask is not None:
  489. if mask_att_chunk_encoder is not None:
  490. mask = mask * mask_att_chunk_encoder
  491. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  492. min_value = float(
  493. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  494. )
  495. scores = scores.masked_fill(mask, min_value)
  496. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  497. mask, 0.0
  498. ) # (batch, head, time1, time2)
  499. else:
  500. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  501. p_attn = self.dropout(self.attn)
  502. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  503. x = (
  504. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  505. ) # (batch, time1, d_model)
  506. return self.linear_out(x) # (batch, time1, d_model)
  507. def forward(self, x, mask, mask_att_chunk_encoder=None):
  508. """Compute scaled dot product attention.
  509. Args:
  510. query (torch.Tensor): Query tensor (#batch, time1, size).
  511. key (torch.Tensor): Key tensor (#batch, time2, size).
  512. value (torch.Tensor): Value tensor (#batch, time2, size).
  513. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  514. (#batch, time1, time2).
  515. Returns:
  516. torch.Tensor: Output tensor (#batch, time1, d_model).
  517. """
  518. q_h, k_h, v_h, v = self.forward_qkv(x)
  519. q_h = q_h * self.d_k ** (-0.5)
  520. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  521. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  522. return att_outs