attention.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Multi-Head Attention layer definition."""
  6. import math
  7. import numpy
  8. import torch
  9. from torch import nn
  10. from typing import Optional, Tuple
  11. import torch.nn.functional as F
  12. from funasr.modules.nets_utils import make_pad_mask
  13. import funasr.modules.lora.layers as lora
  14. class MultiHeadedAttention(nn.Module):
  15. """Multi-Head Attention layer.
  16. Args:
  17. n_head (int): The number of heads.
  18. n_feat (int): The number of features.
  19. dropout_rate (float): Dropout rate.
  20. """
  21. def __init__(self, n_head, n_feat, dropout_rate):
  22. """Construct an MultiHeadedAttention object."""
  23. super(MultiHeadedAttention, self).__init__()
  24. assert n_feat % n_head == 0
  25. # We assume d_v always equals d_k
  26. self.d_k = n_feat // n_head
  27. self.h = n_head
  28. self.linear_q = nn.Linear(n_feat, n_feat)
  29. self.linear_k = nn.Linear(n_feat, n_feat)
  30. self.linear_v = nn.Linear(n_feat, n_feat)
  31. self.linear_out = nn.Linear(n_feat, n_feat)
  32. self.attn = None
  33. self.dropout = nn.Dropout(p=dropout_rate)
  34. def forward_qkv(self, query, key, value):
  35. """Transform query, key and value.
  36. Args:
  37. query (torch.Tensor): Query tensor (#batch, time1, size).
  38. key (torch.Tensor): Key tensor (#batch, time2, size).
  39. value (torch.Tensor): Value tensor (#batch, time2, size).
  40. Returns:
  41. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  42. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  43. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  44. """
  45. n_batch = query.size(0)
  46. q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
  47. k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
  48. v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
  49. q = q.transpose(1, 2) # (batch, head, time1, d_k)
  50. k = k.transpose(1, 2) # (batch, head, time2, d_k)
  51. v = v.transpose(1, 2) # (batch, head, time2, d_k)
  52. return q, k, v
  53. def forward_attention(self, value, scores, mask):
  54. """Compute attention context vector.
  55. Args:
  56. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  57. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  58. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  59. Returns:
  60. torch.Tensor: Transformed value (#batch, time1, d_model)
  61. weighted by the attention score (#batch, time1, time2).
  62. """
  63. n_batch = value.size(0)
  64. if mask is not None:
  65. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  66. min_value = float(
  67. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  68. )
  69. scores = scores.masked_fill(mask, min_value)
  70. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  71. mask, 0.0
  72. ) # (batch, head, time1, time2)
  73. else:
  74. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  75. p_attn = self.dropout(self.attn)
  76. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  77. x = (
  78. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  79. ) # (batch, time1, d_model)
  80. return self.linear_out(x) # (batch, time1, d_model)
  81. def forward(self, query, key, value, mask):
  82. """Compute scaled dot product attention.
  83. Args:
  84. query (torch.Tensor): Query tensor (#batch, time1, size).
  85. key (torch.Tensor): Key tensor (#batch, time2, size).
  86. value (torch.Tensor): Value tensor (#batch, time2, size).
  87. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  88. (#batch, time1, time2).
  89. Returns:
  90. torch.Tensor: Output tensor (#batch, time1, d_model).
  91. """
  92. q, k, v = self.forward_qkv(query, key, value)
  93. scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
  94. return self.forward_attention(v, scores, mask)
  95. class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
  96. """Multi-Head Attention layer with relative position encoding (old version).
  97. Details can be found in https://github.com/espnet/espnet/pull/2816.
  98. Paper: https://arxiv.org/abs/1901.02860
  99. Args:
  100. n_head (int): The number of heads.
  101. n_feat (int): The number of features.
  102. dropout_rate (float): Dropout rate.
  103. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  104. """
  105. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  106. """Construct an RelPositionMultiHeadedAttention object."""
  107. super().__init__(n_head, n_feat, dropout_rate)
  108. self.zero_triu = zero_triu
  109. # linear transformation for positional encoding
  110. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  111. # these two learnable bias are used in matrix c and matrix d
  112. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  113. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  114. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  115. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  116. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  117. def rel_shift(self, x):
  118. """Compute relative positional encoding.
  119. Args:
  120. x (torch.Tensor): Input tensor (batch, head, time1, time2).
  121. Returns:
  122. torch.Tensor: Output tensor.
  123. """
  124. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  125. x_padded = torch.cat([zero_pad, x], dim=-1)
  126. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  127. x = x_padded[:, :, 1:].view_as(x)
  128. if self.zero_triu:
  129. ones = torch.ones((x.size(2), x.size(3)))
  130. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  131. return x
  132. def forward(self, query, key, value, pos_emb, mask):
  133. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  134. Args:
  135. query (torch.Tensor): Query tensor (#batch, time1, size).
  136. key (torch.Tensor): Key tensor (#batch, time2, size).
  137. value (torch.Tensor): Value tensor (#batch, time2, size).
  138. pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
  139. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  140. (#batch, time1, time2).
  141. Returns:
  142. torch.Tensor: Output tensor (#batch, time1, d_model).
  143. """
  144. q, k, v = self.forward_qkv(query, key, value)
  145. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  146. n_batch_pos = pos_emb.size(0)
  147. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  148. p = p.transpose(1, 2) # (batch, head, time1, d_k)
  149. # (batch, head, time1, d_k)
  150. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  151. # (batch, head, time1, d_k)
  152. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  153. # compute attention score
  154. # first compute matrix a and matrix c
  155. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  156. # (batch, head, time1, time2)
  157. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  158. # compute matrix b and matrix d
  159. # (batch, head, time1, time1)
  160. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  161. matrix_bd = self.rel_shift(matrix_bd)
  162. scores = (matrix_ac + matrix_bd) / math.sqrt(
  163. self.d_k
  164. ) # (batch, head, time1, time2)
  165. return self.forward_attention(v, scores, mask)
  166. class RelPositionMultiHeadedAttention(MultiHeadedAttention):
  167. """Multi-Head Attention layer with relative position encoding (new implementation).
  168. Details can be found in https://github.com/espnet/espnet/pull/2816.
  169. Paper: https://arxiv.org/abs/1901.02860
  170. Args:
  171. n_head (int): The number of heads.
  172. n_feat (int): The number of features.
  173. dropout_rate (float): Dropout rate.
  174. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  175. """
  176. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  177. """Construct an RelPositionMultiHeadedAttention object."""
  178. super().__init__(n_head, n_feat, dropout_rate)
  179. self.zero_triu = zero_triu
  180. # linear transformation for positional encoding
  181. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  182. # these two learnable bias are used in matrix c and matrix d
  183. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  184. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  185. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  186. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  187. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  188. def rel_shift(self, x):
  189. """Compute relative positional encoding.
  190. Args:
  191. x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
  192. time1 means the length of query vector.
  193. Returns:
  194. torch.Tensor: Output tensor.
  195. """
  196. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  197. x_padded = torch.cat([zero_pad, x], dim=-1)
  198. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  199. x = x_padded[:, :, 1:].view_as(x)[
  200. :, :, :, : x.size(-1) // 2 + 1
  201. ] # only keep the positions from 0 to time2
  202. if self.zero_triu:
  203. ones = torch.ones((x.size(2), x.size(3)), device=x.device)
  204. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  205. return x
  206. def forward(self, query, key, value, pos_emb, mask):
  207. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  208. Args:
  209. query (torch.Tensor): Query tensor (#batch, time1, size).
  210. key (torch.Tensor): Key tensor (#batch, time2, size).
  211. value (torch.Tensor): Value tensor (#batch, time2, size).
  212. pos_emb (torch.Tensor): Positional embedding tensor
  213. (#batch, 2*time1-1, size).
  214. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  215. (#batch, time1, time2).
  216. Returns:
  217. torch.Tensor: Output tensor (#batch, time1, d_model).
  218. """
  219. q, k, v = self.forward_qkv(query, key, value)
  220. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  221. n_batch_pos = pos_emb.size(0)
  222. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  223. p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
  224. # (batch, head, time1, d_k)
  225. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  226. # (batch, head, time1, d_k)
  227. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  228. # compute attention score
  229. # first compute matrix a and matrix c
  230. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  231. # (batch, head, time1, time2)
  232. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  233. # compute matrix b and matrix d
  234. # (batch, head, time1, 2*time1-1)
  235. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  236. matrix_bd = self.rel_shift(matrix_bd)
  237. scores = (matrix_ac + matrix_bd) / math.sqrt(
  238. self.d_k
  239. ) # (batch, head, time1, time2)
  240. return self.forward_attention(v, scores, mask)
  241. class MultiHeadedAttentionSANM(nn.Module):
  242. """Multi-Head Attention layer.
  243. Args:
  244. n_head (int): The number of heads.
  245. n_feat (int): The number of features.
  246. dropout_rate (float): Dropout rate.
  247. """
  248. def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
  249. """Construct an MultiHeadedAttention object."""
  250. super(MultiHeadedAttentionSANM, self).__init__()
  251. assert n_feat % n_head == 0
  252. # We assume d_v always equals d_k
  253. self.d_k = n_feat // n_head
  254. self.h = n_head
  255. # self.linear_q = nn.Linear(n_feat, n_feat)
  256. # self.linear_k = nn.Linear(n_feat, n_feat)
  257. # self.linear_v = nn.Linear(n_feat, n_feat)
  258. if lora_list is not None:
  259. if "o" in lora_list:
  260. self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
  261. else:
  262. self.linear_out = nn.Linear(n_feat, n_feat)
  263. lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
  264. if lora_qkv_list == [False, False, False]:
  265. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  266. else:
  267. self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
  268. else:
  269. self.linear_out = nn.Linear(n_feat, n_feat)
  270. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  271. self.attn = None
  272. self.dropout = nn.Dropout(p=dropout_rate)
  273. self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  274. # padding
  275. left_padding = (kernel_size - 1) // 2
  276. if sanm_shfit > 0:
  277. left_padding = left_padding + sanm_shfit
  278. right_padding = kernel_size - 1 - left_padding
  279. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  280. def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
  281. b, t, d = inputs.size()
  282. if mask is not None:
  283. mask = torch.reshape(mask, (b, -1, 1))
  284. if mask_shfit_chunk is not None:
  285. mask = mask * mask_shfit_chunk
  286. inputs = inputs * mask
  287. x = inputs.transpose(1, 2)
  288. x = self.pad_fn(x)
  289. x = self.fsmn_block(x)
  290. x = x.transpose(1, 2)
  291. x += inputs
  292. x = self.dropout(x)
  293. if mask is not None:
  294. x = x * mask
  295. return x
  296. def forward_qkv(self, x):
  297. """Transform query, key and value.
  298. Args:
  299. query (torch.Tensor): Query tensor (#batch, time1, size).
  300. key (torch.Tensor): Key tensor (#batch, time2, size).
  301. value (torch.Tensor): Value tensor (#batch, time2, size).
  302. Returns:
  303. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  304. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  305. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  306. """
  307. b, t, d = x.size()
  308. q_k_v = self.linear_q_k_v(x)
  309. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  310. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  311. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  312. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  313. return q_h, k_h, v_h, v
  314. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  315. """Compute attention context vector.
  316. Args:
  317. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  318. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  319. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  320. Returns:
  321. torch.Tensor: Transformed value (#batch, time1, d_model)
  322. weighted by the attention score (#batch, time1, time2).
  323. """
  324. n_batch = value.size(0)
  325. if mask is not None:
  326. if mask_att_chunk_encoder is not None:
  327. mask = mask * mask_att_chunk_encoder
  328. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  329. min_value = float(
  330. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  331. )
  332. scores = scores.masked_fill(mask, min_value)
  333. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  334. mask, 0.0
  335. ) # (batch, head, time1, time2)
  336. else:
  337. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  338. p_attn = self.dropout(self.attn)
  339. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  340. x = (
  341. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  342. ) # (batch, time1, d_model)
  343. return self.linear_out(x) # (batch, time1, d_model)
  344. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  345. """Compute scaled dot product attention.
  346. Args:
  347. query (torch.Tensor): Query tensor (#batch, time1, size).
  348. key (torch.Tensor): Key tensor (#batch, time2, size).
  349. value (torch.Tensor): Value tensor (#batch, time2, size).
  350. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  351. (#batch, time1, time2).
  352. Returns:
  353. torch.Tensor: Output tensor (#batch, time1, d_model).
  354. """
  355. q_h, k_h, v_h, v = self.forward_qkv(x)
  356. fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
  357. q_h = q_h * self.d_k ** (-0.5)
  358. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  359. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  360. return att_outs + fsmn_memory
  361. class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
  362. def __init__(self, *args, **kwargs):
  363. super().__init__(*args, **kwargs)
  364. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  365. q_h, k_h, v_h, v = self.forward_qkv(x)
  366. fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
  367. q_h = q_h * self.d_k ** (-0.5)
  368. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  369. att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
  370. return att_outs + fsmn_memory
  371. class MultiHeadedAttentionSANMDecoder(nn.Module):
  372. """Multi-Head Attention layer.
  373. Args:
  374. n_head (int): The number of heads.
  375. n_feat (int): The number of features.
  376. dropout_rate (float): Dropout rate.
  377. """
  378. def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  379. """Construct an MultiHeadedAttention object."""
  380. super(MultiHeadedAttentionSANMDecoder, self).__init__()
  381. self.dropout = nn.Dropout(p=dropout_rate)
  382. self.fsmn_block = nn.Conv1d(n_feat, n_feat,
  383. kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  384. # padding
  385. # padding
  386. left_padding = (kernel_size - 1) // 2
  387. if sanm_shfit > 0:
  388. left_padding = left_padding + sanm_shfit
  389. right_padding = kernel_size - 1 - left_padding
  390. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  391. self.kernel_size = kernel_size
  392. def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
  393. '''
  394. :param x: (#batch, time1, size).
  395. :param mask: Mask tensor (#batch, 1, time)
  396. :return:
  397. '''
  398. # print("in fsmn, inputs", inputs.size())
  399. b, t, d = inputs.size()
  400. # logging.info(
  401. # "mask: {}".format(mask.size()))
  402. if mask is not None:
  403. mask = torch.reshape(mask, (b ,-1, 1))
  404. # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  405. if mask_shfit_chunk is not None:
  406. # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
  407. mask = mask * mask_shfit_chunk
  408. # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  409. # print("in fsmn, mask", mask.size())
  410. # print("in fsmn, inputs", inputs.size())
  411. inputs = inputs * mask
  412. x = inputs.transpose(1, 2)
  413. b, d, t = x.size()
  414. if cache is None:
  415. # print("in fsmn, cache is None, x", x.size())
  416. x = self.pad_fn(x)
  417. if not self.training:
  418. cache = x
  419. else:
  420. # print("in fsmn, cache is not None, x", x.size())
  421. # x = torch.cat((x, cache), dim=2)[:, :, :-1]
  422. # if t < self.kernel_size:
  423. # x = self.pad_fn(x)
  424. x = torch.cat((cache[:, :, 1:], x), dim=2)
  425. x = x[:, :, -(self.kernel_size+t-1):]
  426. # print("in fsmn, cache is not None, x_cat", x.size())
  427. cache = x
  428. x = self.fsmn_block(x)
  429. x = x.transpose(1, 2)
  430. # print("in fsmn, fsmn_out", x.size())
  431. if x.size(1) != inputs.size(1):
  432. inputs = inputs[:, -1, :]
  433. x = x + inputs
  434. x = self.dropout(x)
  435. if mask is not None:
  436. x = x * mask
  437. return x, cache
  438. class MultiHeadedAttentionCrossAtt(nn.Module):
  439. """Multi-Head Attention layer.
  440. Args:
  441. n_head (int): The number of heads.
  442. n_feat (int): The number of features.
  443. dropout_rate (float): Dropout rate.
  444. """
  445. def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
  446. """Construct an MultiHeadedAttention object."""
  447. super(MultiHeadedAttentionCrossAtt, self).__init__()
  448. assert n_feat % n_head == 0
  449. # We assume d_v always equals d_k
  450. self.d_k = n_feat // n_head
  451. self.h = n_head
  452. if lora_list is not None:
  453. if "q" in lora_list:
  454. self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
  455. else:
  456. self.linear_q = nn.Linear(n_feat, n_feat)
  457. lora_kv_list = ["k" in lora_list, "v" in lora_list]
  458. if lora_kv_list == [False, False]:
  459. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  460. else:
  461. self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
  462. r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
  463. if "o" in lora_list:
  464. self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
  465. else:
  466. self.linear_out = nn.Linear(n_feat, n_feat)
  467. else:
  468. self.linear_q = nn.Linear(n_feat, n_feat)
  469. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  470. self.linear_out = nn.Linear(n_feat, n_feat)
  471. self.attn = None
  472. self.dropout = nn.Dropout(p=dropout_rate)
  473. def forward_qkv(self, x, memory):
  474. """Transform query, key and value.
  475. Args:
  476. query (torch.Tensor): Query tensor (#batch, time1, size).
  477. key (torch.Tensor): Key tensor (#batch, time2, size).
  478. value (torch.Tensor): Value tensor (#batch, time2, size).
  479. Returns:
  480. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  481. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  482. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  483. """
  484. # print("in forward_qkv, x", x.size())
  485. b = x.size(0)
  486. q = self.linear_q(x)
  487. q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  488. k_v = self.linear_k_v(memory)
  489. k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
  490. k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  491. v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  492. return q_h, k_h, v_h
  493. def forward_attention(self, value, scores, mask):
  494. """Compute attention context vector.
  495. Args:
  496. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  497. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  498. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  499. Returns:
  500. torch.Tensor: Transformed value (#batch, time1, d_model)
  501. weighted by the attention score (#batch, time1, time2).
  502. """
  503. n_batch = value.size(0)
  504. if mask is not None:
  505. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  506. min_value = float(
  507. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  508. )
  509. # logging.info(
  510. # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
  511. scores = scores.masked_fill(mask, min_value)
  512. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  513. mask, 0.0
  514. ) # (batch, head, time1, time2)
  515. else:
  516. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  517. p_attn = self.dropout(self.attn)
  518. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  519. x = (
  520. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  521. ) # (batch, time1, d_model)
  522. return self.linear_out(x) # (batch, time1, d_model)
  523. def forward(self, x, memory, memory_mask):
  524. """Compute scaled dot product attention.
  525. Args:
  526. query (torch.Tensor): Query tensor (#batch, time1, size).
  527. key (torch.Tensor): Key tensor (#batch, time2, size).
  528. value (torch.Tensor): Value tensor (#batch, time2, size).
  529. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  530. (#batch, time1, time2).
  531. Returns:
  532. torch.Tensor: Output tensor (#batch, time1, d_model).
  533. """
  534. q_h, k_h, v_h = self.forward_qkv(x, memory)
  535. q_h = q_h * self.d_k ** (-0.5)
  536. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  537. return self.forward_attention(v_h, scores, memory_mask)
  538. class MultiHeadSelfAttention(nn.Module):
  539. """Multi-Head Attention layer.
  540. Args:
  541. n_head (int): The number of heads.
  542. n_feat (int): The number of features.
  543. dropout_rate (float): Dropout rate.
  544. """
  545. def __init__(self, n_head, in_feat, n_feat, dropout_rate):
  546. """Construct an MultiHeadedAttention object."""
  547. super(MultiHeadSelfAttention, self).__init__()
  548. assert n_feat % n_head == 0
  549. # We assume d_v always equals d_k
  550. self.d_k = n_feat // n_head
  551. self.h = n_head
  552. self.linear_out = nn.Linear(n_feat, n_feat)
  553. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  554. self.attn = None
  555. self.dropout = nn.Dropout(p=dropout_rate)
  556. def forward_qkv(self, x):
  557. """Transform query, key and value.
  558. Args:
  559. query (torch.Tensor): Query tensor (#batch, time1, size).
  560. key (torch.Tensor): Key tensor (#batch, time2, size).
  561. value (torch.Tensor): Value tensor (#batch, time2, size).
  562. Returns:
  563. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  564. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  565. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  566. """
  567. b, t, d = x.size()
  568. q_k_v = self.linear_q_k_v(x)
  569. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  570. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  571. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  572. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  573. return q_h, k_h, v_h, v
  574. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  575. """Compute attention context vector.
  576. Args:
  577. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  578. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  579. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  580. Returns:
  581. torch.Tensor: Transformed value (#batch, time1, d_model)
  582. weighted by the attention score (#batch, time1, time2).
  583. """
  584. n_batch = value.size(0)
  585. if mask is not None:
  586. if mask_att_chunk_encoder is not None:
  587. mask = mask * mask_att_chunk_encoder
  588. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  589. min_value = float(
  590. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  591. )
  592. scores = scores.masked_fill(mask, min_value)
  593. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  594. mask, 0.0
  595. ) # (batch, head, time1, time2)
  596. else:
  597. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  598. p_attn = self.dropout(self.attn)
  599. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  600. x = (
  601. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  602. ) # (batch, time1, d_model)
  603. return self.linear_out(x) # (batch, time1, d_model)
  604. def forward(self, x, mask, mask_att_chunk_encoder=None):
  605. """Compute scaled dot product attention.
  606. Args:
  607. query (torch.Tensor): Query tensor (#batch, time1, size).
  608. key (torch.Tensor): Key tensor (#batch, time2, size).
  609. value (torch.Tensor): Value tensor (#batch, time2, size).
  610. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  611. (#batch, time1, time2).
  612. Returns:
  613. torch.Tensor: Output tensor (#batch, time1, d_model).
  614. """
  615. q_h, k_h, v_h, v = self.forward_qkv(x)
  616. q_h = q_h * self.d_k ** (-0.5)
  617. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  618. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  619. return att_outs
  620. class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
  621. """RelPositionMultiHeadedAttention definition.
  622. Args:
  623. num_heads: Number of attention heads.
  624. embed_size: Embedding size.
  625. dropout_rate: Dropout rate.
  626. """
  627. def __init__(
  628. self,
  629. num_heads: int,
  630. embed_size: int,
  631. dropout_rate: float = 0.0,
  632. simplified_attention_score: bool = False,
  633. ) -> None:
  634. """Construct an MultiHeadedAttention object."""
  635. super().__init__()
  636. self.d_k = embed_size // num_heads
  637. self.num_heads = num_heads
  638. assert self.d_k * num_heads == embed_size, (
  639. "embed_size (%d) must be divisible by num_heads (%d)",
  640. (embed_size, num_heads),
  641. )
  642. self.linear_q = torch.nn.Linear(embed_size, embed_size)
  643. self.linear_k = torch.nn.Linear(embed_size, embed_size)
  644. self.linear_v = torch.nn.Linear(embed_size, embed_size)
  645. self.linear_out = torch.nn.Linear(embed_size, embed_size)
  646. if simplified_attention_score:
  647. self.linear_pos = torch.nn.Linear(embed_size, num_heads)
  648. self.compute_att_score = self.compute_simplified_attention_score
  649. else:
  650. self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
  651. self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
  652. self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
  653. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  654. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  655. self.compute_att_score = self.compute_attention_score
  656. self.dropout = torch.nn.Dropout(p=dropout_rate)
  657. self.attn = None
  658. def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
  659. """Compute relative positional encoding.
  660. Args:
  661. x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
  662. left_context: Number of frames in left context.
  663. Returns:
  664. x: Output sequence. (B, H, T_1, T_2)
  665. """
  666. batch_size, n_heads, time1, n = x.shape
  667. time2 = time1 + left_context
  668. batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
  669. return x.as_strided(
  670. (batch_size, n_heads, time1, time2),
  671. (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
  672. storage_offset=(n_stride * (time1 - 1)),
  673. )
  674. def compute_simplified_attention_score(
  675. self,
  676. query: torch.Tensor,
  677. key: torch.Tensor,
  678. pos_enc: torch.Tensor,
  679. left_context: int = 0,
  680. ) -> torch.Tensor:
  681. """Simplified attention score computation.
  682. Reference: https://github.com/k2-fsa/icefall/pull/458
  683. Args:
  684. query: Transformed query tensor. (B, H, T_1, d_k)
  685. key: Transformed key tensor. (B, H, T_2, d_k)
  686. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  687. left_context: Number of frames in left context.
  688. Returns:
  689. : Attention score. (B, H, T_1, T_2)
  690. """
  691. pos_enc = self.linear_pos(pos_enc)
  692. matrix_ac = torch.matmul(query, key.transpose(2, 3))
  693. matrix_bd = self.rel_shift(
  694. pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
  695. left_context=left_context,
  696. )
  697. return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
  698. def compute_attention_score(
  699. self,
  700. query: torch.Tensor,
  701. key: torch.Tensor,
  702. pos_enc: torch.Tensor,
  703. left_context: int = 0,
  704. ) -> torch.Tensor:
  705. """Attention score computation.
  706. Args:
  707. query: Transformed query tensor. (B, H, T_1, d_k)
  708. key: Transformed key tensor. (B, H, T_2, d_k)
  709. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  710. left_context: Number of frames in left context.
  711. Returns:
  712. : Attention score. (B, H, T_1, T_2)
  713. """
  714. p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
  715. query = query.transpose(1, 2)
  716. q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
  717. q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
  718. matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
  719. matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
  720. matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
  721. return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
  722. def forward_qkv(
  723. self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
  724. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  725. """Transform query, key and value.
  726. Args:
  727. query: Query tensor. (B, T_1, size)
  728. key: Key tensor. (B, T_2, size)
  729. v: Value tensor. (B, T_2, size)
  730. Returns:
  731. q: Transformed query tensor. (B, H, T_1, d_k)
  732. k: Transformed key tensor. (B, H, T_2, d_k)
  733. v: Transformed value tensor. (B, H, T_2, d_k)
  734. """
  735. n_batch = query.size(0)
  736. q = (
  737. self.linear_q(query)
  738. .view(n_batch, -1, self.num_heads, self.d_k)
  739. .transpose(1, 2)
  740. )
  741. k = (
  742. self.linear_k(key)
  743. .view(n_batch, -1, self.num_heads, self.d_k)
  744. .transpose(1, 2)
  745. )
  746. v = (
  747. self.linear_v(value)
  748. .view(n_batch, -1, self.num_heads, self.d_k)
  749. .transpose(1, 2)
  750. )
  751. return q, k, v
  752. def forward_attention(
  753. self,
  754. value: torch.Tensor,
  755. scores: torch.Tensor,
  756. mask: torch.Tensor,
  757. chunk_mask: Optional[torch.Tensor] = None,
  758. ) -> torch.Tensor:
  759. """Compute attention context vector.
  760. Args:
  761. value: Transformed value. (B, H, T_2, d_k)
  762. scores: Attention score. (B, H, T_1, T_2)
  763. mask: Source mask. (B, T_2)
  764. chunk_mask: Chunk mask. (T_1, T_1)
  765. Returns:
  766. attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
  767. """
  768. batch_size = scores.size(0)
  769. mask = mask.unsqueeze(1).unsqueeze(2)
  770. if chunk_mask is not None:
  771. mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
  772. scores = scores.masked_fill(mask, float("-inf"))
  773. self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
  774. attn_output = self.dropout(self.attn)
  775. attn_output = torch.matmul(attn_output, value)
  776. attn_output = self.linear_out(
  777. attn_output.transpose(1, 2)
  778. .contiguous()
  779. .view(batch_size, -1, self.num_heads * self.d_k)
  780. )
  781. return attn_output
  782. def forward(
  783. self,
  784. query: torch.Tensor,
  785. key: torch.Tensor,
  786. value: torch.Tensor,
  787. pos_enc: torch.Tensor,
  788. mask: torch.Tensor,
  789. chunk_mask: Optional[torch.Tensor] = None,
  790. left_context: int = 0,
  791. ) -> torch.Tensor:
  792. """Compute scaled dot product attention with rel. positional encoding.
  793. Args:
  794. query: Query tensor. (B, T_1, size)
  795. key: Key tensor. (B, T_2, size)
  796. value: Value tensor. (B, T_2, size)
  797. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  798. mask: Source mask. (B, T_2)
  799. chunk_mask: Chunk mask. (T_1, T_1)
  800. left_context: Number of frames in left context.
  801. Returns:
  802. : Output tensor. (B, T_1, H * d_k)
  803. """
  804. q, k, v = self.forward_qkv(query, key, value)
  805. scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
  806. return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
  807. class CosineDistanceAttention(nn.Module):
  808. """ Compute Cosine Distance between spk decoder output and speaker profile
  809. Args:
  810. profile_path: speaker profile file path (.npy file)
  811. """
  812. def __init__(self):
  813. super().__init__()
  814. self.softmax = nn.Softmax(dim=-1)
  815. def forward(self, spk_decoder_out, profile, profile_lens=None):
  816. """
  817. Args:
  818. spk_decoder_out(torch.Tensor):(B, L, D)
  819. spk_profiles(torch.Tensor):(B, N, D)
  820. """
  821. x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
  822. if profile_lens is not None:
  823. mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
  824. min_value = float(
  825. numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
  826. )
  827. weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
  828. weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
  829. else:
  830. x = x[:, -1:, :, :]
  831. weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
  832. weights = self.softmax(weights_not_softmax) # (B, 1, N)
  833. spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
  834. return spk_embedding, weights