attention.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Multi-Head Attention layer definition."""
  6. import math
  7. import numpy
  8. import torch
  9. from torch import nn
  10. from typing import Optional, Tuple
  11. import torch.nn.functional as F
  12. from funasr.modules.nets_utils import make_pad_mask
  13. class MultiHeadedAttention(nn.Module):
  14. """Multi-Head Attention layer.
  15. Args:
  16. n_head (int): The number of heads.
  17. n_feat (int): The number of features.
  18. dropout_rate (float): Dropout rate.
  19. """
  20. def __init__(self, n_head, n_feat, dropout_rate):
  21. """Construct an MultiHeadedAttention object."""
  22. super(MultiHeadedAttention, self).__init__()
  23. assert n_feat % n_head == 0
  24. # We assume d_v always equals d_k
  25. self.d_k = n_feat // n_head
  26. self.h = n_head
  27. self.linear_q = nn.Linear(n_feat, n_feat)
  28. self.linear_k = nn.Linear(n_feat, n_feat)
  29. self.linear_v = nn.Linear(n_feat, n_feat)
  30. self.linear_out = nn.Linear(n_feat, n_feat)
  31. self.attn = None
  32. self.dropout = nn.Dropout(p=dropout_rate)
  33. def forward_qkv(self, query, key, value):
  34. """Transform query, key and value.
  35. Args:
  36. query (torch.Tensor): Query tensor (#batch, time1, size).
  37. key (torch.Tensor): Key tensor (#batch, time2, size).
  38. value (torch.Tensor): Value tensor (#batch, time2, size).
  39. Returns:
  40. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  41. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  42. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  43. """
  44. n_batch = query.size(0)
  45. q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
  46. k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
  47. v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
  48. q = q.transpose(1, 2) # (batch, head, time1, d_k)
  49. k = k.transpose(1, 2) # (batch, head, time2, d_k)
  50. v = v.transpose(1, 2) # (batch, head, time2, d_k)
  51. return q, k, v
  52. def forward_attention(self, value, scores, mask):
  53. """Compute attention context vector.
  54. Args:
  55. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  56. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  57. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  58. Returns:
  59. torch.Tensor: Transformed value (#batch, time1, d_model)
  60. weighted by the attention score (#batch, time1, time2).
  61. """
  62. n_batch = value.size(0)
  63. if mask is not None:
  64. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  65. min_value = float(
  66. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  67. )
  68. scores = scores.masked_fill(mask, min_value)
  69. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  70. mask, 0.0
  71. ) # (batch, head, time1, time2)
  72. else:
  73. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  74. p_attn = self.dropout(self.attn)
  75. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  76. x = (
  77. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  78. ) # (batch, time1, d_model)
  79. return self.linear_out(x) # (batch, time1, d_model)
  80. def forward(self, query, key, value, mask):
  81. """Compute scaled dot product attention.
  82. Args:
  83. query (torch.Tensor): Query tensor (#batch, time1, size).
  84. key (torch.Tensor): Key tensor (#batch, time2, size).
  85. value (torch.Tensor): Value tensor (#batch, time2, size).
  86. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  87. (#batch, time1, time2).
  88. Returns:
  89. torch.Tensor: Output tensor (#batch, time1, d_model).
  90. """
  91. q, k, v = self.forward_qkv(query, key, value)
  92. scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
  93. return self.forward_attention(v, scores, mask)
  94. class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
  95. """Multi-Head Attention layer with relative position encoding (old version).
  96. Details can be found in https://github.com/espnet/espnet/pull/2816.
  97. Paper: https://arxiv.org/abs/1901.02860
  98. Args:
  99. n_head (int): The number of heads.
  100. n_feat (int): The number of features.
  101. dropout_rate (float): Dropout rate.
  102. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  103. """
  104. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  105. """Construct an RelPositionMultiHeadedAttention object."""
  106. super().__init__(n_head, n_feat, dropout_rate)
  107. self.zero_triu = zero_triu
  108. # linear transformation for positional encoding
  109. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  110. # these two learnable bias are used in matrix c and matrix d
  111. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  112. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  113. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  114. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  115. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  116. def rel_shift(self, x):
  117. """Compute relative positional encoding.
  118. Args:
  119. x (torch.Tensor): Input tensor (batch, head, time1, time2).
  120. Returns:
  121. torch.Tensor: Output tensor.
  122. """
  123. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  124. x_padded = torch.cat([zero_pad, x], dim=-1)
  125. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  126. x = x_padded[:, :, 1:].view_as(x)
  127. if self.zero_triu:
  128. ones = torch.ones((x.size(2), x.size(3)))
  129. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  130. return x
  131. def forward(self, query, key, value, pos_emb, mask):
  132. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  133. Args:
  134. query (torch.Tensor): Query tensor (#batch, time1, size).
  135. key (torch.Tensor): Key tensor (#batch, time2, size).
  136. value (torch.Tensor): Value tensor (#batch, time2, size).
  137. pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
  138. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  139. (#batch, time1, time2).
  140. Returns:
  141. torch.Tensor: Output tensor (#batch, time1, d_model).
  142. """
  143. q, k, v = self.forward_qkv(query, key, value)
  144. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  145. n_batch_pos = pos_emb.size(0)
  146. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  147. p = p.transpose(1, 2) # (batch, head, time1, d_k)
  148. # (batch, head, time1, d_k)
  149. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  150. # (batch, head, time1, d_k)
  151. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  152. # compute attention score
  153. # first compute matrix a and matrix c
  154. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  155. # (batch, head, time1, time2)
  156. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  157. # compute matrix b and matrix d
  158. # (batch, head, time1, time1)
  159. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  160. matrix_bd = self.rel_shift(matrix_bd)
  161. scores = (matrix_ac + matrix_bd) / math.sqrt(
  162. self.d_k
  163. ) # (batch, head, time1, time2)
  164. return self.forward_attention(v, scores, mask)
  165. class RelPositionMultiHeadedAttention(MultiHeadedAttention):
  166. """Multi-Head Attention layer with relative position encoding (new implementation).
  167. Details can be found in https://github.com/espnet/espnet/pull/2816.
  168. Paper: https://arxiv.org/abs/1901.02860
  169. Args:
  170. n_head (int): The number of heads.
  171. n_feat (int): The number of features.
  172. dropout_rate (float): Dropout rate.
  173. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  174. """
  175. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  176. """Construct an RelPositionMultiHeadedAttention object."""
  177. super().__init__(n_head, n_feat, dropout_rate)
  178. self.zero_triu = zero_triu
  179. # linear transformation for positional encoding
  180. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  181. # these two learnable bias are used in matrix c and matrix d
  182. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  183. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  184. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  185. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  186. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  187. def rel_shift(self, x):
  188. """Compute relative positional encoding.
  189. Args:
  190. x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
  191. time1 means the length of query vector.
  192. Returns:
  193. torch.Tensor: Output tensor.
  194. """
  195. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  196. x_padded = torch.cat([zero_pad, x], dim=-1)
  197. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  198. x = x_padded[:, :, 1:].view_as(x)[
  199. :, :, :, : x.size(-1) // 2 + 1
  200. ] # only keep the positions from 0 to time2
  201. if self.zero_triu:
  202. ones = torch.ones((x.size(2), x.size(3)), device=x.device)
  203. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  204. return x
  205. def forward(self, query, key, value, pos_emb, mask):
  206. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  207. Args:
  208. query (torch.Tensor): Query tensor (#batch, time1, size).
  209. key (torch.Tensor): Key tensor (#batch, time2, size).
  210. value (torch.Tensor): Value tensor (#batch, time2, size).
  211. pos_emb (torch.Tensor): Positional embedding tensor
  212. (#batch, 2*time1-1, size).
  213. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  214. (#batch, time1, time2).
  215. Returns:
  216. torch.Tensor: Output tensor (#batch, time1, d_model).
  217. """
  218. q, k, v = self.forward_qkv(query, key, value)
  219. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  220. n_batch_pos = pos_emb.size(0)
  221. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  222. p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
  223. # (batch, head, time1, d_k)
  224. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  225. # (batch, head, time1, d_k)
  226. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  227. # compute attention score
  228. # first compute matrix a and matrix c
  229. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  230. # (batch, head, time1, time2)
  231. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  232. # compute matrix b and matrix d
  233. # (batch, head, time1, 2*time1-1)
  234. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  235. matrix_bd = self.rel_shift(matrix_bd)
  236. scores = (matrix_ac + matrix_bd) / math.sqrt(
  237. self.d_k
  238. ) # (batch, head, time1, time2)
  239. return self.forward_attention(v, scores, mask)
  240. class MultiHeadedAttentionSANM(nn.Module):
  241. """Multi-Head Attention layer.
  242. Args:
  243. n_head (int): The number of heads.
  244. n_feat (int): The number of features.
  245. dropout_rate (float): Dropout rate.
  246. """
  247. def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  248. """Construct an MultiHeadedAttention object."""
  249. super(MultiHeadedAttentionSANM, self).__init__()
  250. assert n_feat % n_head == 0
  251. # We assume d_v always equals d_k
  252. self.d_k = n_feat // n_head
  253. self.h = n_head
  254. # self.linear_q = nn.Linear(n_feat, n_feat)
  255. # self.linear_k = nn.Linear(n_feat, n_feat)
  256. # self.linear_v = nn.Linear(n_feat, n_feat)
  257. self.linear_out = nn.Linear(n_feat, n_feat)
  258. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  259. self.attn = None
  260. self.dropout = nn.Dropout(p=dropout_rate)
  261. self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  262. # padding
  263. left_padding = (kernel_size - 1) // 2
  264. if sanm_shfit > 0:
  265. left_padding = left_padding + sanm_shfit
  266. right_padding = kernel_size - 1 - left_padding
  267. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  268. def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
  269. b, t, d = inputs.size()
  270. if mask is not None:
  271. mask = torch.reshape(mask, (b, -1, 1))
  272. if mask_shfit_chunk is not None:
  273. mask = mask * mask_shfit_chunk
  274. inputs = inputs * mask
  275. x = inputs.transpose(1, 2)
  276. x = self.pad_fn(x)
  277. x = self.fsmn_block(x)
  278. x = x.transpose(1, 2)
  279. x += inputs
  280. x = self.dropout(x)
  281. if mask is not None:
  282. x = x * mask
  283. return x
  284. def forward_qkv(self, x):
  285. """Transform query, key and value.
  286. Args:
  287. query (torch.Tensor): Query tensor (#batch, time1, size).
  288. key (torch.Tensor): Key tensor (#batch, time2, size).
  289. value (torch.Tensor): Value tensor (#batch, time2, size).
  290. Returns:
  291. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  292. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  293. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  294. """
  295. b, t, d = x.size()
  296. q_k_v = self.linear_q_k_v(x)
  297. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  298. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  299. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  300. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  301. return q_h, k_h, v_h, v
  302. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  303. """Compute attention context vector.
  304. Args:
  305. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  306. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  307. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  308. Returns:
  309. torch.Tensor: Transformed value (#batch, time1, d_model)
  310. weighted by the attention score (#batch, time1, time2).
  311. """
  312. n_batch = value.size(0)
  313. if mask is not None:
  314. if mask_att_chunk_encoder is not None:
  315. mask = mask * mask_att_chunk_encoder
  316. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  317. min_value = float(
  318. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  319. )
  320. scores = scores.masked_fill(mask, min_value)
  321. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  322. mask, 0.0
  323. ) # (batch, head, time1, time2)
  324. else:
  325. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  326. p_attn = self.dropout(self.attn)
  327. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  328. x = (
  329. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  330. ) # (batch, time1, d_model)
  331. return self.linear_out(x) # (batch, time1, d_model)
  332. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  333. """Compute scaled dot product attention.
  334. Args:
  335. query (torch.Tensor): Query tensor (#batch, time1, size).
  336. key (torch.Tensor): Key tensor (#batch, time2, size).
  337. value (torch.Tensor): Value tensor (#batch, time2, size).
  338. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  339. (#batch, time1, time2).
  340. Returns:
  341. torch.Tensor: Output tensor (#batch, time1, d_model).
  342. """
  343. q_h, k_h, v_h, v = self.forward_qkv(x)
  344. fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
  345. q_h = q_h * self.d_k ** (-0.5)
  346. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  347. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  348. return att_outs + fsmn_memory
  349. class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
  350. def __init__(self, *args, **kwargs):
  351. super().__init__(*args, **kwargs)
  352. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  353. q_h, k_h, v_h, v = self.forward_qkv(x)
  354. fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
  355. q_h = q_h * self.d_k ** (-0.5)
  356. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  357. att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
  358. return att_outs + fsmn_memory
  359. class MultiHeadedAttentionSANMDecoder(nn.Module):
  360. """Multi-Head Attention layer.
  361. Args:
  362. n_head (int): The number of heads.
  363. n_feat (int): The number of features.
  364. dropout_rate (float): Dropout rate.
  365. """
  366. def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  367. """Construct an MultiHeadedAttention object."""
  368. super(MultiHeadedAttentionSANMDecoder, self).__init__()
  369. self.dropout = nn.Dropout(p=dropout_rate)
  370. self.fsmn_block = nn.Conv1d(n_feat, n_feat,
  371. kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  372. # padding
  373. # padding
  374. left_padding = (kernel_size - 1) // 2
  375. if sanm_shfit > 0:
  376. left_padding = left_padding + sanm_shfit
  377. right_padding = kernel_size - 1 - left_padding
  378. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  379. self.kernel_size = kernel_size
  380. def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
  381. '''
  382. :param x: (#batch, time1, size).
  383. :param mask: Mask tensor (#batch, 1, time)
  384. :return:
  385. '''
  386. # print("in fsmn, inputs", inputs.size())
  387. b, t, d = inputs.size()
  388. # logging.info(
  389. # "mask: {}".format(mask.size()))
  390. if mask is not None:
  391. mask = torch.reshape(mask, (b ,-1, 1))
  392. # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  393. if mask_shfit_chunk is not None:
  394. # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
  395. mask = mask * mask_shfit_chunk
  396. # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  397. # print("in fsmn, mask", mask.size())
  398. # print("in fsmn, inputs", inputs.size())
  399. inputs = inputs * mask
  400. x = inputs.transpose(1, 2)
  401. b, d, t = x.size()
  402. if cache is None:
  403. # print("in fsmn, cache is None, x", x.size())
  404. x = self.pad_fn(x)
  405. if not self.training:
  406. cache = x
  407. else:
  408. # print("in fsmn, cache is not None, x", x.size())
  409. # x = torch.cat((x, cache), dim=2)[:, :, :-1]
  410. # if t < self.kernel_size:
  411. # x = self.pad_fn(x)
  412. x = torch.cat((cache[:, :, 1:], x), dim=2)
  413. x = x[:, :, -(self.kernel_size+t-1):]
  414. # print("in fsmn, cache is not None, x_cat", x.size())
  415. cache = x
  416. x = self.fsmn_block(x)
  417. x = x.transpose(1, 2)
  418. # print("in fsmn, fsmn_out", x.size())
  419. if x.size(1) != inputs.size(1):
  420. inputs = inputs[:, -1, :]
  421. x = x + inputs
  422. x = self.dropout(x)
  423. if mask is not None:
  424. x = x * mask
  425. return x, cache
  426. class MultiHeadedAttentionCrossAtt(nn.Module):
  427. """Multi-Head Attention layer.
  428. Args:
  429. n_head (int): The number of heads.
  430. n_feat (int): The number of features.
  431. dropout_rate (float): Dropout rate.
  432. """
  433. def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
  434. """Construct an MultiHeadedAttention object."""
  435. super(MultiHeadedAttentionCrossAtt, self).__init__()
  436. assert n_feat % n_head == 0
  437. # We assume d_v always equals d_k
  438. self.d_k = n_feat // n_head
  439. self.h = n_head
  440. self.linear_q = nn.Linear(n_feat, n_feat)
  441. # self.linear_k = nn.Linear(n_feat, n_feat)
  442. # self.linear_v = nn.Linear(n_feat, n_feat)
  443. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  444. self.linear_out = nn.Linear(n_feat, n_feat)
  445. self.attn = None
  446. self.dropout = nn.Dropout(p=dropout_rate)
  447. def forward_qkv(self, x, memory):
  448. """Transform query, key and value.
  449. Args:
  450. query (torch.Tensor): Query tensor (#batch, time1, size).
  451. key (torch.Tensor): Key tensor (#batch, time2, size).
  452. value (torch.Tensor): Value tensor (#batch, time2, size).
  453. Returns:
  454. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  455. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  456. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  457. """
  458. # print("in forward_qkv, x", x.size())
  459. b = x.size(0)
  460. q = self.linear_q(x)
  461. q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  462. k_v = self.linear_k_v(memory)
  463. k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
  464. k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  465. v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  466. return q_h, k_h, v_h
  467. def forward_attention(self, value, scores, mask):
  468. """Compute attention context vector.
  469. Args:
  470. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  471. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  472. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  473. Returns:
  474. torch.Tensor: Transformed value (#batch, time1, d_model)
  475. weighted by the attention score (#batch, time1, time2).
  476. """
  477. n_batch = value.size(0)
  478. if mask is not None:
  479. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  480. min_value = float(
  481. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  482. )
  483. # logging.info(
  484. # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
  485. scores = scores.masked_fill(mask, min_value)
  486. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  487. mask, 0.0
  488. ) # (batch, head, time1, time2)
  489. else:
  490. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  491. p_attn = self.dropout(self.attn)
  492. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  493. x = (
  494. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  495. ) # (batch, time1, d_model)
  496. return self.linear_out(x) # (batch, time1, d_model)
  497. def forward(self, x, memory, memory_mask):
  498. """Compute scaled dot product attention.
  499. Args:
  500. query (torch.Tensor): Query tensor (#batch, time1, size).
  501. key (torch.Tensor): Key tensor (#batch, time2, size).
  502. value (torch.Tensor): Value tensor (#batch, time2, size).
  503. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  504. (#batch, time1, time2).
  505. Returns:
  506. torch.Tensor: Output tensor (#batch, time1, d_model).
  507. """
  508. q_h, k_h, v_h = self.forward_qkv(x, memory)
  509. q_h = q_h * self.d_k ** (-0.5)
  510. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  511. return self.forward_attention(v_h, scores, memory_mask)
  512. class MultiHeadSelfAttention(nn.Module):
  513. """Multi-Head Attention layer.
  514. Args:
  515. n_head (int): The number of heads.
  516. n_feat (int): The number of features.
  517. dropout_rate (float): Dropout rate.
  518. """
  519. def __init__(self, n_head, in_feat, n_feat, dropout_rate):
  520. """Construct an MultiHeadedAttention object."""
  521. super(MultiHeadSelfAttention, self).__init__()
  522. assert n_feat % n_head == 0
  523. # We assume d_v always equals d_k
  524. self.d_k = n_feat // n_head
  525. self.h = n_head
  526. self.linear_out = nn.Linear(n_feat, n_feat)
  527. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  528. self.attn = None
  529. self.dropout = nn.Dropout(p=dropout_rate)
  530. def forward_qkv(self, x):
  531. """Transform query, key and value.
  532. Args:
  533. query (torch.Tensor): Query tensor (#batch, time1, size).
  534. key (torch.Tensor): Key tensor (#batch, time2, size).
  535. value (torch.Tensor): Value tensor (#batch, time2, size).
  536. Returns:
  537. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  538. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  539. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  540. """
  541. b, t, d = x.size()
  542. q_k_v = self.linear_q_k_v(x)
  543. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  544. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  545. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  546. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  547. return q_h, k_h, v_h, v
  548. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  549. """Compute attention context vector.
  550. Args:
  551. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  552. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  553. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  554. Returns:
  555. torch.Tensor: Transformed value (#batch, time1, d_model)
  556. weighted by the attention score (#batch, time1, time2).
  557. """
  558. n_batch = value.size(0)
  559. if mask is not None:
  560. if mask_att_chunk_encoder is not None:
  561. mask = mask * mask_att_chunk_encoder
  562. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  563. min_value = float(
  564. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  565. )
  566. scores = scores.masked_fill(mask, min_value)
  567. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  568. mask, 0.0
  569. ) # (batch, head, time1, time2)
  570. else:
  571. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  572. p_attn = self.dropout(self.attn)
  573. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  574. x = (
  575. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  576. ) # (batch, time1, d_model)
  577. return self.linear_out(x) # (batch, time1, d_model)
  578. def forward(self, x, mask, mask_att_chunk_encoder=None):
  579. """Compute scaled dot product attention.
  580. Args:
  581. query (torch.Tensor): Query tensor (#batch, time1, size).
  582. key (torch.Tensor): Key tensor (#batch, time2, size).
  583. value (torch.Tensor): Value tensor (#batch, time2, size).
  584. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  585. (#batch, time1, time2).
  586. Returns:
  587. torch.Tensor: Output tensor (#batch, time1, d_model).
  588. """
  589. q_h, k_h, v_h, v = self.forward_qkv(x)
  590. q_h = q_h * self.d_k ** (-0.5)
  591. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  592. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  593. return att_outs
  594. class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
  595. """RelPositionMultiHeadedAttention definition.
  596. Args:
  597. num_heads: Number of attention heads.
  598. embed_size: Embedding size.
  599. dropout_rate: Dropout rate.
  600. """
  601. def __init__(
  602. self,
  603. num_heads: int,
  604. embed_size: int,
  605. dropout_rate: float = 0.0,
  606. simplified_attention_score: bool = False,
  607. ) -> None:
  608. """Construct an MultiHeadedAttention object."""
  609. super().__init__()
  610. self.d_k = embed_size // num_heads
  611. self.num_heads = num_heads
  612. assert self.d_k * num_heads == embed_size, (
  613. "embed_size (%d) must be divisible by num_heads (%d)",
  614. (embed_size, num_heads),
  615. )
  616. self.linear_q = torch.nn.Linear(embed_size, embed_size)
  617. self.linear_k = torch.nn.Linear(embed_size, embed_size)
  618. self.linear_v = torch.nn.Linear(embed_size, embed_size)
  619. self.linear_out = torch.nn.Linear(embed_size, embed_size)
  620. if simplified_attention_score:
  621. self.linear_pos = torch.nn.Linear(embed_size, num_heads)
  622. self.compute_att_score = self.compute_simplified_attention_score
  623. else:
  624. self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
  625. self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
  626. self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
  627. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  628. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  629. self.compute_att_score = self.compute_attention_score
  630. self.dropout = torch.nn.Dropout(p=dropout_rate)
  631. self.attn = None
  632. def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
  633. """Compute relative positional encoding.
  634. Args:
  635. x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
  636. left_context: Number of frames in left context.
  637. Returns:
  638. x: Output sequence. (B, H, T_1, T_2)
  639. """
  640. batch_size, n_heads, time1, n = x.shape
  641. time2 = time1 + left_context
  642. batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
  643. return x.as_strided(
  644. (batch_size, n_heads, time1, time2),
  645. (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
  646. storage_offset=(n_stride * (time1 - 1)),
  647. )
  648. def compute_simplified_attention_score(
  649. self,
  650. query: torch.Tensor,
  651. key: torch.Tensor,
  652. pos_enc: torch.Tensor,
  653. left_context: int = 0,
  654. ) -> torch.Tensor:
  655. """Simplified attention score computation.
  656. Reference: https://github.com/k2-fsa/icefall/pull/458
  657. Args:
  658. query: Transformed query tensor. (B, H, T_1, d_k)
  659. key: Transformed key tensor. (B, H, T_2, d_k)
  660. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  661. left_context: Number of frames in left context.
  662. Returns:
  663. : Attention score. (B, H, T_1, T_2)
  664. """
  665. pos_enc = self.linear_pos(pos_enc)
  666. matrix_ac = torch.matmul(query, key.transpose(2, 3))
  667. matrix_bd = self.rel_shift(
  668. pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
  669. left_context=left_context,
  670. )
  671. return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
  672. def compute_attention_score(
  673. self,
  674. query: torch.Tensor,
  675. key: torch.Tensor,
  676. pos_enc: torch.Tensor,
  677. left_context: int = 0,
  678. ) -> torch.Tensor:
  679. """Attention score computation.
  680. Args:
  681. query: Transformed query tensor. (B, H, T_1, d_k)
  682. key: Transformed key tensor. (B, H, T_2, d_k)
  683. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  684. left_context: Number of frames in left context.
  685. Returns:
  686. : Attention score. (B, H, T_1, T_2)
  687. """
  688. p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
  689. query = query.transpose(1, 2)
  690. q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
  691. q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
  692. matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
  693. matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
  694. matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
  695. return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
  696. def forward_qkv(
  697. self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
  698. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  699. """Transform query, key and value.
  700. Args:
  701. query: Query tensor. (B, T_1, size)
  702. key: Key tensor. (B, T_2, size)
  703. v: Value tensor. (B, T_2, size)
  704. Returns:
  705. q: Transformed query tensor. (B, H, T_1, d_k)
  706. k: Transformed key tensor. (B, H, T_2, d_k)
  707. v: Transformed value tensor. (B, H, T_2, d_k)
  708. """
  709. n_batch = query.size(0)
  710. q = (
  711. self.linear_q(query)
  712. .view(n_batch, -1, self.num_heads, self.d_k)
  713. .transpose(1, 2)
  714. )
  715. k = (
  716. self.linear_k(key)
  717. .view(n_batch, -1, self.num_heads, self.d_k)
  718. .transpose(1, 2)
  719. )
  720. v = (
  721. self.linear_v(value)
  722. .view(n_batch, -1, self.num_heads, self.d_k)
  723. .transpose(1, 2)
  724. )
  725. return q, k, v
  726. def forward_attention(
  727. self,
  728. value: torch.Tensor,
  729. scores: torch.Tensor,
  730. mask: torch.Tensor,
  731. chunk_mask: Optional[torch.Tensor] = None,
  732. ) -> torch.Tensor:
  733. """Compute attention context vector.
  734. Args:
  735. value: Transformed value. (B, H, T_2, d_k)
  736. scores: Attention score. (B, H, T_1, T_2)
  737. mask: Source mask. (B, T_2)
  738. chunk_mask: Chunk mask. (T_1, T_1)
  739. Returns:
  740. attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
  741. """
  742. batch_size = scores.size(0)
  743. mask = mask.unsqueeze(1).unsqueeze(2)
  744. if chunk_mask is not None:
  745. mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
  746. scores = scores.masked_fill(mask, float("-inf"))
  747. self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
  748. attn_output = self.dropout(self.attn)
  749. attn_output = torch.matmul(attn_output, value)
  750. attn_output = self.linear_out(
  751. attn_output.transpose(1, 2)
  752. .contiguous()
  753. .view(batch_size, -1, self.num_heads * self.d_k)
  754. )
  755. return attn_output
  756. def forward(
  757. self,
  758. query: torch.Tensor,
  759. key: torch.Tensor,
  760. value: torch.Tensor,
  761. pos_enc: torch.Tensor,
  762. mask: torch.Tensor,
  763. chunk_mask: Optional[torch.Tensor] = None,
  764. left_context: int = 0,
  765. ) -> torch.Tensor:
  766. """Compute scaled dot product attention with rel. positional encoding.
  767. Args:
  768. query: Query tensor. (B, T_1, size)
  769. key: Key tensor. (B, T_2, size)
  770. value: Value tensor. (B, T_2, size)
  771. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  772. mask: Source mask. (B, T_2)
  773. chunk_mask: Chunk mask. (T_1, T_1)
  774. left_context: Number of frames in left context.
  775. Returns:
  776. : Output tensor. (B, T_1, H * d_k)
  777. """
  778. q, k, v = self.forward_qkv(query, key, value)
  779. scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
  780. return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
  781. class CosineDistanceAttention(nn.Module):
  782. """ Compute Cosine Distance between spk decoder output and speaker profile
  783. Args:
  784. profile_path: speaker profile file path (.npy file)
  785. """
  786. def __init__(self):
  787. super().__init__()
  788. self.softmax = nn.Softmax(dim=-1)
  789. def forward(self, spk_decoder_out, profile, profile_lens=None):
  790. """
  791. Args:
  792. spk_decoder_out(torch.Tensor):(B, L, D)
  793. spk_profiles(torch.Tensor):(B, N, D)
  794. """
  795. x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
  796. if profile_lens is not None:
  797. mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
  798. min_value = float(
  799. numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
  800. )
  801. weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
  802. weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
  803. else:
  804. x = x[:, -1:, :, :]
  805. weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
  806. weights = self.softmax(weights_not_softmax) # (B, 1, N)
  807. spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
  808. return spk_embedding, weights