decoder.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. from typing import List
  2. from typing import Tuple
  3. import logging
  4. import torch
  5. import torch.nn as nn
  6. import numpy as np
  7. from funasr.models.scama import utils as myutils
  8. from funasr.models.transformer.decoder import BaseTransformerDecoder
  9. from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
  10. from funasr.models.transformer.layer_norm import LayerNorm
  11. from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  12. from funasr.models.transformer.utils.repeat import repeat
  13. from funasr.models.transformer.decoder import DecoderLayer
  14. from funasr.models.transformer.attention import MultiHeadedAttention
  15. from funasr.models.transformer.embedding import PositionalEncoding
  16. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  17. from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
  18. from funasr.utils.register import register_class, registry_tables
  19. class DecoderLayerSANM(nn.Module):
  20. """Single decoder layer module.
  21. Args:
  22. size (int): Input dimension.
  23. self_attn (torch.nn.Module): Self-attention module instance.
  24. `MultiHeadedAttention` instance can be used as the argument.
  25. src_attn (torch.nn.Module): Self-attention module instance.
  26. `MultiHeadedAttention` instance can be used as the argument.
  27. feed_forward (torch.nn.Module): Feed-forward module instance.
  28. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  29. can be used as the argument.
  30. dropout_rate (float): Dropout rate.
  31. normalize_before (bool): Whether to use layer_norm before the first block.
  32. concat_after (bool): Whether to concat attention layer's input and output.
  33. if True, additional linear will be applied.
  34. i.e. x -> x + linear(concat(x, att(x)))
  35. if False, no additional linear will be applied. i.e. x -> x + att(x)
  36. """
  37. def __init__(
  38. self,
  39. size,
  40. self_attn,
  41. src_attn,
  42. feed_forward,
  43. dropout_rate,
  44. normalize_before=True,
  45. concat_after=False,
  46. ):
  47. """Construct an DecoderLayer object."""
  48. super(DecoderLayerSANM, self).__init__()
  49. self.size = size
  50. self.self_attn = self_attn
  51. self.src_attn = src_attn
  52. self.feed_forward = feed_forward
  53. self.norm1 = LayerNorm(size)
  54. if self_attn is not None:
  55. self.norm2 = LayerNorm(size)
  56. if src_attn is not None:
  57. self.norm3 = LayerNorm(size)
  58. self.dropout = nn.Dropout(dropout_rate)
  59. self.normalize_before = normalize_before
  60. self.concat_after = concat_after
  61. if self.concat_after:
  62. self.concat_linear1 = nn.Linear(size + size, size)
  63. self.concat_linear2 = nn.Linear(size + size, size)
  64. self.reserve_attn=False
  65. self.attn_mat = []
  66. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  67. """Compute decoded features.
  68. Args:
  69. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  70. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  71. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  72. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  73. cache (List[torch.Tensor]): List of cached tensors.
  74. Each tensor shape should be (#batch, maxlen_out - 1, size).
  75. Returns:
  76. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  77. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  78. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  79. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  80. """
  81. # tgt = self.dropout(tgt)
  82. residual = tgt
  83. if self.normalize_before:
  84. tgt = self.norm1(tgt)
  85. tgt = self.feed_forward(tgt)
  86. x = tgt
  87. if self.self_attn:
  88. if self.normalize_before:
  89. tgt = self.norm2(tgt)
  90. x, _ = self.self_attn(tgt, tgt_mask)
  91. x = residual + self.dropout(x)
  92. if self.src_attn is not None:
  93. residual = x
  94. if self.normalize_before:
  95. x = self.norm3(x)
  96. if self.reserve_attn:
  97. x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
  98. self.attn_mat.append(attn_mat)
  99. else:
  100. x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False)
  101. x = residual + self.dropout(x_src_attn)
  102. # x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  103. return x, tgt_mask, memory, memory_mask, cache
  104. def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  105. """Compute decoded features.
  106. Args:
  107. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  108. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  109. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  110. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  111. cache (List[torch.Tensor]): List of cached tensors.
  112. Each tensor shape should be (#batch, maxlen_out - 1, size).
  113. Returns:
  114. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  115. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  116. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  117. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  118. """
  119. # tgt = self.dropout(tgt)
  120. residual = tgt
  121. if self.normalize_before:
  122. tgt = self.norm1(tgt)
  123. tgt = self.feed_forward(tgt)
  124. x = tgt
  125. if self.self_attn:
  126. if self.normalize_before:
  127. tgt = self.norm2(tgt)
  128. if self.training:
  129. cache = None
  130. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  131. x = residual + self.dropout(x)
  132. if self.src_attn is not None:
  133. residual = x
  134. if self.normalize_before:
  135. x = self.norm3(x)
  136. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  137. return x, tgt_mask, memory, memory_mask, cache
  138. def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
  139. """Compute decoded features.
  140. Args:
  141. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  142. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  143. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  144. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  145. cache (List[torch.Tensor]): List of cached tensors.
  146. Each tensor shape should be (#batch, maxlen_out - 1, size).
  147. Returns:
  148. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  149. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  150. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  151. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  152. """
  153. residual = tgt
  154. if self.normalize_before:
  155. tgt = self.norm1(tgt)
  156. tgt = self.feed_forward(tgt)
  157. x = tgt
  158. if self.self_attn:
  159. if self.normalize_before:
  160. tgt = self.norm2(tgt)
  161. x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
  162. x = residual + self.dropout(x)
  163. if self.src_attn is not None:
  164. residual = x
  165. if self.normalize_before:
  166. x = self.norm3(x)
  167. x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
  168. x = residual + x
  169. return x, memory, fsmn_cache, opt_cache
  170. @register_class("decoder_classes", "ParaformerSANMDecoder")
  171. class ParaformerSANMDecoder(BaseTransformerDecoder):
  172. """
  173. Author: Speech Lab of DAMO Academy, Alibaba Group
  174. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  175. https://arxiv.org/abs/2006.01713
  176. """
  177. def __init__(
  178. self,
  179. vocab_size: int,
  180. encoder_output_size: int,
  181. attention_heads: int = 4,
  182. linear_units: int = 2048,
  183. num_blocks: int = 6,
  184. dropout_rate: float = 0.1,
  185. positional_dropout_rate: float = 0.1,
  186. self_attention_dropout_rate: float = 0.0,
  187. src_attention_dropout_rate: float = 0.0,
  188. input_layer: str = "embed",
  189. use_output_layer: bool = True,
  190. wo_input_layer: bool = False,
  191. pos_enc_class=PositionalEncoding,
  192. normalize_before: bool = True,
  193. concat_after: bool = False,
  194. att_layer_num: int = 6,
  195. kernel_size: int = 21,
  196. sanm_shfit: int = 0,
  197. lora_list: List[str] = None,
  198. lora_rank: int = 8,
  199. lora_alpha: int = 16,
  200. lora_dropout: float = 0.1,
  201. chunk_multiply_factor: tuple = (1,),
  202. tf2torch_tensor_name_prefix_torch: str = "decoder",
  203. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  204. ):
  205. super().__init__(
  206. vocab_size=vocab_size,
  207. encoder_output_size=encoder_output_size,
  208. dropout_rate=dropout_rate,
  209. positional_dropout_rate=positional_dropout_rate,
  210. input_layer=input_layer,
  211. use_output_layer=use_output_layer,
  212. pos_enc_class=pos_enc_class,
  213. normalize_before=normalize_before,
  214. )
  215. attention_dim = encoder_output_size
  216. if wo_input_layer:
  217. self.embed = None
  218. else:
  219. if input_layer == "embed":
  220. self.embed = torch.nn.Sequential(
  221. torch.nn.Embedding(vocab_size, attention_dim),
  222. # pos_enc_class(attention_dim, positional_dropout_rate),
  223. )
  224. elif input_layer == "linear":
  225. self.embed = torch.nn.Sequential(
  226. torch.nn.Linear(vocab_size, attention_dim),
  227. torch.nn.LayerNorm(attention_dim),
  228. torch.nn.Dropout(dropout_rate),
  229. torch.nn.ReLU(),
  230. pos_enc_class(attention_dim, positional_dropout_rate),
  231. )
  232. else:
  233. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  234. self.normalize_before = normalize_before
  235. if self.normalize_before:
  236. self.after_norm = LayerNorm(attention_dim)
  237. if use_output_layer:
  238. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  239. else:
  240. self.output_layer = None
  241. self.att_layer_num = att_layer_num
  242. self.num_blocks = num_blocks
  243. if sanm_shfit is None:
  244. sanm_shfit = (kernel_size - 1) // 2
  245. self.decoders = repeat(
  246. att_layer_num,
  247. lambda lnum: DecoderLayerSANM(
  248. attention_dim,
  249. MultiHeadedAttentionSANMDecoder(
  250. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  251. ),
  252. MultiHeadedAttentionCrossAtt(
  253. attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
  254. ),
  255. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  256. dropout_rate,
  257. normalize_before,
  258. concat_after,
  259. ),
  260. )
  261. if num_blocks - att_layer_num <= 0:
  262. self.decoders2 = None
  263. else:
  264. self.decoders2 = repeat(
  265. num_blocks - att_layer_num,
  266. lambda lnum: DecoderLayerSANM(
  267. attention_dim,
  268. MultiHeadedAttentionSANMDecoder(
  269. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
  270. ),
  271. None,
  272. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  273. dropout_rate,
  274. normalize_before,
  275. concat_after,
  276. ),
  277. )
  278. self.decoders3 = repeat(
  279. 1,
  280. lambda lnum: DecoderLayerSANM(
  281. attention_dim,
  282. None,
  283. None,
  284. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  285. dropout_rate,
  286. normalize_before,
  287. concat_after,
  288. ),
  289. )
  290. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  291. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  292. self.chunk_multiply_factor = chunk_multiply_factor
  293. def forward(
  294. self,
  295. hs_pad: torch.Tensor,
  296. hlens: torch.Tensor,
  297. ys_in_pad: torch.Tensor,
  298. ys_in_lens: torch.Tensor,
  299. return_hidden: bool = False,
  300. return_both: bool= False,
  301. chunk_mask: torch.Tensor = None,
  302. ) -> Tuple[torch.Tensor, torch.Tensor]:
  303. """Forward decoder.
  304. Args:
  305. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  306. hlens: (batch)
  307. ys_in_pad:
  308. input token ids, int64 (batch, maxlen_out)
  309. if input_layer == "embed"
  310. input tensor (batch, maxlen_out, #mels) in the other cases
  311. ys_in_lens: (batch)
  312. Returns:
  313. (tuple): tuple containing:
  314. x: decoded token score before softmax (batch, maxlen_out, token)
  315. if use_output_layer is True,
  316. olens: (batch, )
  317. """
  318. tgt = ys_in_pad
  319. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  320. memory = hs_pad
  321. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  322. if chunk_mask is not None:
  323. memory_mask = memory_mask * chunk_mask
  324. if tgt_mask.size(1) != memory_mask.size(1):
  325. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  326. x = tgt
  327. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  328. x, tgt_mask, memory, memory_mask
  329. )
  330. if self.decoders2 is not None:
  331. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  332. x, tgt_mask, memory, memory_mask
  333. )
  334. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  335. x, tgt_mask, memory, memory_mask
  336. )
  337. if self.normalize_before:
  338. hidden = self.after_norm(x)
  339. olens = tgt_mask.sum(1)
  340. if self.output_layer is not None and return_hidden is False:
  341. x = self.output_layer(hidden)
  342. return x, olens
  343. if return_both:
  344. x = self.output_layer(hidden)
  345. return x, hidden, olens
  346. return hidden, olens
  347. def score(self, ys, state, x):
  348. """Score."""
  349. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  350. logp, state = self.forward_one_step(
  351. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  352. )
  353. return logp.squeeze(0), state
  354. def forward_chunk(
  355. self,
  356. memory: torch.Tensor,
  357. tgt: torch.Tensor,
  358. cache: dict = None,
  359. ) -> Tuple[torch.Tensor, torch.Tensor]:
  360. """Forward decoder.
  361. Args:
  362. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  363. hlens: (batch)
  364. ys_in_pad:
  365. input token ids, int64 (batch, maxlen_out)
  366. if input_layer == "embed"
  367. input tensor (batch, maxlen_out, #mels) in the other cases
  368. ys_in_lens: (batch)
  369. Returns:
  370. (tuple): tuple containing:
  371. x: decoded token score before softmax (batch, maxlen_out, token)
  372. if use_output_layer is True,
  373. olens: (batch, )
  374. """
  375. x = tgt
  376. if cache["decode_fsmn"] is None:
  377. cache_layer_num = len(self.decoders)
  378. if self.decoders2 is not None:
  379. cache_layer_num += len(self.decoders2)
  380. fsmn_cache = [None] * cache_layer_num
  381. else:
  382. fsmn_cache = cache["decode_fsmn"]
  383. if cache["opt"] is None:
  384. cache_layer_num = len(self.decoders)
  385. opt_cache = [None] * cache_layer_num
  386. else:
  387. opt_cache = cache["opt"]
  388. for i in range(self.att_layer_num):
  389. decoder = self.decoders[i]
  390. x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
  391. x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
  392. chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
  393. )
  394. if self.num_blocks - self.att_layer_num > 1:
  395. for i in range(self.num_blocks - self.att_layer_num):
  396. j = i + self.att_layer_num
  397. decoder = self.decoders2[i]
  398. x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
  399. x, memory, fsmn_cache=fsmn_cache[j]
  400. )
  401. for decoder in self.decoders3:
  402. x, memory, _, _ = decoder.forward_chunk(
  403. x, memory
  404. )
  405. if self.normalize_before:
  406. x = self.after_norm(x)
  407. if self.output_layer is not None:
  408. x = self.output_layer(x)
  409. cache["decode_fsmn"] = fsmn_cache
  410. if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
  411. cache["opt"] = opt_cache
  412. return x
  413. def forward_one_step(
  414. self,
  415. tgt: torch.Tensor,
  416. tgt_mask: torch.Tensor,
  417. memory: torch.Tensor,
  418. cache: List[torch.Tensor] = None,
  419. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  420. """Forward one step.
  421. Args:
  422. tgt: input token ids, int64 (batch, maxlen_out)
  423. tgt_mask: input token mask, (batch, maxlen_out)
  424. dtype=torch.uint8 in PyTorch 1.2-
  425. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  426. memory: encoded memory, float32 (batch, maxlen_in, feat)
  427. cache: cached output list of (batch, max_time_out-1, size)
  428. Returns:
  429. y, cache: NN output value and cache per `self.decoders`.
  430. y.shape` is (batch, maxlen_out, token)
  431. """
  432. x = self.embed(tgt)
  433. if cache is None:
  434. cache_layer_num = len(self.decoders)
  435. if self.decoders2 is not None:
  436. cache_layer_num += len(self.decoders2)
  437. cache = [None] * cache_layer_num
  438. new_cache = []
  439. # for c, decoder in zip(cache, self.decoders):
  440. for i in range(self.att_layer_num):
  441. decoder = self.decoders[i]
  442. c = cache[i]
  443. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  444. x, tgt_mask, memory, None, cache=c
  445. )
  446. new_cache.append(c_ret)
  447. if self.num_blocks - self.att_layer_num > 1:
  448. for i in range(self.num_blocks - self.att_layer_num):
  449. j = i + self.att_layer_num
  450. decoder = self.decoders2[i]
  451. c = cache[j]
  452. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  453. x, tgt_mask, memory, None, cache=c
  454. )
  455. new_cache.append(c_ret)
  456. for decoder in self.decoders3:
  457. x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
  458. x, tgt_mask, memory, None, cache=None
  459. )
  460. if self.normalize_before:
  461. y = self.after_norm(x[:, -1])
  462. else:
  463. y = x[:, -1]
  464. if self.output_layer is not None:
  465. y = torch.log_softmax(self.output_layer(y), dim=-1)
  466. return y, new_cache
  467. @register_class("decoder_classes", "ParaformerDecoderSAN")
  468. class ParaformerDecoderSAN(BaseTransformerDecoder):
  469. """
  470. Author: Speech Lab of DAMO Academy, Alibaba Group
  471. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  472. https://arxiv.org/abs/2006.01713
  473. """
  474. def __init__(
  475. self,
  476. vocab_size: int,
  477. encoder_output_size: int,
  478. attention_heads: int = 4,
  479. linear_units: int = 2048,
  480. num_blocks: int = 6,
  481. dropout_rate: float = 0.1,
  482. positional_dropout_rate: float = 0.1,
  483. self_attention_dropout_rate: float = 0.0,
  484. src_attention_dropout_rate: float = 0.0,
  485. input_layer: str = "embed",
  486. use_output_layer: bool = True,
  487. pos_enc_class=PositionalEncoding,
  488. normalize_before: bool = True,
  489. concat_after: bool = False,
  490. embeds_id: int = -1,
  491. ):
  492. super().__init__(
  493. vocab_size=vocab_size,
  494. encoder_output_size=encoder_output_size,
  495. dropout_rate=dropout_rate,
  496. positional_dropout_rate=positional_dropout_rate,
  497. input_layer=input_layer,
  498. use_output_layer=use_output_layer,
  499. pos_enc_class=pos_enc_class,
  500. normalize_before=normalize_before,
  501. )
  502. attention_dim = encoder_output_size
  503. self.decoders = repeat(
  504. num_blocks,
  505. lambda lnum: DecoderLayer(
  506. attention_dim,
  507. MultiHeadedAttention(
  508. attention_heads, attention_dim, self_attention_dropout_rate
  509. ),
  510. MultiHeadedAttention(
  511. attention_heads, attention_dim, src_attention_dropout_rate
  512. ),
  513. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  514. dropout_rate,
  515. normalize_before,
  516. concat_after,
  517. ),
  518. )
  519. self.embeds_id = embeds_id
  520. self.attention_dim = attention_dim
  521. def forward(
  522. self,
  523. hs_pad: torch.Tensor,
  524. hlens: torch.Tensor,
  525. ys_in_pad: torch.Tensor,
  526. ys_in_lens: torch.Tensor,
  527. ) -> Tuple[torch.Tensor, torch.Tensor]:
  528. """Forward decoder.
  529. Args:
  530. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  531. hlens: (batch)
  532. ys_in_pad:
  533. input token ids, int64 (batch, maxlen_out)
  534. if input_layer == "embed"
  535. input tensor (batch, maxlen_out, #mels) in the other cases
  536. ys_in_lens: (batch)
  537. Returns:
  538. (tuple): tuple containing:
  539. x: decoded token score before softmax (batch, maxlen_out, token)
  540. if use_output_layer is True,
  541. olens: (batch, )
  542. """
  543. tgt = ys_in_pad
  544. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  545. memory = hs_pad
  546. memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
  547. memory.device
  548. )
  549. # Padding for Longformer
  550. if memory_mask.shape[-1] != memory.shape[1]:
  551. padlen = memory.shape[1] - memory_mask.shape[-1]
  552. memory_mask = torch.nn.functional.pad(
  553. memory_mask, (0, padlen), "constant", False
  554. )
  555. # x = self.embed(tgt)
  556. x = tgt
  557. embeds_outputs = None
  558. for layer_id, decoder in enumerate(self.decoders):
  559. x, tgt_mask, memory, memory_mask = decoder(
  560. x, tgt_mask, memory, memory_mask
  561. )
  562. if layer_id == self.embeds_id:
  563. embeds_outputs = x
  564. if self.normalize_before:
  565. x = self.after_norm(x)
  566. if self.output_layer is not None:
  567. x = self.output_layer(x)
  568. olens = tgt_mask.sum(1)
  569. if embeds_outputs is not None:
  570. return x, olens, embeds_outputs
  571. else:
  572. return x, olens