attention.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Multi-Head Attention layer definition."""
  6. import math
  7. import numpy
  8. import torch
  9. from torch import nn
  10. from typing import Optional, Tuple
  11. class MultiHeadedAttention(nn.Module):
  12. """Multi-Head Attention layer.
  13. Args:
  14. n_head (int): The number of heads.
  15. n_feat (int): The number of features.
  16. dropout_rate (float): Dropout rate.
  17. """
  18. def __init__(self, n_head, n_feat, dropout_rate):
  19. """Construct an MultiHeadedAttention object."""
  20. super(MultiHeadedAttention, self).__init__()
  21. assert n_feat % n_head == 0
  22. # We assume d_v always equals d_k
  23. self.d_k = n_feat // n_head
  24. self.h = n_head
  25. self.linear_q = nn.Linear(n_feat, n_feat)
  26. self.linear_k = nn.Linear(n_feat, n_feat)
  27. self.linear_v = nn.Linear(n_feat, n_feat)
  28. self.linear_out = nn.Linear(n_feat, n_feat)
  29. self.attn = None
  30. self.dropout = nn.Dropout(p=dropout_rate)
  31. def forward_qkv(self, query, key, value):
  32. """Transform query, key and value.
  33. Args:
  34. query (torch.Tensor): Query tensor (#batch, time1, size).
  35. key (torch.Tensor): Key tensor (#batch, time2, size).
  36. value (torch.Tensor): Value tensor (#batch, time2, size).
  37. Returns:
  38. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  39. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  40. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  41. """
  42. n_batch = query.size(0)
  43. q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
  44. k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
  45. v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
  46. q = q.transpose(1, 2) # (batch, head, time1, d_k)
  47. k = k.transpose(1, 2) # (batch, head, time2, d_k)
  48. v = v.transpose(1, 2) # (batch, head, time2, d_k)
  49. return q, k, v
  50. def forward_attention(self, value, scores, mask):
  51. """Compute attention context vector.
  52. Args:
  53. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  54. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  55. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  56. Returns:
  57. torch.Tensor: Transformed value (#batch, time1, d_model)
  58. weighted by the attention score (#batch, time1, time2).
  59. """
  60. n_batch = value.size(0)
  61. if mask is not None:
  62. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  63. min_value = float(
  64. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  65. )
  66. scores = scores.masked_fill(mask, min_value)
  67. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  68. mask, 0.0
  69. ) # (batch, head, time1, time2)
  70. else:
  71. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  72. p_attn = self.dropout(self.attn)
  73. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  74. x = (
  75. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  76. ) # (batch, time1, d_model)
  77. return self.linear_out(x) # (batch, time1, d_model)
  78. def forward(self, query, key, value, mask):
  79. """Compute scaled dot product attention.
  80. Args:
  81. query (torch.Tensor): Query tensor (#batch, time1, size).
  82. key (torch.Tensor): Key tensor (#batch, time2, size).
  83. value (torch.Tensor): Value tensor (#batch, time2, size).
  84. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  85. (#batch, time1, time2).
  86. Returns:
  87. torch.Tensor: Output tensor (#batch, time1, d_model).
  88. """
  89. q, k, v = self.forward_qkv(query, key, value)
  90. scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
  91. return self.forward_attention(v, scores, mask)
  92. class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
  93. """Multi-Head Attention layer with relative position encoding (old version).
  94. Details can be found in https://github.com/espnet/espnet/pull/2816.
  95. Paper: https://arxiv.org/abs/1901.02860
  96. Args:
  97. n_head (int): The number of heads.
  98. n_feat (int): The number of features.
  99. dropout_rate (float): Dropout rate.
  100. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  101. """
  102. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  103. """Construct an RelPositionMultiHeadedAttention object."""
  104. super().__init__(n_head, n_feat, dropout_rate)
  105. self.zero_triu = zero_triu
  106. # linear transformation for positional encoding
  107. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  108. # these two learnable bias are used in matrix c and matrix d
  109. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  110. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  111. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  112. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  113. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  114. def rel_shift(self, x):
  115. """Compute relative positional encoding.
  116. Args:
  117. x (torch.Tensor): Input tensor (batch, head, time1, time2).
  118. Returns:
  119. torch.Tensor: Output tensor.
  120. """
  121. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  122. x_padded = torch.cat([zero_pad, x], dim=-1)
  123. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  124. x = x_padded[:, :, 1:].view_as(x)
  125. if self.zero_triu:
  126. ones = torch.ones((x.size(2), x.size(3)))
  127. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  128. return x
  129. def forward(self, query, key, value, pos_emb, mask):
  130. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  131. Args:
  132. query (torch.Tensor): Query tensor (#batch, time1, size).
  133. key (torch.Tensor): Key tensor (#batch, time2, size).
  134. value (torch.Tensor): Value tensor (#batch, time2, size).
  135. pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
  136. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  137. (#batch, time1, time2).
  138. Returns:
  139. torch.Tensor: Output tensor (#batch, time1, d_model).
  140. """
  141. q, k, v = self.forward_qkv(query, key, value)
  142. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  143. n_batch_pos = pos_emb.size(0)
  144. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  145. p = p.transpose(1, 2) # (batch, head, time1, d_k)
  146. # (batch, head, time1, d_k)
  147. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  148. # (batch, head, time1, d_k)
  149. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  150. # compute attention score
  151. # first compute matrix a and matrix c
  152. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  153. # (batch, head, time1, time2)
  154. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  155. # compute matrix b and matrix d
  156. # (batch, head, time1, time1)
  157. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  158. matrix_bd = self.rel_shift(matrix_bd)
  159. scores = (matrix_ac + matrix_bd) / math.sqrt(
  160. self.d_k
  161. ) # (batch, head, time1, time2)
  162. return self.forward_attention(v, scores, mask)
  163. class RelPositionMultiHeadedAttention(MultiHeadedAttention):
  164. """Multi-Head Attention layer with relative position encoding (new implementation).
  165. Details can be found in https://github.com/espnet/espnet/pull/2816.
  166. Paper: https://arxiv.org/abs/1901.02860
  167. Args:
  168. n_head (int): The number of heads.
  169. n_feat (int): The number of features.
  170. dropout_rate (float): Dropout rate.
  171. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  172. """
  173. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  174. """Construct an RelPositionMultiHeadedAttention object."""
  175. super().__init__(n_head, n_feat, dropout_rate)
  176. self.zero_triu = zero_triu
  177. # linear transformation for positional encoding
  178. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  179. # these two learnable bias are used in matrix c and matrix d
  180. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  181. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  182. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  183. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  184. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  185. def rel_shift(self, x):
  186. """Compute relative positional encoding.
  187. Args:
  188. x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
  189. time1 means the length of query vector.
  190. Returns:
  191. torch.Tensor: Output tensor.
  192. """
  193. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  194. x_padded = torch.cat([zero_pad, x], dim=-1)
  195. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  196. x = x_padded[:, :, 1:].view_as(x)[
  197. :, :, :, : x.size(-1) // 2 + 1
  198. ] # only keep the positions from 0 to time2
  199. if self.zero_triu:
  200. ones = torch.ones((x.size(2), x.size(3)), device=x.device)
  201. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  202. return x
  203. def forward(self, query, key, value, pos_emb, mask):
  204. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  205. Args:
  206. query (torch.Tensor): Query tensor (#batch, time1, size).
  207. key (torch.Tensor): Key tensor (#batch, time2, size).
  208. value (torch.Tensor): Value tensor (#batch, time2, size).
  209. pos_emb (torch.Tensor): Positional embedding tensor
  210. (#batch, 2*time1-1, size).
  211. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  212. (#batch, time1, time2).
  213. Returns:
  214. torch.Tensor: Output tensor (#batch, time1, d_model).
  215. """
  216. q, k, v = self.forward_qkv(query, key, value)
  217. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  218. n_batch_pos = pos_emb.size(0)
  219. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  220. p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
  221. # (batch, head, time1, d_k)
  222. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  223. # (batch, head, time1, d_k)
  224. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  225. # compute attention score
  226. # first compute matrix a and matrix c
  227. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  228. # (batch, head, time1, time2)
  229. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  230. # compute matrix b and matrix d
  231. # (batch, head, time1, 2*time1-1)
  232. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  233. matrix_bd = self.rel_shift(matrix_bd)
  234. scores = (matrix_ac + matrix_bd) / math.sqrt(
  235. self.d_k
  236. ) # (batch, head, time1, time2)
  237. return self.forward_attention(v, scores, mask)
  238. class MultiHeadedAttentionSANM(nn.Module):
  239. """Multi-Head Attention layer.
  240. Args:
  241. n_head (int): The number of heads.
  242. n_feat (int): The number of features.
  243. dropout_rate (float): Dropout rate.
  244. """
  245. def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  246. """Construct an MultiHeadedAttention object."""
  247. super(MultiHeadedAttentionSANM, self).__init__()
  248. assert n_feat % n_head == 0
  249. # We assume d_v always equals d_k
  250. self.d_k = n_feat // n_head
  251. self.h = n_head
  252. # self.linear_q = nn.Linear(n_feat, n_feat)
  253. # self.linear_k = nn.Linear(n_feat, n_feat)
  254. # self.linear_v = nn.Linear(n_feat, n_feat)
  255. self.linear_out = nn.Linear(n_feat, n_feat)
  256. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  257. self.attn = None
  258. self.dropout = nn.Dropout(p=dropout_rate)
  259. self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  260. # padding
  261. left_padding = (kernel_size - 1) // 2
  262. if sanm_shfit > 0:
  263. left_padding = left_padding + sanm_shfit
  264. right_padding = kernel_size - 1 - left_padding
  265. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  266. def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
  267. b, t, d = inputs.size()
  268. if mask is not None:
  269. mask = torch.reshape(mask, (b, -1, 1))
  270. if mask_shfit_chunk is not None:
  271. mask = mask * mask_shfit_chunk
  272. inputs = inputs * mask
  273. x = inputs.transpose(1, 2)
  274. x = self.pad_fn(x)
  275. x = self.fsmn_block(x)
  276. x = x.transpose(1, 2)
  277. x += inputs
  278. x = self.dropout(x)
  279. if mask is not None:
  280. x = x * mask
  281. return x
  282. def forward_qkv(self, x):
  283. """Transform query, key and value.
  284. Args:
  285. query (torch.Tensor): Query tensor (#batch, time1, size).
  286. key (torch.Tensor): Key tensor (#batch, time2, size).
  287. value (torch.Tensor): Value tensor (#batch, time2, size).
  288. Returns:
  289. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  290. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  291. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  292. """
  293. b, t, d = x.size()
  294. q_k_v = self.linear_q_k_v(x)
  295. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  296. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  297. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  298. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  299. return q_h, k_h, v_h, v
  300. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  301. """Compute attention context vector.
  302. Args:
  303. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  304. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  305. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  306. Returns:
  307. torch.Tensor: Transformed value (#batch, time1, d_model)
  308. weighted by the attention score (#batch, time1, time2).
  309. """
  310. n_batch = value.size(0)
  311. if mask is not None:
  312. if mask_att_chunk_encoder is not None:
  313. mask = mask * mask_att_chunk_encoder
  314. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  315. min_value = float(
  316. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  317. )
  318. scores = scores.masked_fill(mask, min_value)
  319. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  320. mask, 0.0
  321. ) # (batch, head, time1, time2)
  322. else:
  323. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  324. p_attn = self.dropout(self.attn)
  325. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  326. x = (
  327. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  328. ) # (batch, time1, d_model)
  329. return self.linear_out(x) # (batch, time1, d_model)
  330. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  331. """Compute scaled dot product attention.
  332. Args:
  333. query (torch.Tensor): Query tensor (#batch, time1, size).
  334. key (torch.Tensor): Key tensor (#batch, time2, size).
  335. value (torch.Tensor): Value tensor (#batch, time2, size).
  336. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  337. (#batch, time1, time2).
  338. Returns:
  339. torch.Tensor: Output tensor (#batch, time1, d_model).
  340. """
  341. q_h, k_h, v_h, v = self.forward_qkv(x)
  342. fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
  343. q_h = q_h * self.d_k ** (-0.5)
  344. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  345. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  346. return att_outs + fsmn_memory
  347. class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
  348. def __init__(self, *args, **kwargs):
  349. super().__init__(*args, **kwargs)
  350. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  351. q_h, k_h, v_h, v = self.forward_qkv(x)
  352. fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
  353. q_h = q_h * self.d_k ** (-0.5)
  354. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  355. att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
  356. return att_outs + fsmn_memory
  357. class MultiHeadedAttentionSANMDecoder(nn.Module):
  358. """Multi-Head Attention layer.
  359. Args:
  360. n_head (int): The number of heads.
  361. n_feat (int): The number of features.
  362. dropout_rate (float): Dropout rate.
  363. """
  364. def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  365. """Construct an MultiHeadedAttention object."""
  366. super(MultiHeadedAttentionSANMDecoder, self).__init__()
  367. self.dropout = nn.Dropout(p=dropout_rate)
  368. self.fsmn_block = nn.Conv1d(n_feat, n_feat,
  369. kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  370. # padding
  371. # padding
  372. left_padding = (kernel_size - 1) // 2
  373. if sanm_shfit > 0:
  374. left_padding = left_padding + sanm_shfit
  375. right_padding = kernel_size - 1 - left_padding
  376. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  377. self.kernel_size = kernel_size
  378. def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
  379. '''
  380. :param x: (#batch, time1, size).
  381. :param mask: Mask tensor (#batch, 1, time)
  382. :return:
  383. '''
  384. # print("in fsmn, inputs", inputs.size())
  385. b, t, d = inputs.size()
  386. # logging.info(
  387. # "mask: {}".format(mask.size()))
  388. if mask is not None:
  389. mask = torch.reshape(mask, (b ,-1, 1))
  390. # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  391. if mask_shfit_chunk is not None:
  392. # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
  393. mask = mask * mask_shfit_chunk
  394. # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  395. # print("in fsmn, mask", mask.size())
  396. # print("in fsmn, inputs", inputs.size())
  397. inputs = inputs * mask
  398. x = inputs.transpose(1, 2)
  399. b, d, t = x.size()
  400. if cache is None:
  401. # print("in fsmn, cache is None, x", x.size())
  402. x = self.pad_fn(x)
  403. if not self.training:
  404. cache = x
  405. else:
  406. # print("in fsmn, cache is not None, x", x.size())
  407. # x = torch.cat((x, cache), dim=2)[:, :, :-1]
  408. # if t < self.kernel_size:
  409. # x = self.pad_fn(x)
  410. x = torch.cat((cache[:, :, 1:], x), dim=2)
  411. x = x[:, :, -(self.kernel_size+t-1):]
  412. # print("in fsmn, cache is not None, x_cat", x.size())
  413. cache = x
  414. x = self.fsmn_block(x)
  415. x = x.transpose(1, 2)
  416. # print("in fsmn, fsmn_out", x.size())
  417. if x.size(1) != inputs.size(1):
  418. inputs = inputs[:, -1, :]
  419. x = x + inputs
  420. x = self.dropout(x)
  421. if mask is not None:
  422. x = x * mask
  423. return x, cache
  424. class MultiHeadedAttentionCrossAtt(nn.Module):
  425. """Multi-Head Attention layer.
  426. Args:
  427. n_head (int): The number of heads.
  428. n_feat (int): The number of features.
  429. dropout_rate (float): Dropout rate.
  430. """
  431. def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
  432. """Construct an MultiHeadedAttention object."""
  433. super(MultiHeadedAttentionCrossAtt, self).__init__()
  434. assert n_feat % n_head == 0
  435. # We assume d_v always equals d_k
  436. self.d_k = n_feat // n_head
  437. self.h = n_head
  438. self.linear_q = nn.Linear(n_feat, n_feat)
  439. # self.linear_k = nn.Linear(n_feat, n_feat)
  440. # self.linear_v = nn.Linear(n_feat, n_feat)
  441. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  442. self.linear_out = nn.Linear(n_feat, n_feat)
  443. self.attn = None
  444. self.dropout = nn.Dropout(p=dropout_rate)
  445. def forward_qkv(self, x, memory):
  446. """Transform query, key and value.
  447. Args:
  448. query (torch.Tensor): Query tensor (#batch, time1, size).
  449. key (torch.Tensor): Key tensor (#batch, time2, size).
  450. value (torch.Tensor): Value tensor (#batch, time2, size).
  451. Returns:
  452. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  453. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  454. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  455. """
  456. # print("in forward_qkv, x", x.size())
  457. b = x.size(0)
  458. q = self.linear_q(x)
  459. q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  460. k_v = self.linear_k_v(memory)
  461. k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
  462. k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  463. v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  464. return q_h, k_h, v_h
  465. def forward_attention(self, value, scores, mask):
  466. """Compute attention context vector.
  467. Args:
  468. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  469. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  470. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  471. Returns:
  472. torch.Tensor: Transformed value (#batch, time1, d_model)
  473. weighted by the attention score (#batch, time1, time2).
  474. """
  475. n_batch = value.size(0)
  476. if mask is not None:
  477. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  478. min_value = float(
  479. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  480. )
  481. # logging.info(
  482. # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
  483. scores = scores.masked_fill(mask, min_value)
  484. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  485. mask, 0.0
  486. ) # (batch, head, time1, time2)
  487. else:
  488. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  489. p_attn = self.dropout(self.attn)
  490. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  491. x = (
  492. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  493. ) # (batch, time1, d_model)
  494. return self.linear_out(x) # (batch, time1, d_model)
  495. def forward(self, x, memory, memory_mask):
  496. """Compute scaled dot product attention.
  497. Args:
  498. query (torch.Tensor): Query tensor (#batch, time1, size).
  499. key (torch.Tensor): Key tensor (#batch, time2, size).
  500. value (torch.Tensor): Value tensor (#batch, time2, size).
  501. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  502. (#batch, time1, time2).
  503. Returns:
  504. torch.Tensor: Output tensor (#batch, time1, d_model).
  505. """
  506. q_h, k_h, v_h = self.forward_qkv(x, memory)
  507. q_h = q_h * self.d_k ** (-0.5)
  508. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  509. return self.forward_attention(v_h, scores, memory_mask)
  510. class MultiHeadSelfAttention(nn.Module):
  511. """Multi-Head Attention layer.
  512. Args:
  513. n_head (int): The number of heads.
  514. n_feat (int): The number of features.
  515. dropout_rate (float): Dropout rate.
  516. """
  517. def __init__(self, n_head, in_feat, n_feat, dropout_rate):
  518. """Construct an MultiHeadedAttention object."""
  519. super(MultiHeadSelfAttention, self).__init__()
  520. assert n_feat % n_head == 0
  521. # We assume d_v always equals d_k
  522. self.d_k = n_feat // n_head
  523. self.h = n_head
  524. self.linear_out = nn.Linear(n_feat, n_feat)
  525. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  526. self.attn = None
  527. self.dropout = nn.Dropout(p=dropout_rate)
  528. def forward_qkv(self, x):
  529. """Transform query, key and value.
  530. Args:
  531. query (torch.Tensor): Query tensor (#batch, time1, size).
  532. key (torch.Tensor): Key tensor (#batch, time2, size).
  533. value (torch.Tensor): Value tensor (#batch, time2, size).
  534. Returns:
  535. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  536. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  537. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  538. """
  539. b, t, d = x.size()
  540. q_k_v = self.linear_q_k_v(x)
  541. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  542. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  543. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  544. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  545. return q_h, k_h, v_h, v
  546. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  547. """Compute attention context vector.
  548. Args:
  549. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  550. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  551. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  552. Returns:
  553. torch.Tensor: Transformed value (#batch, time1, d_model)
  554. weighted by the attention score (#batch, time1, time2).
  555. """
  556. n_batch = value.size(0)
  557. if mask is not None:
  558. if mask_att_chunk_encoder is not None:
  559. mask = mask * mask_att_chunk_encoder
  560. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  561. min_value = float(
  562. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  563. )
  564. scores = scores.masked_fill(mask, min_value)
  565. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  566. mask, 0.0
  567. ) # (batch, head, time1, time2)
  568. else:
  569. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  570. p_attn = self.dropout(self.attn)
  571. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  572. x = (
  573. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  574. ) # (batch, time1, d_model)
  575. return self.linear_out(x) # (batch, time1, d_model)
  576. def forward(self, x, mask, mask_att_chunk_encoder=None):
  577. """Compute scaled dot product attention.
  578. Args:
  579. query (torch.Tensor): Query tensor (#batch, time1, size).
  580. key (torch.Tensor): Key tensor (#batch, time2, size).
  581. value (torch.Tensor): Value tensor (#batch, time2, size).
  582. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  583. (#batch, time1, time2).
  584. Returns:
  585. torch.Tensor: Output tensor (#batch, time1, d_model).
  586. """
  587. q_h, k_h, v_h, v = self.forward_qkv(x)
  588. q_h = q_h * self.d_k ** (-0.5)
  589. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  590. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  591. return att_outs
  592. class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
  593. """RelPositionMultiHeadedAttention definition.
  594. Args:
  595. num_heads: Number of attention heads.
  596. embed_size: Embedding size.
  597. dropout_rate: Dropout rate.
  598. """
  599. def __init__(
  600. self,
  601. num_heads: int,
  602. embed_size: int,
  603. dropout_rate: float = 0.0,
  604. simplified_attention_score: bool = False,
  605. ) -> None:
  606. """Construct an MultiHeadedAttention object."""
  607. super().__init__()
  608. self.d_k = embed_size // num_heads
  609. self.num_heads = num_heads
  610. assert self.d_k * num_heads == embed_size, (
  611. "embed_size (%d) must be divisible by num_heads (%d)",
  612. (embed_size, num_heads),
  613. )
  614. self.linear_q = torch.nn.Linear(embed_size, embed_size)
  615. self.linear_k = torch.nn.Linear(embed_size, embed_size)
  616. self.linear_v = torch.nn.Linear(embed_size, embed_size)
  617. self.linear_out = torch.nn.Linear(embed_size, embed_size)
  618. if simplified_attention_score:
  619. self.linear_pos = torch.nn.Linear(embed_size, num_heads)
  620. self.compute_att_score = self.compute_simplified_attention_score
  621. else:
  622. self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
  623. self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
  624. self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
  625. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  626. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  627. self.compute_att_score = self.compute_attention_score
  628. self.dropout = torch.nn.Dropout(p=dropout_rate)
  629. self.attn = None
  630. def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
  631. """Compute relative positional encoding.
  632. Args:
  633. x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
  634. left_context: Number of frames in left context.
  635. Returns:
  636. x: Output sequence. (B, H, T_1, T_2)
  637. """
  638. batch_size, n_heads, time1, n = x.shape
  639. time2 = time1 + left_context
  640. batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
  641. return x.as_strided(
  642. (batch_size, n_heads, time1, time2),
  643. (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
  644. storage_offset=(n_stride * (time1 - 1)),
  645. )
  646. def compute_simplified_attention_score(
  647. self,
  648. query: torch.Tensor,
  649. key: torch.Tensor,
  650. pos_enc: torch.Tensor,
  651. left_context: int = 0,
  652. ) -> torch.Tensor:
  653. """Simplified attention score computation.
  654. Reference: https://github.com/k2-fsa/icefall/pull/458
  655. Args:
  656. query: Transformed query tensor. (B, H, T_1, d_k)
  657. key: Transformed key tensor. (B, H, T_2, d_k)
  658. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  659. left_context: Number of frames in left context.
  660. Returns:
  661. : Attention score. (B, H, T_1, T_2)
  662. """
  663. pos_enc = self.linear_pos(pos_enc)
  664. matrix_ac = torch.matmul(query, key.transpose(2, 3))
  665. matrix_bd = self.rel_shift(
  666. pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
  667. left_context=left_context,
  668. )
  669. return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
  670. def compute_attention_score(
  671. self,
  672. query: torch.Tensor,
  673. key: torch.Tensor,
  674. pos_enc: torch.Tensor,
  675. left_context: int = 0,
  676. ) -> torch.Tensor:
  677. """Attention score computation.
  678. Args:
  679. query: Transformed query tensor. (B, H, T_1, d_k)
  680. key: Transformed key tensor. (B, H, T_2, d_k)
  681. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  682. left_context: Number of frames in left context.
  683. Returns:
  684. : Attention score. (B, H, T_1, T_2)
  685. """
  686. p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
  687. query = query.transpose(1, 2)
  688. q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
  689. q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
  690. matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
  691. matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
  692. matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
  693. return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
  694. def forward_qkv(
  695. self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
  696. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  697. """Transform query, key and value.
  698. Args:
  699. query: Query tensor. (B, T_1, size)
  700. key: Key tensor. (B, T_2, size)
  701. v: Value tensor. (B, T_2, size)
  702. Returns:
  703. q: Transformed query tensor. (B, H, T_1, d_k)
  704. k: Transformed key tensor. (B, H, T_2, d_k)
  705. v: Transformed value tensor. (B, H, T_2, d_k)
  706. """
  707. n_batch = query.size(0)
  708. q = (
  709. self.linear_q(query)
  710. .view(n_batch, -1, self.num_heads, self.d_k)
  711. .transpose(1, 2)
  712. )
  713. k = (
  714. self.linear_k(key)
  715. .view(n_batch, -1, self.num_heads, self.d_k)
  716. .transpose(1, 2)
  717. )
  718. v = (
  719. self.linear_v(value)
  720. .view(n_batch, -1, self.num_heads, self.d_k)
  721. .transpose(1, 2)
  722. )
  723. return q, k, v
  724. def forward_attention(
  725. self,
  726. value: torch.Tensor,
  727. scores: torch.Tensor,
  728. mask: torch.Tensor,
  729. chunk_mask: Optional[torch.Tensor] = None,
  730. ) -> torch.Tensor:
  731. """Compute attention context vector.
  732. Args:
  733. value: Transformed value. (B, H, T_2, d_k)
  734. scores: Attention score. (B, H, T_1, T_2)
  735. mask: Source mask. (B, T_2)
  736. chunk_mask: Chunk mask. (T_1, T_1)
  737. Returns:
  738. attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
  739. """
  740. batch_size = scores.size(0)
  741. mask = mask.unsqueeze(1).unsqueeze(2)
  742. if chunk_mask is not None:
  743. mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
  744. scores = scores.masked_fill(mask, float("-inf"))
  745. self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
  746. attn_output = self.dropout(self.attn)
  747. attn_output = torch.matmul(attn_output, value)
  748. attn_output = self.linear_out(
  749. attn_output.transpose(1, 2)
  750. .contiguous()
  751. .view(batch_size, -1, self.num_heads * self.d_k)
  752. )
  753. return attn_output
  754. def forward(
  755. self,
  756. query: torch.Tensor,
  757. key: torch.Tensor,
  758. value: torch.Tensor,
  759. pos_enc: torch.Tensor,
  760. mask: torch.Tensor,
  761. chunk_mask: Optional[torch.Tensor] = None,
  762. left_context: int = 0,
  763. ) -> torch.Tensor:
  764. """Compute scaled dot product attention with rel. positional encoding.
  765. Args:
  766. query: Query tensor. (B, T_1, size)
  767. key: Key tensor. (B, T_2, size)
  768. value: Value tensor. (B, T_2, size)
  769. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  770. mask: Source mask. (B, T_2)
  771. chunk_mask: Chunk mask. (T_1, T_1)
  772. left_context: Number of frames in left context.
  773. Returns:
  774. : Output tensor. (B, T_1, H * d_k)
  775. """
  776. q, k, v = self.forward_qkv(query, key, value)
  777. scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
  778. return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)