attention.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Multi-Head Attention layer definition."""
  6. import math
  7. import numpy
  8. import torch
  9. from torch import nn
  10. from typing import Optional, Tuple
  11. import torch.nn.functional as F
  12. from funasr.modules.nets_utils import make_pad_mask
  13. import funasr.modules.lora.layers as lora
  14. class MultiHeadedAttention(nn.Module):
  15. """Multi-Head Attention layer.
  16. Args:
  17. n_head (int): The number of heads.
  18. n_feat (int): The number of features.
  19. dropout_rate (float): Dropout rate.
  20. """
  21. def __init__(self, n_head, n_feat, dropout_rate):
  22. """Construct an MultiHeadedAttention object."""
  23. super(MultiHeadedAttention, self).__init__()
  24. assert n_feat % n_head == 0
  25. # We assume d_v always equals d_k
  26. self.d_k = n_feat // n_head
  27. self.h = n_head
  28. self.linear_q = nn.Linear(n_feat, n_feat)
  29. self.linear_k = nn.Linear(n_feat, n_feat)
  30. self.linear_v = nn.Linear(n_feat, n_feat)
  31. self.linear_out = nn.Linear(n_feat, n_feat)
  32. self.attn = None
  33. self.dropout = nn.Dropout(p=dropout_rate)
  34. def forward_qkv(self, query, key, value):
  35. """Transform query, key and value.
  36. Args:
  37. query (torch.Tensor): Query tensor (#batch, time1, size).
  38. key (torch.Tensor): Key tensor (#batch, time2, size).
  39. value (torch.Tensor): Value tensor (#batch, time2, size).
  40. Returns:
  41. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  42. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  43. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  44. """
  45. n_batch = query.size(0)
  46. q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
  47. k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
  48. v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
  49. q = q.transpose(1, 2) # (batch, head, time1, d_k)
  50. k = k.transpose(1, 2) # (batch, head, time2, d_k)
  51. v = v.transpose(1, 2) # (batch, head, time2, d_k)
  52. return q, k, v
  53. def forward_attention(self, value, scores, mask):
  54. """Compute attention context vector.
  55. Args:
  56. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  57. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  58. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  59. Returns:
  60. torch.Tensor: Transformed value (#batch, time1, d_model)
  61. weighted by the attention score (#batch, time1, time2).
  62. """
  63. n_batch = value.size(0)
  64. if mask is not None:
  65. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  66. min_value = float(
  67. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  68. )
  69. scores = scores.masked_fill(mask, min_value)
  70. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  71. mask, 0.0
  72. ) # (batch, head, time1, time2)
  73. else:
  74. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  75. p_attn = self.dropout(self.attn)
  76. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  77. x = (
  78. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  79. ) # (batch, time1, d_model)
  80. return self.linear_out(x) # (batch, time1, d_model)
  81. def forward(self, query, key, value, mask):
  82. """Compute scaled dot product attention.
  83. Args:
  84. query (torch.Tensor): Query tensor (#batch, time1, size).
  85. key (torch.Tensor): Key tensor (#batch, time2, size).
  86. value (torch.Tensor): Value tensor (#batch, time2, size).
  87. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  88. (#batch, time1, time2).
  89. Returns:
  90. torch.Tensor: Output tensor (#batch, time1, d_model).
  91. """
  92. q, k, v = self.forward_qkv(query, key, value)
  93. scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
  94. return self.forward_attention(v, scores, mask)
  95. class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
  96. """Multi-Head Attention layer with relative position encoding (old version).
  97. Details can be found in https://github.com/espnet/espnet/pull/2816.
  98. Paper: https://arxiv.org/abs/1901.02860
  99. Args:
  100. n_head (int): The number of heads.
  101. n_feat (int): The number of features.
  102. dropout_rate (float): Dropout rate.
  103. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  104. """
  105. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  106. """Construct an RelPositionMultiHeadedAttention object."""
  107. super().__init__(n_head, n_feat, dropout_rate)
  108. self.zero_triu = zero_triu
  109. # linear transformation for positional encoding
  110. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  111. # these two learnable bias are used in matrix c and matrix d
  112. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  113. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  114. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  115. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  116. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  117. def rel_shift(self, x):
  118. """Compute relative positional encoding.
  119. Args:
  120. x (torch.Tensor): Input tensor (batch, head, time1, time2).
  121. Returns:
  122. torch.Tensor: Output tensor.
  123. """
  124. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  125. x_padded = torch.cat([zero_pad, x], dim=-1)
  126. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  127. x = x_padded[:, :, 1:].view_as(x)
  128. if self.zero_triu:
  129. ones = torch.ones((x.size(2), x.size(3)))
  130. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  131. return x
  132. def forward(self, query, key, value, pos_emb, mask):
  133. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  134. Args:
  135. query (torch.Tensor): Query tensor (#batch, time1, size).
  136. key (torch.Tensor): Key tensor (#batch, time2, size).
  137. value (torch.Tensor): Value tensor (#batch, time2, size).
  138. pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
  139. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  140. (#batch, time1, time2).
  141. Returns:
  142. torch.Tensor: Output tensor (#batch, time1, d_model).
  143. """
  144. q, k, v = self.forward_qkv(query, key, value)
  145. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  146. n_batch_pos = pos_emb.size(0)
  147. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  148. p = p.transpose(1, 2) # (batch, head, time1, d_k)
  149. # (batch, head, time1, d_k)
  150. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  151. # (batch, head, time1, d_k)
  152. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  153. # compute attention score
  154. # first compute matrix a and matrix c
  155. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  156. # (batch, head, time1, time2)
  157. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  158. # compute matrix b and matrix d
  159. # (batch, head, time1, time1)
  160. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  161. matrix_bd = self.rel_shift(matrix_bd)
  162. scores = (matrix_ac + matrix_bd) / math.sqrt(
  163. self.d_k
  164. ) # (batch, head, time1, time2)
  165. return self.forward_attention(v, scores, mask)
  166. class RelPositionMultiHeadedAttention(MultiHeadedAttention):
  167. """Multi-Head Attention layer with relative position encoding (new implementation).
  168. Details can be found in https://github.com/espnet/espnet/pull/2816.
  169. Paper: https://arxiv.org/abs/1901.02860
  170. Args:
  171. n_head (int): The number of heads.
  172. n_feat (int): The number of features.
  173. dropout_rate (float): Dropout rate.
  174. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  175. """
  176. def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
  177. """Construct an RelPositionMultiHeadedAttention object."""
  178. super().__init__(n_head, n_feat, dropout_rate)
  179. self.zero_triu = zero_triu
  180. # linear transformation for positional encoding
  181. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  182. # these two learnable bias are used in matrix c and matrix d
  183. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  184. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  185. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  186. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  187. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  188. def rel_shift(self, x):
  189. """Compute relative positional encoding.
  190. Args:
  191. x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
  192. time1 means the length of query vector.
  193. Returns:
  194. torch.Tensor: Output tensor.
  195. """
  196. zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
  197. x_padded = torch.cat([zero_pad, x], dim=-1)
  198. x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
  199. x = x_padded[:, :, 1:].view_as(x)[
  200. :, :, :, : x.size(-1) // 2 + 1
  201. ] # only keep the positions from 0 to time2
  202. if self.zero_triu:
  203. ones = torch.ones((x.size(2), x.size(3)), device=x.device)
  204. x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
  205. return x
  206. def forward(self, query, key, value, pos_emb, mask):
  207. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  208. Args:
  209. query (torch.Tensor): Query tensor (#batch, time1, size).
  210. key (torch.Tensor): Key tensor (#batch, time2, size).
  211. value (torch.Tensor): Value tensor (#batch, time2, size).
  212. pos_emb (torch.Tensor): Positional embedding tensor
  213. (#batch, 2*time1-1, size).
  214. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  215. (#batch, time1, time2).
  216. Returns:
  217. torch.Tensor: Output tensor (#batch, time1, d_model).
  218. """
  219. q, k, v = self.forward_qkv(query, key, value)
  220. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  221. n_batch_pos = pos_emb.size(0)
  222. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  223. p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
  224. # (batch, head, time1, d_k)
  225. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  226. # (batch, head, time1, d_k)
  227. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  228. # compute attention score
  229. # first compute matrix a and matrix c
  230. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  231. # (batch, head, time1, time2)
  232. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  233. # compute matrix b and matrix d
  234. # (batch, head, time1, 2*time1-1)
  235. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  236. matrix_bd = self.rel_shift(matrix_bd)
  237. scores = (matrix_ac + matrix_bd) / math.sqrt(
  238. self.d_k
  239. ) # (batch, head, time1, time2)
  240. return self.forward_attention(v, scores, mask)
  241. class MultiHeadedAttentionSANM(nn.Module):
  242. """Multi-Head Attention layer.
  243. Args:
  244. n_head (int): The number of heads.
  245. n_feat (int): The number of features.
  246. dropout_rate (float): Dropout rate.
  247. """
  248. def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
  249. """Construct an MultiHeadedAttention object."""
  250. super(MultiHeadedAttentionSANM, self).__init__()
  251. assert n_feat % n_head == 0
  252. # We assume d_v always equals d_k
  253. self.d_k = n_feat // n_head
  254. self.h = n_head
  255. # self.linear_q = nn.Linear(n_feat, n_feat)
  256. # self.linear_k = nn.Linear(n_feat, n_feat)
  257. # self.linear_v = nn.Linear(n_feat, n_feat)
  258. if lora_list is not None:
  259. if "o" in lora_list:
  260. self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
  261. else:
  262. self.linear_out = nn.Linear(n_feat, n_feat)
  263. lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
  264. if lora_qkv_list == [False, False, False]:
  265. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  266. else:
  267. self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
  268. else:
  269. self.linear_out = nn.Linear(n_feat, n_feat)
  270. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  271. self.attn = None
  272. self.dropout = nn.Dropout(p=dropout_rate)
  273. self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  274. # padding
  275. left_padding = (kernel_size - 1) // 2
  276. if sanm_shfit > 0:
  277. left_padding = left_padding + sanm_shfit
  278. right_padding = kernel_size - 1 - left_padding
  279. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  280. def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
  281. b, t, d = inputs.size()
  282. if mask is not None:
  283. mask = torch.reshape(mask, (b, -1, 1))
  284. if mask_shfit_chunk is not None:
  285. mask = mask * mask_shfit_chunk
  286. inputs = inputs * mask
  287. x = inputs.transpose(1, 2)
  288. x = self.pad_fn(x)
  289. x = self.fsmn_block(x)
  290. x = x.transpose(1, 2)
  291. x += inputs
  292. x = self.dropout(x)
  293. if mask is not None:
  294. x = x * mask
  295. return x
  296. def forward_qkv(self, x):
  297. """Transform query, key and value.
  298. Args:
  299. query (torch.Tensor): Query tensor (#batch, time1, size).
  300. key (torch.Tensor): Key tensor (#batch, time2, size).
  301. value (torch.Tensor): Value tensor (#batch, time2, size).
  302. Returns:
  303. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  304. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  305. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  306. """
  307. b, t, d = x.size()
  308. q_k_v = self.linear_q_k_v(x)
  309. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  310. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  311. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  312. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  313. return q_h, k_h, v_h, v
  314. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  315. """Compute attention context vector.
  316. Args:
  317. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  318. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  319. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  320. Returns:
  321. torch.Tensor: Transformed value (#batch, time1, d_model)
  322. weighted by the attention score (#batch, time1, time2).
  323. """
  324. n_batch = value.size(0)
  325. if mask is not None:
  326. if mask_att_chunk_encoder is not None:
  327. mask = mask * mask_att_chunk_encoder
  328. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  329. min_value = float(
  330. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  331. )
  332. scores = scores.masked_fill(mask, min_value)
  333. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  334. mask, 0.0
  335. ) # (batch, head, time1, time2)
  336. else:
  337. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  338. p_attn = self.dropout(self.attn)
  339. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  340. x = (
  341. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  342. ) # (batch, time1, d_model)
  343. return self.linear_out(x) # (batch, time1, d_model)
  344. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  345. """Compute scaled dot product attention.
  346. Args:
  347. query (torch.Tensor): Query tensor (#batch, time1, size).
  348. key (torch.Tensor): Key tensor (#batch, time2, size).
  349. value (torch.Tensor): Value tensor (#batch, time2, size).
  350. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  351. (#batch, time1, time2).
  352. Returns:
  353. torch.Tensor: Output tensor (#batch, time1, d_model).
  354. """
  355. q_h, k_h, v_h, v = self.forward_qkv(x)
  356. fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
  357. q_h = q_h * self.d_k ** (-0.5)
  358. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  359. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  360. return att_outs + fsmn_memory
  361. def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
  362. """Compute scaled dot product attention.
  363. Args:
  364. query (torch.Tensor): Query tensor (#batch, time1, size).
  365. key (torch.Tensor): Key tensor (#batch, time2, size).
  366. value (torch.Tensor): Value tensor (#batch, time2, size).
  367. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  368. (#batch, time1, time2).
  369. Returns:
  370. torch.Tensor: Output tensor (#batch, time1, d_model).
  371. """
  372. q_h, k_h, v_h, v = self.forward_qkv(x)
  373. if chunk_size is not None and look_back > 0 or look_back == -1:
  374. if cache is not None:
  375. k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
  376. v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
  377. k_h = torch.cat((cache["k"], k_h), dim=2)
  378. v_h = torch.cat((cache["v"], v_h), dim=2)
  379. cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
  380. cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
  381. if look_back != -1:
  382. cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
  383. cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
  384. else:
  385. cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
  386. "v": v_h[:, :, :-(chunk_size[2]), :]}
  387. cache = cache_tmp
  388. fsmn_memory = self.forward_fsmn(v, None)
  389. q_h = q_h * self.d_k ** (-0.5)
  390. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  391. att_outs = self.forward_attention(v_h, scores, None)
  392. return att_outs + fsmn_memory, cache
  393. class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
  394. def __init__(self, *args, **kwargs):
  395. super().__init__(*args, **kwargs)
  396. def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  397. q_h, k_h, v_h, v = self.forward_qkv(x)
  398. fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
  399. q_h = q_h * self.d_k ** (-0.5)
  400. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  401. att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
  402. return att_outs + fsmn_memory
  403. class MultiHeadedAttentionSANMDecoder(nn.Module):
  404. """Multi-Head Attention layer.
  405. Args:
  406. n_head (int): The number of heads.
  407. n_feat (int): The number of features.
  408. dropout_rate (float): Dropout rate.
  409. """
  410. def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
  411. """Construct an MultiHeadedAttention object."""
  412. super(MultiHeadedAttentionSANMDecoder, self).__init__()
  413. self.dropout = nn.Dropout(p=dropout_rate)
  414. self.fsmn_block = nn.Conv1d(n_feat, n_feat,
  415. kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
  416. # padding
  417. # padding
  418. left_padding = (kernel_size - 1) // 2
  419. if sanm_shfit > 0:
  420. left_padding = left_padding + sanm_shfit
  421. right_padding = kernel_size - 1 - left_padding
  422. self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  423. self.kernel_size = kernel_size
  424. def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
  425. '''
  426. :param x: (#batch, time1, size).
  427. :param mask: Mask tensor (#batch, 1, time)
  428. :return:
  429. '''
  430. # print("in fsmn, inputs", inputs.size())
  431. b, t, d = inputs.size()
  432. # logging.info(
  433. # "mask: {}".format(mask.size()))
  434. if mask is not None:
  435. mask = torch.reshape(mask, (b ,-1, 1))
  436. # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  437. if mask_shfit_chunk is not None:
  438. # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
  439. mask = mask * mask_shfit_chunk
  440. # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
  441. # print("in fsmn, mask", mask.size())
  442. # print("in fsmn, inputs", inputs.size())
  443. inputs = inputs * mask
  444. x = inputs.transpose(1, 2)
  445. b, d, t = x.size()
  446. if cache is None:
  447. # print("in fsmn, cache is None, x", x.size())
  448. x = self.pad_fn(x)
  449. if not self.training:
  450. cache = x
  451. else:
  452. # print("in fsmn, cache is not None, x", x.size())
  453. # x = torch.cat((x, cache), dim=2)[:, :, :-1]
  454. # if t < self.kernel_size:
  455. # x = self.pad_fn(x)
  456. x = torch.cat((cache[:, :, 1:], x), dim=2)
  457. x = x[:, :, -(self.kernel_size+t-1):]
  458. # print("in fsmn, cache is not None, x_cat", x.size())
  459. cache = x
  460. x = self.fsmn_block(x)
  461. x = x.transpose(1, 2)
  462. # print("in fsmn, fsmn_out", x.size())
  463. if x.size(1) != inputs.size(1):
  464. inputs = inputs[:, -1, :]
  465. x = x + inputs
  466. x = self.dropout(x)
  467. if mask is not None:
  468. x = x * mask
  469. return x, cache
  470. class MultiHeadedAttentionCrossAtt(nn.Module):
  471. """Multi-Head Attention layer.
  472. Args:
  473. n_head (int): The number of heads.
  474. n_feat (int): The number of features.
  475. dropout_rate (float): Dropout rate.
  476. """
  477. def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
  478. """Construct an MultiHeadedAttention object."""
  479. super(MultiHeadedAttentionCrossAtt, self).__init__()
  480. assert n_feat % n_head == 0
  481. # We assume d_v always equals d_k
  482. self.d_k = n_feat // n_head
  483. self.h = n_head
  484. if lora_list is not None:
  485. if "q" in lora_list:
  486. self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
  487. else:
  488. self.linear_q = nn.Linear(n_feat, n_feat)
  489. lora_kv_list = ["k" in lora_list, "v" in lora_list]
  490. if lora_kv_list == [False, False]:
  491. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  492. else:
  493. self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
  494. r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
  495. if "o" in lora_list:
  496. self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
  497. else:
  498. self.linear_out = nn.Linear(n_feat, n_feat)
  499. else:
  500. self.linear_q = nn.Linear(n_feat, n_feat)
  501. self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
  502. self.linear_out = nn.Linear(n_feat, n_feat)
  503. self.attn = None
  504. self.dropout = nn.Dropout(p=dropout_rate)
  505. def forward_qkv(self, x, memory):
  506. """Transform query, key and value.
  507. Args:
  508. query (torch.Tensor): Query tensor (#batch, time1, size).
  509. key (torch.Tensor): Key tensor (#batch, time2, size).
  510. value (torch.Tensor): Value tensor (#batch, time2, size).
  511. Returns:
  512. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  513. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  514. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  515. """
  516. # print("in forward_qkv, x", x.size())
  517. b = x.size(0)
  518. q = self.linear_q(x)
  519. q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  520. k_v = self.linear_k_v(memory)
  521. k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
  522. k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  523. v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  524. return q_h, k_h, v_h
  525. def forward_attention(self, value, scores, mask):
  526. """Compute attention context vector.
  527. Args:
  528. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  529. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  530. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  531. Returns:
  532. torch.Tensor: Transformed value (#batch, time1, d_model)
  533. weighted by the attention score (#batch, time1, time2).
  534. """
  535. n_batch = value.size(0)
  536. if mask is not None:
  537. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  538. min_value = float(
  539. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  540. )
  541. # logging.info(
  542. # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
  543. scores = scores.masked_fill(mask, min_value)
  544. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  545. mask, 0.0
  546. ) # (batch, head, time1, time2)
  547. else:
  548. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  549. p_attn = self.dropout(self.attn)
  550. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  551. x = (
  552. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  553. ) # (batch, time1, d_model)
  554. return self.linear_out(x) # (batch, time1, d_model)
  555. def forward(self, x, memory, memory_mask):
  556. """Compute scaled dot product attention.
  557. Args:
  558. query (torch.Tensor): Query tensor (#batch, time1, size).
  559. key (torch.Tensor): Key tensor (#batch, time2, size).
  560. value (torch.Tensor): Value tensor (#batch, time2, size).
  561. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  562. (#batch, time1, time2).
  563. Returns:
  564. torch.Tensor: Output tensor (#batch, time1, d_model).
  565. """
  566. q_h, k_h, v_h = self.forward_qkv(x, memory)
  567. q_h = q_h * self.d_k ** (-0.5)
  568. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  569. return self.forward_attention(v_h, scores, memory_mask)
  570. def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
  571. """Compute scaled dot product attention.
  572. Args:
  573. query (torch.Tensor): Query tensor (#batch, time1, size).
  574. key (torch.Tensor): Key tensor (#batch, time2, size).
  575. value (torch.Tensor): Value tensor (#batch, time2, size).
  576. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  577. (#batch, time1, time2).
  578. Returns:
  579. torch.Tensor: Output tensor (#batch, time1, d_model).
  580. """
  581. q_h, k_h, v_h = self.forward_qkv(x, memory)
  582. if chunk_size is not None and look_back > 0:
  583. if cache is not None:
  584. k_h = torch.cat((cache["k"], k_h), dim=2)
  585. v_h = torch.cat((cache["v"], v_h), dim=2)
  586. cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
  587. cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
  588. else:
  589. cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
  590. "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
  591. cache = cache_tmp
  592. q_h = q_h * self.d_k ** (-0.5)
  593. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  594. return self.forward_attention(v_h, scores, None), cache
  595. class MultiHeadSelfAttention(nn.Module):
  596. """Multi-Head Attention layer.
  597. Args:
  598. n_head (int): The number of heads.
  599. n_feat (int): The number of features.
  600. dropout_rate (float): Dropout rate.
  601. """
  602. def __init__(self, n_head, in_feat, n_feat, dropout_rate):
  603. """Construct an MultiHeadedAttention object."""
  604. super(MultiHeadSelfAttention, self).__init__()
  605. assert n_feat % n_head == 0
  606. # We assume d_v always equals d_k
  607. self.d_k = n_feat // n_head
  608. self.h = n_head
  609. self.linear_out = nn.Linear(n_feat, n_feat)
  610. self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
  611. self.attn = None
  612. self.dropout = nn.Dropout(p=dropout_rate)
  613. def forward_qkv(self, x):
  614. """Transform query, key and value.
  615. Args:
  616. query (torch.Tensor): Query tensor (#batch, time1, size).
  617. key (torch.Tensor): Key tensor (#batch, time2, size).
  618. value (torch.Tensor): Value tensor (#batch, time2, size).
  619. Returns:
  620. torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
  621. torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
  622. torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
  623. """
  624. b, t, d = x.size()
  625. q_k_v = self.linear_q_k_v(x)
  626. q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
  627. q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
  628. k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  629. v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
  630. return q_h, k_h, v_h, v
  631. def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
  632. """Compute attention context vector.
  633. Args:
  634. value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
  635. scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
  636. mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
  637. Returns:
  638. torch.Tensor: Transformed value (#batch, time1, d_model)
  639. weighted by the attention score (#batch, time1, time2).
  640. """
  641. n_batch = value.size(0)
  642. if mask is not None:
  643. if mask_att_chunk_encoder is not None:
  644. mask = mask * mask_att_chunk_encoder
  645. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  646. min_value = float(
  647. numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
  648. )
  649. scores = scores.masked_fill(mask, min_value)
  650. self.attn = torch.softmax(scores, dim=-1).masked_fill(
  651. mask, 0.0
  652. ) # (batch, head, time1, time2)
  653. else:
  654. self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  655. p_attn = self.dropout(self.attn)
  656. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  657. x = (
  658. x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
  659. ) # (batch, time1, d_model)
  660. return self.linear_out(x) # (batch, time1, d_model)
  661. def forward(self, x, mask, mask_att_chunk_encoder=None):
  662. """Compute scaled dot product attention.
  663. Args:
  664. query (torch.Tensor): Query tensor (#batch, time1, size).
  665. key (torch.Tensor): Key tensor (#batch, time2, size).
  666. value (torch.Tensor): Value tensor (#batch, time2, size).
  667. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  668. (#batch, time1, time2).
  669. Returns:
  670. torch.Tensor: Output tensor (#batch, time1, d_model).
  671. """
  672. q_h, k_h, v_h, v = self.forward_qkv(x)
  673. q_h = q_h * self.d_k ** (-0.5)
  674. scores = torch.matmul(q_h, k_h.transpose(-2, -1))
  675. att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
  676. return att_outs
  677. class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
  678. """RelPositionMultiHeadedAttention definition.
  679. Args:
  680. num_heads: Number of attention heads.
  681. embed_size: Embedding size.
  682. dropout_rate: Dropout rate.
  683. """
  684. def __init__(
  685. self,
  686. num_heads: int,
  687. embed_size: int,
  688. dropout_rate: float = 0.0,
  689. simplified_attention_score: bool = False,
  690. ) -> None:
  691. """Construct an MultiHeadedAttention object."""
  692. super().__init__()
  693. self.d_k = embed_size // num_heads
  694. self.num_heads = num_heads
  695. assert self.d_k * num_heads == embed_size, (
  696. "embed_size (%d) must be divisible by num_heads (%d)",
  697. (embed_size, num_heads),
  698. )
  699. self.linear_q = torch.nn.Linear(embed_size, embed_size)
  700. self.linear_k = torch.nn.Linear(embed_size, embed_size)
  701. self.linear_v = torch.nn.Linear(embed_size, embed_size)
  702. self.linear_out = torch.nn.Linear(embed_size, embed_size)
  703. if simplified_attention_score:
  704. self.linear_pos = torch.nn.Linear(embed_size, num_heads)
  705. self.compute_att_score = self.compute_simplified_attention_score
  706. else:
  707. self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
  708. self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
  709. self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
  710. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  711. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  712. self.compute_att_score = self.compute_attention_score
  713. self.dropout = torch.nn.Dropout(p=dropout_rate)
  714. self.attn = None
  715. def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
  716. """Compute relative positional encoding.
  717. Args:
  718. x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
  719. left_context: Number of frames in left context.
  720. Returns:
  721. x: Output sequence. (B, H, T_1, T_2)
  722. """
  723. batch_size, n_heads, time1, n = x.shape
  724. time2 = time1 + left_context
  725. batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
  726. return x.as_strided(
  727. (batch_size, n_heads, time1, time2),
  728. (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
  729. storage_offset=(n_stride * (time1 - 1)),
  730. )
  731. def compute_simplified_attention_score(
  732. self,
  733. query: torch.Tensor,
  734. key: torch.Tensor,
  735. pos_enc: torch.Tensor,
  736. left_context: int = 0,
  737. ) -> torch.Tensor:
  738. """Simplified attention score computation.
  739. Reference: https://github.com/k2-fsa/icefall/pull/458
  740. Args:
  741. query: Transformed query tensor. (B, H, T_1, d_k)
  742. key: Transformed key tensor. (B, H, T_2, d_k)
  743. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  744. left_context: Number of frames in left context.
  745. Returns:
  746. : Attention score. (B, H, T_1, T_2)
  747. """
  748. pos_enc = self.linear_pos(pos_enc)
  749. matrix_ac = torch.matmul(query, key.transpose(2, 3))
  750. matrix_bd = self.rel_shift(
  751. pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
  752. left_context=left_context,
  753. )
  754. return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
  755. def compute_attention_score(
  756. self,
  757. query: torch.Tensor,
  758. key: torch.Tensor,
  759. pos_enc: torch.Tensor,
  760. left_context: int = 0,
  761. ) -> torch.Tensor:
  762. """Attention score computation.
  763. Args:
  764. query: Transformed query tensor. (B, H, T_1, d_k)
  765. key: Transformed key tensor. (B, H, T_2, d_k)
  766. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  767. left_context: Number of frames in left context.
  768. Returns:
  769. : Attention score. (B, H, T_1, T_2)
  770. """
  771. p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
  772. query = query.transpose(1, 2)
  773. q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
  774. q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
  775. matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
  776. matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
  777. matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
  778. return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
  779. def forward_qkv(
  780. self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
  781. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  782. """Transform query, key and value.
  783. Args:
  784. query: Query tensor. (B, T_1, size)
  785. key: Key tensor. (B, T_2, size)
  786. v: Value tensor. (B, T_2, size)
  787. Returns:
  788. q: Transformed query tensor. (B, H, T_1, d_k)
  789. k: Transformed key tensor. (B, H, T_2, d_k)
  790. v: Transformed value tensor. (B, H, T_2, d_k)
  791. """
  792. n_batch = query.size(0)
  793. q = (
  794. self.linear_q(query)
  795. .view(n_batch, -1, self.num_heads, self.d_k)
  796. .transpose(1, 2)
  797. )
  798. k = (
  799. self.linear_k(key)
  800. .view(n_batch, -1, self.num_heads, self.d_k)
  801. .transpose(1, 2)
  802. )
  803. v = (
  804. self.linear_v(value)
  805. .view(n_batch, -1, self.num_heads, self.d_k)
  806. .transpose(1, 2)
  807. )
  808. return q, k, v
  809. def forward_attention(
  810. self,
  811. value: torch.Tensor,
  812. scores: torch.Tensor,
  813. mask: torch.Tensor,
  814. chunk_mask: Optional[torch.Tensor] = None,
  815. ) -> torch.Tensor:
  816. """Compute attention context vector.
  817. Args:
  818. value: Transformed value. (B, H, T_2, d_k)
  819. scores: Attention score. (B, H, T_1, T_2)
  820. mask: Source mask. (B, T_2)
  821. chunk_mask: Chunk mask. (T_1, T_1)
  822. Returns:
  823. attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
  824. """
  825. batch_size = scores.size(0)
  826. mask = mask.unsqueeze(1).unsqueeze(2)
  827. if chunk_mask is not None:
  828. mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
  829. scores = scores.masked_fill(mask, float("-inf"))
  830. self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
  831. attn_output = self.dropout(self.attn)
  832. attn_output = torch.matmul(attn_output, value)
  833. attn_output = self.linear_out(
  834. attn_output.transpose(1, 2)
  835. .contiguous()
  836. .view(batch_size, -1, self.num_heads * self.d_k)
  837. )
  838. return attn_output
  839. def forward(
  840. self,
  841. query: torch.Tensor,
  842. key: torch.Tensor,
  843. value: torch.Tensor,
  844. pos_enc: torch.Tensor,
  845. mask: torch.Tensor,
  846. chunk_mask: Optional[torch.Tensor] = None,
  847. left_context: int = 0,
  848. ) -> torch.Tensor:
  849. """Compute scaled dot product attention with rel. positional encoding.
  850. Args:
  851. query: Query tensor. (B, T_1, size)
  852. key: Key tensor. (B, T_2, size)
  853. value: Value tensor. (B, T_2, size)
  854. pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
  855. mask: Source mask. (B, T_2)
  856. chunk_mask: Chunk mask. (T_1, T_1)
  857. left_context: Number of frames in left context.
  858. Returns:
  859. : Output tensor. (B, T_1, H * d_k)
  860. """
  861. q, k, v = self.forward_qkv(query, key, value)
  862. scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
  863. return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
  864. class CosineDistanceAttention(nn.Module):
  865. """ Compute Cosine Distance between spk decoder output and speaker profile
  866. Args:
  867. profile_path: speaker profile file path (.npy file)
  868. """
  869. def __init__(self):
  870. super().__init__()
  871. self.softmax = nn.Softmax(dim=-1)
  872. def forward(self, spk_decoder_out, profile, profile_lens=None):
  873. """
  874. Args:
  875. spk_decoder_out(torch.Tensor):(B, L, D)
  876. spk_profiles(torch.Tensor):(B, N, D)
  877. """
  878. x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
  879. if profile_lens is not None:
  880. mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
  881. min_value = float(
  882. numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
  883. )
  884. weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
  885. weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
  886. else:
  887. x = x[:, -1:, :, :]
  888. weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
  889. weights = self.softmax(weights_not_softmax) # (B, 1, N)
  890. spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
  891. return spk_embedding, weights