sanm_decoder.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. from typing import List
  2. from typing import Tuple
  3. import torch
  4. import torch.nn as nn
  5. from funasr.modules.streaming_utils import utils as myutils
  6. from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
  7. from typeguard import check_argument_types
  8. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
  9. from funasr.modules.embedding import PositionalEncoding
  10. from funasr.modules.layer_norm import LayerNorm
  11. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  12. from funasr.modules.repeat import repeat
  13. class DecoderLayerSANM(nn.Module):
  14. """Single decoder layer module.
  15. Args:
  16. size (int): Input dimension.
  17. self_attn (torch.nn.Module): Self-attention module instance.
  18. `MultiHeadedAttention` instance can be used as the argument.
  19. src_attn (torch.nn.Module): Self-attention module instance.
  20. `MultiHeadedAttention` instance can be used as the argument.
  21. feed_forward (torch.nn.Module): Feed-forward module instance.
  22. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  23. can be used as the argument.
  24. dropout_rate (float): Dropout rate.
  25. normalize_before (bool): Whether to use layer_norm before the first block.
  26. concat_after (bool): Whether to concat attention layer's input and output.
  27. if True, additional linear will be applied.
  28. i.e. x -> x + linear(concat(x, att(x)))
  29. if False, no additional linear will be applied. i.e. x -> x + att(x)
  30. """
  31. def __init__(
  32. self,
  33. size,
  34. self_attn,
  35. src_attn,
  36. feed_forward,
  37. dropout_rate,
  38. normalize_before=True,
  39. concat_after=False,
  40. ):
  41. """Construct an DecoderLayer object."""
  42. super(DecoderLayerSANM, self).__init__()
  43. self.size = size
  44. self.self_attn = self_attn
  45. self.src_attn = src_attn
  46. self.feed_forward = feed_forward
  47. self.norm1 = LayerNorm(size)
  48. if self_attn is not None:
  49. self.norm2 = LayerNorm(size)
  50. if src_attn is not None:
  51. self.norm3 = LayerNorm(size)
  52. self.dropout = nn.Dropout(dropout_rate)
  53. self.normalize_before = normalize_before
  54. self.concat_after = concat_after
  55. if self.concat_after:
  56. self.concat_linear1 = nn.Linear(size + size, size)
  57. self.concat_linear2 = nn.Linear(size + size, size)
  58. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  59. """Compute decoded features.
  60. Args:
  61. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  62. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  63. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  64. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  65. cache (List[torch.Tensor]): List of cached tensors.
  66. Each tensor shape should be (#batch, maxlen_out - 1, size).
  67. Returns:
  68. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  69. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  70. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  71. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  72. """
  73. # tgt = self.dropout(tgt)
  74. residual = tgt
  75. if self.normalize_before:
  76. tgt = self.norm1(tgt)
  77. tgt = self.feed_forward(tgt)
  78. x = tgt
  79. if self.self_attn:
  80. if self.normalize_before:
  81. tgt = self.norm2(tgt)
  82. if self.training:
  83. cache = None
  84. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  85. x = residual + self.dropout(x)
  86. if self.src_attn is not None:
  87. residual = x
  88. if self.normalize_before:
  89. x = self.norm3(x)
  90. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  91. return x, tgt_mask, memory, memory_mask, cache
  92. class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
  93. """
  94. author: Speech Lab, Alibaba Group, China
  95. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  96. https://arxiv.org/abs/2006.01713
  97. """
  98. def __init__(
  99. self,
  100. vocab_size: int,
  101. encoder_output_size: int,
  102. attention_heads: int = 4,
  103. linear_units: int = 2048,
  104. num_blocks: int = 6,
  105. dropout_rate: float = 0.1,
  106. positional_dropout_rate: float = 0.1,
  107. self_attention_dropout_rate: float = 0.0,
  108. src_attention_dropout_rate: float = 0.0,
  109. input_layer: str = "embed",
  110. use_output_layer: bool = True,
  111. pos_enc_class=PositionalEncoding,
  112. normalize_before: bool = True,
  113. concat_after: bool = False,
  114. att_layer_num: int = 6,
  115. kernel_size: int = 21,
  116. sanm_shfit: int = None,
  117. concat_embeds: bool = False,
  118. attention_dim: int = None,
  119. ):
  120. assert check_argument_types()
  121. super().__init__(
  122. vocab_size=vocab_size,
  123. encoder_output_size=encoder_output_size,
  124. dropout_rate=dropout_rate,
  125. positional_dropout_rate=positional_dropout_rate,
  126. input_layer=input_layer,
  127. use_output_layer=use_output_layer,
  128. pos_enc_class=pos_enc_class,
  129. normalize_before=normalize_before,
  130. )
  131. if attention_dim is None:
  132. attention_dim = encoder_output_size
  133. if input_layer == "embed":
  134. self.embed = torch.nn.Sequential(
  135. torch.nn.Embedding(vocab_size, attention_dim),
  136. )
  137. elif input_layer == "linear":
  138. self.embed = torch.nn.Sequential(
  139. torch.nn.Linear(vocab_size, attention_dim),
  140. torch.nn.LayerNorm(attention_dim),
  141. torch.nn.Dropout(dropout_rate),
  142. torch.nn.ReLU(),
  143. pos_enc_class(attention_dim, positional_dropout_rate),
  144. )
  145. else:
  146. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  147. self.normalize_before = normalize_before
  148. if self.normalize_before:
  149. self.after_norm = LayerNorm(attention_dim)
  150. if use_output_layer:
  151. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  152. else:
  153. self.output_layer = None
  154. self.att_layer_num = att_layer_num
  155. self.num_blocks = num_blocks
  156. if sanm_shfit is None:
  157. sanm_shfit = (kernel_size - 1) // 2
  158. self.decoders = repeat(
  159. att_layer_num,
  160. lambda lnum: DecoderLayerSANM(
  161. attention_dim,
  162. MultiHeadedAttentionSANMDecoder(
  163. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  164. ),
  165. MultiHeadedAttentionCrossAtt(
  166. attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
  167. ),
  168. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  169. dropout_rate,
  170. normalize_before,
  171. concat_after,
  172. ),
  173. )
  174. if num_blocks - att_layer_num <= 0:
  175. self.decoders2 = None
  176. else:
  177. self.decoders2 = repeat(
  178. num_blocks - att_layer_num,
  179. lambda lnum: DecoderLayerSANM(
  180. attention_dim,
  181. MultiHeadedAttentionSANMDecoder(
  182. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  183. ),
  184. None,
  185. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  186. dropout_rate,
  187. normalize_before,
  188. concat_after,
  189. ),
  190. )
  191. self.decoders3 = repeat(
  192. 1,
  193. lambda lnum: DecoderLayerSANM(
  194. attention_dim,
  195. None,
  196. None,
  197. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  198. dropout_rate,
  199. normalize_before,
  200. concat_after,
  201. ),
  202. )
  203. if concat_embeds:
  204. self.embed_concat_ffn = repeat(
  205. 1,
  206. lambda lnum: DecoderLayerSANM(
  207. attention_dim + encoder_output_size,
  208. None,
  209. None,
  210. PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
  211. adim=attention_dim),
  212. dropout_rate,
  213. normalize_before,
  214. concat_after,
  215. ),
  216. )
  217. else:
  218. self.embed_concat_ffn = None
  219. self.concat_embeds = concat_embeds
  220. def forward(
  221. self,
  222. hs_pad: torch.Tensor,
  223. hlens: torch.Tensor,
  224. ys_in_pad: torch.Tensor,
  225. ys_in_lens: torch.Tensor,
  226. chunk_mask: torch.Tensor = None,
  227. pre_acoustic_embeds: torch.Tensor = None,
  228. ) -> Tuple[torch.Tensor, torch.Tensor]:
  229. """Forward decoder.
  230. Args:
  231. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  232. hlens: (batch)
  233. ys_in_pad:
  234. input token ids, int64 (batch, maxlen_out)
  235. if input_layer == "embed"
  236. input tensor (batch, maxlen_out, #mels) in the other cases
  237. ys_in_lens: (batch)
  238. Returns:
  239. (tuple): tuple containing:
  240. x: decoded token score before softmax (batch, maxlen_out, token)
  241. if use_output_layer is True,
  242. olens: (batch, )
  243. """
  244. tgt = ys_in_pad
  245. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  246. memory = hs_pad
  247. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  248. if chunk_mask is not None:
  249. memory_mask = memory_mask * chunk_mask
  250. if tgt_mask.size(1) != memory_mask.size(1):
  251. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  252. x = self.embed(tgt)
  253. if pre_acoustic_embeds is not None and self.concat_embeds:
  254. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  255. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  256. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  257. x, tgt_mask, memory, memory_mask
  258. )
  259. if self.decoders2 is not None:
  260. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  261. x, tgt_mask, memory, memory_mask
  262. )
  263. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  264. x, tgt_mask, memory, memory_mask
  265. )
  266. if self.normalize_before:
  267. x = self.after_norm(x)
  268. if self.output_layer is not None:
  269. x = self.output_layer(x)
  270. olens = tgt_mask.sum(1)
  271. return x, olens
  272. def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
  273. """Score."""
  274. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  275. logp, state = self.forward_one_step(
  276. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
  277. cache=state
  278. )
  279. return logp.squeeze(0), state
  280. def forward_one_step(
  281. self,
  282. tgt: torch.Tensor,
  283. tgt_mask: torch.Tensor,
  284. memory: torch.Tensor,
  285. memory_mask: torch.Tensor = None,
  286. pre_acoustic_embeds: torch.Tensor = None,
  287. cache: List[torch.Tensor] = None,
  288. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  289. """Forward one step.
  290. Args:
  291. tgt: input token ids, int64 (batch, maxlen_out)
  292. tgt_mask: input token mask, (batch, maxlen_out)
  293. dtype=torch.uint8 in PyTorch 1.2-
  294. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  295. memory: encoded memory, float32 (batch, maxlen_in, feat)
  296. cache: cached output list of (batch, max_time_out-1, size)
  297. Returns:
  298. y, cache: NN output value and cache per `self.decoders`.
  299. y.shape` is (batch, maxlen_out, token)
  300. """
  301. x = tgt[:, -1:]
  302. tgt_mask = None
  303. x = self.embed(x)
  304. if pre_acoustic_embeds is not None and self.concat_embeds:
  305. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  306. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  307. if cache is None:
  308. cache_layer_num = len(self.decoders)
  309. if self.decoders2 is not None:
  310. cache_layer_num += len(self.decoders2)
  311. cache = [None] * cache_layer_num
  312. new_cache = []
  313. # for c, decoder in zip(cache, self.decoders):
  314. for i in range(self.att_layer_num):
  315. decoder = self.decoders[i]
  316. c = cache[i]
  317. x, tgt_mask, memory, memory_mask, c_ret = decoder(
  318. x, tgt_mask, memory, memory_mask, cache=c
  319. )
  320. new_cache.append(c_ret)
  321. if self.num_blocks - self.att_layer_num >= 1:
  322. for i in range(self.num_blocks - self.att_layer_num):
  323. j = i + self.att_layer_num
  324. decoder = self.decoders2[i]
  325. c = cache[j]
  326. x, tgt_mask, memory, memory_mask, c_ret = decoder(
  327. x, tgt_mask, memory, memory_mask, cache=c
  328. )
  329. new_cache.append(c_ret)
  330. for decoder in self.decoders3:
  331. x, tgt_mask, memory, memory_mask, _ = decoder(
  332. x, tgt_mask, memory, None, cache=None
  333. )
  334. if self.normalize_before:
  335. y = self.after_norm(x[:, -1])
  336. else:
  337. y = x[:, -1]
  338. if self.output_layer is not None:
  339. y = self.output_layer(y)
  340. y = torch.log_softmax(y, dim=-1)
  341. return y, new_cache
  342. class ParaformerSANMDecoder(BaseTransformerDecoder):
  343. """
  344. author: Speech Lab, Alibaba Group, China
  345. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  346. https://arxiv.org/abs/2006.01713
  347. """
  348. def __init__(
  349. self,
  350. vocab_size: int,
  351. encoder_output_size: int,
  352. attention_heads: int = 4,
  353. linear_units: int = 2048,
  354. num_blocks: int = 6,
  355. dropout_rate: float = 0.1,
  356. positional_dropout_rate: float = 0.1,
  357. self_attention_dropout_rate: float = 0.0,
  358. src_attention_dropout_rate: float = 0.0,
  359. input_layer: str = "embed",
  360. use_output_layer: bool = True,
  361. pos_enc_class=PositionalEncoding,
  362. normalize_before: bool = True,
  363. concat_after: bool = False,
  364. att_layer_num: int = 6,
  365. kernel_size: int = 21,
  366. sanm_shfit: int = 0,
  367. ):
  368. assert check_argument_types()
  369. super().__init__(
  370. vocab_size=vocab_size,
  371. encoder_output_size=encoder_output_size,
  372. dropout_rate=dropout_rate,
  373. positional_dropout_rate=positional_dropout_rate,
  374. input_layer=input_layer,
  375. use_output_layer=use_output_layer,
  376. pos_enc_class=pos_enc_class,
  377. normalize_before=normalize_before,
  378. )
  379. attention_dim = encoder_output_size
  380. if input_layer == "embed":
  381. self.embed = torch.nn.Sequential(
  382. torch.nn.Embedding(vocab_size, attention_dim),
  383. # pos_enc_class(attention_dim, positional_dropout_rate),
  384. )
  385. elif input_layer == "linear":
  386. self.embed = torch.nn.Sequential(
  387. torch.nn.Linear(vocab_size, attention_dim),
  388. torch.nn.LayerNorm(attention_dim),
  389. torch.nn.Dropout(dropout_rate),
  390. torch.nn.ReLU(),
  391. pos_enc_class(attention_dim, positional_dropout_rate),
  392. )
  393. else:
  394. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  395. self.normalize_before = normalize_before
  396. if self.normalize_before:
  397. self.after_norm = LayerNorm(attention_dim)
  398. if use_output_layer:
  399. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  400. else:
  401. self.output_layer = None
  402. self.att_layer_num = att_layer_num
  403. self.num_blocks = num_blocks
  404. if sanm_shfit is None:
  405. sanm_shfit = (kernel_size - 1) // 2
  406. self.decoders = repeat(
  407. att_layer_num,
  408. lambda lnum: DecoderLayerSANM(
  409. attention_dim,
  410. MultiHeadedAttentionSANMDecoder(
  411. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  412. ),
  413. MultiHeadedAttentionCrossAtt(
  414. attention_heads, attention_dim, src_attention_dropout_rate
  415. ),
  416. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  417. dropout_rate,
  418. normalize_before,
  419. concat_after,
  420. ),
  421. )
  422. if num_blocks - att_layer_num <= 0:
  423. self.decoders2 = None
  424. else:
  425. self.decoders2 = repeat(
  426. num_blocks - att_layer_num,
  427. lambda lnum: DecoderLayerSANM(
  428. attention_dim,
  429. MultiHeadedAttentionSANMDecoder(
  430. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
  431. ),
  432. None,
  433. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  434. dropout_rate,
  435. normalize_before,
  436. concat_after,
  437. ),
  438. )
  439. self.decoders3 = repeat(
  440. 1,
  441. lambda lnum: DecoderLayerSANM(
  442. attention_dim,
  443. None,
  444. None,
  445. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  446. dropout_rate,
  447. normalize_before,
  448. concat_after,
  449. ),
  450. )
  451. def forward(
  452. self,
  453. hs_pad: torch.Tensor,
  454. hlens: torch.Tensor,
  455. ys_in_pad: torch.Tensor,
  456. ys_in_lens: torch.Tensor,
  457. ) -> Tuple[torch.Tensor, torch.Tensor]:
  458. """Forward decoder.
  459. Args:
  460. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  461. hlens: (batch)
  462. ys_in_pad:
  463. input token ids, int64 (batch, maxlen_out)
  464. if input_layer == "embed"
  465. input tensor (batch, maxlen_out, #mels) in the other cases
  466. ys_in_lens: (batch)
  467. Returns:
  468. (tuple): tuple containing:
  469. x: decoded token score before softmax (batch, maxlen_out, token)
  470. if use_output_layer is True,
  471. olens: (batch, )
  472. """
  473. tgt = ys_in_pad
  474. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  475. memory = hs_pad
  476. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  477. x = tgt
  478. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  479. x, tgt_mask, memory, memory_mask
  480. )
  481. if self.decoders2 is not None:
  482. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  483. x, tgt_mask, memory, memory_mask
  484. )
  485. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  486. x, tgt_mask, memory, memory_mask
  487. )
  488. if self.normalize_before:
  489. x = self.after_norm(x)
  490. if self.output_layer is not None:
  491. x = self.output_layer(x)
  492. olens = tgt_mask.sum(1)
  493. return x, olens
  494. def score(self, ys, state, x):
  495. """Score."""
  496. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  497. logp, state = self.forward_one_step(
  498. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  499. )
  500. return logp.squeeze(0), state
  501. def forward_one_step(
  502. self,
  503. tgt: torch.Tensor,
  504. tgt_mask: torch.Tensor,
  505. memory: torch.Tensor,
  506. cache: List[torch.Tensor] = None,
  507. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  508. """Forward one step.
  509. Args:
  510. tgt: input token ids, int64 (batch, maxlen_out)
  511. tgt_mask: input token mask, (batch, maxlen_out)
  512. dtype=torch.uint8 in PyTorch 1.2-
  513. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  514. memory: encoded memory, float32 (batch, maxlen_in, feat)
  515. cache: cached output list of (batch, max_time_out-1, size)
  516. Returns:
  517. y, cache: NN output value and cache per `self.decoders`.
  518. y.shape` is (batch, maxlen_out, token)
  519. """
  520. x = self.embed(tgt)
  521. if cache is None:
  522. cache_layer_num = len(self.decoders)
  523. if self.decoders2 is not None:
  524. cache_layer_num += len(self.decoders2)
  525. cache = [None] * cache_layer_num
  526. new_cache = []
  527. # for c, decoder in zip(cache, self.decoders):
  528. for i in range(self.att_layer_num):
  529. decoder = self.decoders[i]
  530. c = cache[i]
  531. x, tgt_mask, memory, memory_mask, c_ret = decoder(
  532. x, tgt_mask, memory, None, cache=c
  533. )
  534. new_cache.append(c_ret)
  535. if self.num_blocks - self.att_layer_num > 1:
  536. for i in range(self.num_blocks - self.att_layer_num):
  537. j = i + self.att_layer_num
  538. decoder = self.decoders2[i]
  539. c = cache[j]
  540. x, tgt_mask, memory, memory_mask, c_ret = decoder(
  541. x, tgt_mask, memory, None, cache=c
  542. )
  543. new_cache.append(c_ret)
  544. for decoder in self.decoders3:
  545. x, tgt_mask, memory, memory_mask, _ = decoder(
  546. x, tgt_mask, memory, None, cache=None
  547. )
  548. if self.normalize_before:
  549. y = self.after_norm(x[:, -1])
  550. else:
  551. y = x[:, -1]
  552. if self.output_layer is not None:
  553. y = torch.log_softmax(self.output_layer(y), dim=-1)
  554. return y, new_cache