decoder.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. from typing import List, Tuple
  7. from funasr.register import tables
  8. from funasr.models.scama import utils as myutils
  9. from funasr.models.transformer.utils.repeat import repeat
  10. from funasr.models.transformer.decoder import DecoderLayer
  11. from funasr.models.transformer.layer_norm import LayerNorm
  12. from funasr.models.transformer.embedding import PositionalEncoding
  13. from funasr.models.transformer.attention import MultiHeadedAttention
  14. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  15. from funasr.models.transformer.decoder import BaseTransformerDecoder
  16. from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
  17. from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  18. from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
  19. class DecoderLayerSANM(torch.nn.Module):
  20. """Single decoder layer module.
  21. Args:
  22. size (int): Input dimension.
  23. self_attn (torch.nn.Module): Self-attention module instance.
  24. `MultiHeadedAttention` instance can be used as the argument.
  25. src_attn (torch.nn.Module): Self-attention module instance.
  26. `MultiHeadedAttention` instance can be used as the argument.
  27. feed_forward (torch.nn.Module): Feed-forward module instance.
  28. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  29. can be used as the argument.
  30. dropout_rate (float): Dropout rate.
  31. normalize_before (bool): Whether to use layer_norm before the first block.
  32. concat_after (bool): Whether to concat attention layer's input and output.
  33. if True, additional linear will be applied.
  34. i.e. x -> x + linear(concat(x, att(x)))
  35. if False, no additional linear will be applied. i.e. x -> x + att(x)
  36. """
  37. def __init__(
  38. self,
  39. size,
  40. self_attn,
  41. src_attn,
  42. feed_forward,
  43. dropout_rate,
  44. normalize_before=True,
  45. concat_after=False,
  46. ):
  47. """Construct an DecoderLayer object."""
  48. super(DecoderLayerSANM, self).__init__()
  49. self.size = size
  50. self.self_attn = self_attn
  51. self.src_attn = src_attn
  52. self.feed_forward = feed_forward
  53. self.norm1 = LayerNorm(size)
  54. if self_attn is not None:
  55. self.norm2 = LayerNorm(size)
  56. if src_attn is not None:
  57. self.norm3 = LayerNorm(size)
  58. self.dropout = torch.nn.Dropout(dropout_rate)
  59. self.normalize_before = normalize_before
  60. self.concat_after = concat_after
  61. if self.concat_after:
  62. self.concat_linear1 = torch.nn.Linear(size + size, size)
  63. self.concat_linear2 = torch.nn.Linear(size + size, size)
  64. self.reserve_attn=False
  65. self.attn_mat = []
  66. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  67. """Compute decoded features.
  68. Args:
  69. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  70. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  71. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  72. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  73. cache (List[torch.Tensor]): List of cached tensors.
  74. Each tensor shape should be (#batch, maxlen_out - 1, size).
  75. Returns:
  76. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  77. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  78. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  79. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  80. """
  81. # tgt = self.dropout(tgt)
  82. residual = tgt
  83. if self.normalize_before:
  84. tgt = self.norm1(tgt)
  85. tgt = self.feed_forward(tgt)
  86. x = tgt
  87. if self.self_attn:
  88. if self.normalize_before:
  89. tgt = self.norm2(tgt)
  90. x, _ = self.self_attn(tgt, tgt_mask)
  91. x = residual + self.dropout(x)
  92. if self.src_attn is not None:
  93. residual = x
  94. if self.normalize_before:
  95. x = self.norm3(x)
  96. if self.reserve_attn:
  97. x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
  98. self.attn_mat.append(attn_mat)
  99. else:
  100. x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False)
  101. x = residual + self.dropout(x_src_attn)
  102. # x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  103. return x, tgt_mask, memory, memory_mask, cache
  104. def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  105. residual = tgt
  106. tgt = self.norm1(tgt)
  107. tgt = self.feed_forward(tgt)
  108. x = tgt
  109. if self.self_attn is not None:
  110. tgt = self.norm2(tgt)
  111. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  112. x = residual + x
  113. residual = x
  114. x = self.norm3(x)
  115. x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
  116. return attn_mat
  117. def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  118. """Compute decoded features.
  119. Args:
  120. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  121. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  122. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  123. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  124. cache (List[torch.Tensor]): List of cached tensors.
  125. Each tensor shape should be (#batch, maxlen_out - 1, size).
  126. Returns:
  127. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  128. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  129. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  130. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  131. """
  132. # tgt = self.dropout(tgt)
  133. residual = tgt
  134. if self.normalize_before:
  135. tgt = self.norm1(tgt)
  136. tgt = self.feed_forward(tgt)
  137. x = tgt
  138. if self.self_attn:
  139. if self.normalize_before:
  140. tgt = self.norm2(tgt)
  141. if self.training:
  142. cache = None
  143. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  144. x = residual + self.dropout(x)
  145. if self.src_attn is not None:
  146. residual = x
  147. if self.normalize_before:
  148. x = self.norm3(x)
  149. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  150. return x, tgt_mask, memory, memory_mask, cache
  151. def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
  152. """Compute decoded features.
  153. Args:
  154. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  155. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  156. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  157. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  158. cache (List[torch.Tensor]): List of cached tensors.
  159. Each tensor shape should be (#batch, maxlen_out - 1, size).
  160. Returns:
  161. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  162. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  163. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  164. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  165. """
  166. residual = tgt
  167. if self.normalize_before:
  168. tgt = self.norm1(tgt)
  169. tgt = self.feed_forward(tgt)
  170. x = tgt
  171. if self.self_attn:
  172. if self.normalize_before:
  173. tgt = self.norm2(tgt)
  174. x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
  175. x = residual + self.dropout(x)
  176. if self.src_attn is not None:
  177. residual = x
  178. if self.normalize_before:
  179. x = self.norm3(x)
  180. x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
  181. x = residual + x
  182. return x, memory, fsmn_cache, opt_cache
  183. @tables.register("decoder_classes", "ParaformerSANMDecoder")
  184. class ParaformerSANMDecoder(BaseTransformerDecoder):
  185. """
  186. Author: Speech Lab of DAMO Academy, Alibaba Group
  187. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  188. https://arxiv.org/abs/2006.01713
  189. """
  190. def __init__(
  191. self,
  192. vocab_size: int,
  193. encoder_output_size: int,
  194. attention_heads: int = 4,
  195. linear_units: int = 2048,
  196. num_blocks: int = 6,
  197. dropout_rate: float = 0.1,
  198. positional_dropout_rate: float = 0.1,
  199. self_attention_dropout_rate: float = 0.0,
  200. src_attention_dropout_rate: float = 0.0,
  201. input_layer: str = "embed",
  202. use_output_layer: bool = True,
  203. wo_input_layer: bool = False,
  204. pos_enc_class=PositionalEncoding,
  205. normalize_before: bool = True,
  206. concat_after: bool = False,
  207. att_layer_num: int = 6,
  208. kernel_size: int = 21,
  209. sanm_shfit: int = 0,
  210. lora_list: List[str] = None,
  211. lora_rank: int = 8,
  212. lora_alpha: int = 16,
  213. lora_dropout: float = 0.1,
  214. chunk_multiply_factor: tuple = (1,),
  215. tf2torch_tensor_name_prefix_torch: str = "decoder",
  216. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  217. ):
  218. super().__init__(
  219. vocab_size=vocab_size,
  220. encoder_output_size=encoder_output_size,
  221. dropout_rate=dropout_rate,
  222. positional_dropout_rate=positional_dropout_rate,
  223. input_layer=input_layer,
  224. use_output_layer=use_output_layer,
  225. pos_enc_class=pos_enc_class,
  226. normalize_before=normalize_before,
  227. )
  228. attention_dim = encoder_output_size
  229. if wo_input_layer:
  230. self.embed = None
  231. else:
  232. if input_layer == "embed":
  233. self.embed = torch.nn.Sequential(
  234. torch.nn.Embedding(vocab_size, attention_dim),
  235. # pos_enc_class(attention_dim, positional_dropout_rate),
  236. )
  237. elif input_layer == "linear":
  238. self.embed = torch.nn.Sequential(
  239. torch.nn.Linear(vocab_size, attention_dim),
  240. torch.nn.LayerNorm(attention_dim),
  241. torch.nn.Dropout(dropout_rate),
  242. torch.nn.ReLU(),
  243. pos_enc_class(attention_dim, positional_dropout_rate),
  244. )
  245. else:
  246. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  247. self.normalize_before = normalize_before
  248. if self.normalize_before:
  249. self.after_norm = LayerNorm(attention_dim)
  250. if use_output_layer:
  251. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  252. else:
  253. self.output_layer = None
  254. self.att_layer_num = att_layer_num
  255. self.num_blocks = num_blocks
  256. if sanm_shfit is None:
  257. sanm_shfit = (kernel_size - 1) // 2
  258. self.decoders = repeat(
  259. att_layer_num,
  260. lambda lnum: DecoderLayerSANM(
  261. attention_dim,
  262. MultiHeadedAttentionSANMDecoder(
  263. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  264. ),
  265. MultiHeadedAttentionCrossAtt(
  266. attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
  267. ),
  268. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  269. dropout_rate,
  270. normalize_before,
  271. concat_after,
  272. ),
  273. )
  274. if num_blocks - att_layer_num <= 0:
  275. self.decoders2 = None
  276. else:
  277. self.decoders2 = repeat(
  278. num_blocks - att_layer_num,
  279. lambda lnum: DecoderLayerSANM(
  280. attention_dim,
  281. MultiHeadedAttentionSANMDecoder(
  282. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
  283. ),
  284. None,
  285. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  286. dropout_rate,
  287. normalize_before,
  288. concat_after,
  289. ),
  290. )
  291. self.decoders3 = repeat(
  292. 1,
  293. lambda lnum: DecoderLayerSANM(
  294. attention_dim,
  295. None,
  296. None,
  297. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  298. dropout_rate,
  299. normalize_before,
  300. concat_after,
  301. ),
  302. )
  303. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  304. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  305. self.chunk_multiply_factor = chunk_multiply_factor
  306. def forward(
  307. self,
  308. hs_pad: torch.Tensor,
  309. hlens: torch.Tensor,
  310. ys_in_pad: torch.Tensor,
  311. ys_in_lens: torch.Tensor,
  312. return_hidden: bool = False,
  313. return_both: bool= False,
  314. chunk_mask: torch.Tensor = None,
  315. ) -> Tuple[torch.Tensor, torch.Tensor]:
  316. """Forward decoder.
  317. Args:
  318. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  319. hlens: (batch)
  320. ys_in_pad:
  321. input token ids, int64 (batch, maxlen_out)
  322. if input_layer == "embed"
  323. input tensor (batch, maxlen_out, #mels) in the other cases
  324. ys_in_lens: (batch)
  325. Returns:
  326. (tuple): tuple containing:
  327. x: decoded token score before softmax (batch, maxlen_out, token)
  328. if use_output_layer is True,
  329. olens: (batch, )
  330. """
  331. tgt = ys_in_pad
  332. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  333. memory = hs_pad
  334. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  335. if chunk_mask is not None:
  336. memory_mask = memory_mask * chunk_mask
  337. if tgt_mask.size(1) != memory_mask.size(1):
  338. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  339. x = tgt
  340. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  341. x, tgt_mask, memory, memory_mask
  342. )
  343. if self.decoders2 is not None:
  344. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  345. x, tgt_mask, memory, memory_mask
  346. )
  347. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  348. x, tgt_mask, memory, memory_mask
  349. )
  350. if self.normalize_before:
  351. hidden = self.after_norm(x)
  352. olens = tgt_mask.sum(1)
  353. if self.output_layer is not None and return_hidden is False:
  354. x = self.output_layer(hidden)
  355. return x, olens
  356. if return_both:
  357. x = self.output_layer(hidden)
  358. return x, hidden, olens
  359. return hidden, olens
  360. def score(self, ys, state, x):
  361. """Score."""
  362. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  363. logp, state = self.forward_one_step(
  364. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  365. )
  366. return logp.squeeze(0), state
  367. def forward_asf2(
  368. self,
  369. hs_pad: torch.Tensor,
  370. hlens: torch.Tensor,
  371. ys_in_pad: torch.Tensor,
  372. ys_in_lens: torch.Tensor,
  373. ):
  374. tgt = ys_in_pad
  375. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  376. memory = hs_pad
  377. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  378. tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
  379. attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
  380. return attn_mat
  381. def forward_asf6(
  382. self,
  383. hs_pad: torch.Tensor,
  384. hlens: torch.Tensor,
  385. ys_in_pad: torch.Tensor,
  386. ys_in_lens: torch.Tensor,
  387. ):
  388. tgt = ys_in_pad
  389. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  390. memory = hs_pad
  391. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  392. tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
  393. tgt, tgt_mask, memory, memory_mask, _ = self.decoders[1](tgt, tgt_mask, memory, memory_mask)
  394. tgt, tgt_mask, memory, memory_mask, _ = self.decoders[2](tgt, tgt_mask, memory, memory_mask)
  395. tgt, tgt_mask, memory, memory_mask, _ = self.decoders[3](tgt, tgt_mask, memory, memory_mask)
  396. tgt, tgt_mask, memory, memory_mask, _ = self.decoders[4](tgt, tgt_mask, memory, memory_mask)
  397. attn_mat = self.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
  398. return attn_mat
  399. def forward_chunk(
  400. self,
  401. memory: torch.Tensor,
  402. tgt: torch.Tensor,
  403. cache: dict = None,
  404. ) -> Tuple[torch.Tensor, torch.Tensor]:
  405. """Forward decoder.
  406. Args:
  407. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  408. hlens: (batch)
  409. ys_in_pad:
  410. input token ids, int64 (batch, maxlen_out)
  411. if input_layer == "embed"
  412. input tensor (batch, maxlen_out, #mels) in the other cases
  413. ys_in_lens: (batch)
  414. Returns:
  415. (tuple): tuple containing:
  416. x: decoded token score before softmax (batch, maxlen_out, token)
  417. if use_output_layer is True,
  418. olens: (batch, )
  419. """
  420. x = tgt
  421. if cache["decode_fsmn"] is None:
  422. cache_layer_num = len(self.decoders)
  423. if self.decoders2 is not None:
  424. cache_layer_num += len(self.decoders2)
  425. fsmn_cache = [None] * cache_layer_num
  426. else:
  427. fsmn_cache = cache["decode_fsmn"]
  428. if cache["opt"] is None:
  429. cache_layer_num = len(self.decoders)
  430. opt_cache = [None] * cache_layer_num
  431. else:
  432. opt_cache = cache["opt"]
  433. for i in range(self.att_layer_num):
  434. decoder = self.decoders[i]
  435. x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
  436. x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
  437. chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
  438. )
  439. if self.num_blocks - self.att_layer_num > 1:
  440. for i in range(self.num_blocks - self.att_layer_num):
  441. j = i + self.att_layer_num
  442. decoder = self.decoders2[i]
  443. x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
  444. x, memory, fsmn_cache=fsmn_cache[j]
  445. )
  446. for decoder in self.decoders3:
  447. x, memory, _, _ = decoder.forward_chunk(
  448. x, memory
  449. )
  450. if self.normalize_before:
  451. x = self.after_norm(x)
  452. if self.output_layer is not None:
  453. x = self.output_layer(x)
  454. cache["decode_fsmn"] = fsmn_cache
  455. if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
  456. cache["opt"] = opt_cache
  457. return x
  458. def forward_one_step(
  459. self,
  460. tgt: torch.Tensor,
  461. tgt_mask: torch.Tensor,
  462. memory: torch.Tensor,
  463. cache: List[torch.Tensor] = None,
  464. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  465. """Forward one step.
  466. Args:
  467. tgt: input token ids, int64 (batch, maxlen_out)
  468. tgt_mask: input token mask, (batch, maxlen_out)
  469. dtype=torch.uint8 in PyTorch 1.2-
  470. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  471. memory: encoded memory, float32 (batch, maxlen_in, feat)
  472. cache: cached output list of (batch, max_time_out-1, size)
  473. Returns:
  474. y, cache: NN output value and cache per `self.decoders`.
  475. y.shape` is (batch, maxlen_out, token)
  476. """
  477. x = self.embed(tgt)
  478. if cache is None:
  479. cache_layer_num = len(self.decoders)
  480. if self.decoders2 is not None:
  481. cache_layer_num += len(self.decoders2)
  482. cache = [None] * cache_layer_num
  483. new_cache = []
  484. # for c, decoder in zip(cache, self.decoders):
  485. for i in range(self.att_layer_num):
  486. decoder = self.decoders[i]
  487. c = cache[i]
  488. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  489. x, tgt_mask, memory, None, cache=c
  490. )
  491. new_cache.append(c_ret)
  492. if self.num_blocks - self.att_layer_num > 1:
  493. for i in range(self.num_blocks - self.att_layer_num):
  494. j = i + self.att_layer_num
  495. decoder = self.decoders2[i]
  496. c = cache[j]
  497. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  498. x, tgt_mask, memory, None, cache=c
  499. )
  500. new_cache.append(c_ret)
  501. for decoder in self.decoders3:
  502. x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
  503. x, tgt_mask, memory, None, cache=None
  504. )
  505. if self.normalize_before:
  506. y = self.after_norm(x[:, -1])
  507. else:
  508. y = x[:, -1]
  509. if self.output_layer is not None:
  510. y = torch.log_softmax(self.output_layer(y), dim=-1)
  511. return y, new_cache
  512. @tables.register("decoder_classes", "ParaformerSANDecoder")
  513. class ParaformerSANDecoder(BaseTransformerDecoder):
  514. """
  515. Author: Speech Lab of DAMO Academy, Alibaba Group
  516. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  517. https://arxiv.org/abs/2006.01713
  518. """
  519. def __init__(
  520. self,
  521. vocab_size: int,
  522. encoder_output_size: int,
  523. attention_heads: int = 4,
  524. linear_units: int = 2048,
  525. num_blocks: int = 6,
  526. dropout_rate: float = 0.1,
  527. positional_dropout_rate: float = 0.1,
  528. self_attention_dropout_rate: float = 0.0,
  529. src_attention_dropout_rate: float = 0.0,
  530. input_layer: str = "embed",
  531. use_output_layer: bool = True,
  532. pos_enc_class=PositionalEncoding,
  533. normalize_before: bool = True,
  534. concat_after: bool = False,
  535. embeds_id: int = -1,
  536. ):
  537. super().__init__(
  538. vocab_size=vocab_size,
  539. encoder_output_size=encoder_output_size,
  540. dropout_rate=dropout_rate,
  541. positional_dropout_rate=positional_dropout_rate,
  542. input_layer=input_layer,
  543. use_output_layer=use_output_layer,
  544. pos_enc_class=pos_enc_class,
  545. normalize_before=normalize_before,
  546. )
  547. attention_dim = encoder_output_size
  548. self.decoders = repeat(
  549. num_blocks,
  550. lambda lnum: DecoderLayer(
  551. attention_dim,
  552. MultiHeadedAttention(
  553. attention_heads, attention_dim, self_attention_dropout_rate
  554. ),
  555. MultiHeadedAttention(
  556. attention_heads, attention_dim, src_attention_dropout_rate
  557. ),
  558. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  559. dropout_rate,
  560. normalize_before,
  561. concat_after,
  562. ),
  563. )
  564. self.embeds_id = embeds_id
  565. self.attention_dim = attention_dim
  566. def forward(
  567. self,
  568. hs_pad: torch.Tensor,
  569. hlens: torch.Tensor,
  570. ys_in_pad: torch.Tensor,
  571. ys_in_lens: torch.Tensor,
  572. ) -> Tuple[torch.Tensor, torch.Tensor]:
  573. """Forward decoder.
  574. Args:
  575. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  576. hlens: (batch)
  577. ys_in_pad:
  578. input token ids, int64 (batch, maxlen_out)
  579. if input_layer == "embed"
  580. input tensor (batch, maxlen_out, #mels) in the other cases
  581. ys_in_lens: (batch)
  582. Returns:
  583. (tuple): tuple containing:
  584. x: decoded token score before softmax (batch, maxlen_out, token)
  585. if use_output_layer is True,
  586. olens: (batch, )
  587. """
  588. tgt = ys_in_pad
  589. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  590. memory = hs_pad
  591. memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
  592. memory.device
  593. )
  594. # Padding for Longformer
  595. if memory_mask.shape[-1] != memory.shape[1]:
  596. padlen = memory.shape[1] - memory_mask.shape[-1]
  597. memory_mask = torch.nn.functional.pad(
  598. memory_mask, (0, padlen), "constant", False
  599. )
  600. # x = self.embed(tgt)
  601. x = tgt
  602. embeds_outputs = None
  603. for layer_id, decoder in enumerate(self.decoders):
  604. x, tgt_mask, memory, memory_mask = decoder(
  605. x, tgt_mask, memory, memory_mask
  606. )
  607. if layer_id == self.embeds_id:
  608. embeds_outputs = x
  609. if self.normalize_before:
  610. x = self.after_norm(x)
  611. if self.output_layer is not None:
  612. x = self.output_layer(x)
  613. olens = tgt_mask.sum(1)
  614. if embeds_outputs is not None:
  615. return x, olens, embeds_outputs
  616. else:
  617. return x, olens