decoder.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. # Copyright 2019 Shigeki Karita
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Decoder definition."""
  4. from typing import Any
  5. from typing import List
  6. from typing import Sequence
  7. from typing import Tuple
  8. import torch
  9. from torch import nn
  10. from funasr.models.transformer.attention import MultiHeadedAttention
  11. from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
  12. from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
  13. from funasr.models.transformer.embedding import PositionalEncoding
  14. from funasr.models.transformer.layer_norm import LayerNorm
  15. from funasr.models.transformer.utils.lightconv import LightweightConvolution
  16. from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
  17. from funasr.models.transformer.utils.mask import subsequent_mask
  18. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  19. from funasr.models.transformer.positionwise_feed_forward import (
  20. PositionwiseFeedForward, # noqa: H301
  21. )
  22. from funasr.models.transformer.utils.repeat import repeat
  23. from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
  24. from funasr.register import tables
  25. class DecoderLayer(nn.Module):
  26. """Single decoder layer module.
  27. Args:
  28. size (int): Input dimension.
  29. self_attn (torch.nn.Module): Self-attention module instance.
  30. `MultiHeadedAttention` instance can be used as the argument.
  31. src_attn (torch.nn.Module): Self-attention module instance.
  32. `MultiHeadedAttention` instance can be used as the argument.
  33. feed_forward (torch.nn.Module): Feed-forward module instance.
  34. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  35. can be used as the argument.
  36. dropout_rate (float): Dropout rate.
  37. normalize_before (bool): Whether to use layer_norm before the first block.
  38. concat_after (bool): Whether to concat attention layer's input and output.
  39. if True, additional linear will be applied.
  40. i.e. x -> x + linear(concat(x, att(x)))
  41. if False, no additional linear will be applied. i.e. x -> x + att(x)
  42. """
  43. def __init__(
  44. self,
  45. size,
  46. self_attn,
  47. src_attn,
  48. feed_forward,
  49. dropout_rate,
  50. normalize_before=True,
  51. concat_after=False,
  52. ):
  53. """Construct an DecoderLayer object."""
  54. super(DecoderLayer, self).__init__()
  55. self.size = size
  56. self.self_attn = self_attn
  57. self.src_attn = src_attn
  58. self.feed_forward = feed_forward
  59. self.norm1 = LayerNorm(size)
  60. self.norm2 = LayerNorm(size)
  61. self.norm3 = LayerNorm(size)
  62. self.dropout = nn.Dropout(dropout_rate)
  63. self.normalize_before = normalize_before
  64. self.concat_after = concat_after
  65. if self.concat_after:
  66. self.concat_linear1 = nn.Linear(size + size, size)
  67. self.concat_linear2 = nn.Linear(size + size, size)
  68. def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
  69. """Compute decoded features.
  70. Args:
  71. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  72. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  73. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  74. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  75. cache (List[torch.Tensor]): List of cached tensors.
  76. Each tensor shape should be (#batch, maxlen_out - 1, size).
  77. Returns:
  78. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  79. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  80. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  81. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  82. """
  83. residual = tgt
  84. if self.normalize_before:
  85. tgt = self.norm1(tgt)
  86. if cache is None:
  87. tgt_q = tgt
  88. tgt_q_mask = tgt_mask
  89. else:
  90. # compute only the last frame query keeping dim: max_time_out -> 1
  91. assert cache.shape == (
  92. tgt.shape[0],
  93. tgt.shape[1] - 1,
  94. self.size,
  95. ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
  96. tgt_q = tgt[:, -1:, :]
  97. residual = residual[:, -1:, :]
  98. tgt_q_mask = None
  99. if tgt_mask is not None:
  100. tgt_q_mask = tgt_mask[:, -1:, :]
  101. if self.concat_after:
  102. tgt_concat = torch.cat(
  103. (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
  104. )
  105. x = residual + self.concat_linear1(tgt_concat)
  106. else:
  107. x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
  108. if not self.normalize_before:
  109. x = self.norm1(x)
  110. residual = x
  111. if self.normalize_before:
  112. x = self.norm2(x)
  113. if self.concat_after:
  114. x_concat = torch.cat(
  115. (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
  116. )
  117. x = residual + self.concat_linear2(x_concat)
  118. else:
  119. x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
  120. if not self.normalize_before:
  121. x = self.norm2(x)
  122. residual = x
  123. if self.normalize_before:
  124. x = self.norm3(x)
  125. x = residual + self.dropout(self.feed_forward(x))
  126. if not self.normalize_before:
  127. x = self.norm3(x)
  128. if cache is not None:
  129. x = torch.cat([cache, x], dim=1)
  130. return x, tgt_mask, memory, memory_mask
  131. class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
  132. """Base class of Transfomer decoder module.
  133. Args:
  134. vocab_size: output dim
  135. encoder_output_size: dimension of attention
  136. attention_heads: the number of heads of multi head attention
  137. linear_units: the number of units of position-wise feed forward
  138. num_blocks: the number of decoder blocks
  139. dropout_rate: dropout rate
  140. self_attention_dropout_rate: dropout rate for attention
  141. input_layer: input layer type
  142. use_output_layer: whether to use output layer
  143. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  144. normalize_before: whether to use layer_norm before the first block
  145. concat_after: whether to concat attention layer's input and output
  146. if True, additional linear will be applied.
  147. i.e. x -> x + linear(concat(x, att(x)))
  148. if False, no additional linear will be applied.
  149. i.e. x -> x + att(x)
  150. """
  151. def __init__(
  152. self,
  153. vocab_size: int,
  154. encoder_output_size: int,
  155. dropout_rate: float = 0.1,
  156. positional_dropout_rate: float = 0.1,
  157. input_layer: str = "embed",
  158. use_output_layer: bool = True,
  159. pos_enc_class=PositionalEncoding,
  160. normalize_before: bool = True,
  161. ):
  162. super().__init__()
  163. attention_dim = encoder_output_size
  164. if input_layer == "embed":
  165. self.embed = torch.nn.Sequential(
  166. torch.nn.Embedding(vocab_size, attention_dim),
  167. pos_enc_class(attention_dim, positional_dropout_rate),
  168. )
  169. elif input_layer == "linear":
  170. self.embed = torch.nn.Sequential(
  171. torch.nn.Linear(vocab_size, attention_dim),
  172. torch.nn.LayerNorm(attention_dim),
  173. torch.nn.Dropout(dropout_rate),
  174. torch.nn.ReLU(),
  175. pos_enc_class(attention_dim, positional_dropout_rate),
  176. )
  177. else:
  178. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  179. self.normalize_before = normalize_before
  180. if self.normalize_before:
  181. self.after_norm = LayerNorm(attention_dim)
  182. if use_output_layer:
  183. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  184. else:
  185. self.output_layer = None
  186. # Must set by the inheritance
  187. self.decoders = None
  188. def forward(
  189. self,
  190. hs_pad: torch.Tensor,
  191. hlens: torch.Tensor,
  192. ys_in_pad: torch.Tensor,
  193. ys_in_lens: torch.Tensor,
  194. ) -> Tuple[torch.Tensor, torch.Tensor]:
  195. """Forward decoder.
  196. Args:
  197. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  198. hlens: (batch)
  199. ys_in_pad:
  200. input token ids, int64 (batch, maxlen_out)
  201. if input_layer == "embed"
  202. input tensor (batch, maxlen_out, #mels) in the other cases
  203. ys_in_lens: (batch)
  204. Returns:
  205. (tuple): tuple containing:
  206. x: decoded token score before softmax (batch, maxlen_out, token)
  207. if use_output_layer is True,
  208. olens: (batch, )
  209. """
  210. tgt = ys_in_pad
  211. # tgt_mask: (B, 1, L)
  212. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  213. # m: (1, L, L)
  214. m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
  215. # tgt_mask: (B, L, L)
  216. tgt_mask = tgt_mask & m
  217. memory = hs_pad
  218. memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
  219. memory.device
  220. )
  221. # Padding for Longformer
  222. if memory_mask.shape[-1] != memory.shape[1]:
  223. padlen = memory.shape[1] - memory_mask.shape[-1]
  224. memory_mask = torch.nn.functional.pad(
  225. memory_mask, (0, padlen), "constant", False
  226. )
  227. x = self.embed(tgt)
  228. x, tgt_mask, memory, memory_mask = self.decoders(
  229. x, tgt_mask, memory, memory_mask
  230. )
  231. if self.normalize_before:
  232. x = self.after_norm(x)
  233. if self.output_layer is not None:
  234. x = self.output_layer(x)
  235. olens = tgt_mask.sum(1)
  236. return x, olens
  237. def forward_one_step(
  238. self,
  239. tgt: torch.Tensor,
  240. tgt_mask: torch.Tensor,
  241. memory: torch.Tensor,
  242. cache: List[torch.Tensor] = None,
  243. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  244. """Forward one step.
  245. Args:
  246. tgt: input token ids, int64 (batch, maxlen_out)
  247. tgt_mask: input token mask, (batch, maxlen_out)
  248. dtype=torch.uint8 in PyTorch 1.2-
  249. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  250. memory: encoded memory, float32 (batch, maxlen_in, feat)
  251. cache: cached output list of (batch, max_time_out-1, size)
  252. Returns:
  253. y, cache: NN output value and cache per `self.decoders`.
  254. y.shape` is (batch, maxlen_out, token)
  255. """
  256. x = self.embed(tgt)
  257. if cache is None:
  258. cache = [None] * len(self.decoders)
  259. new_cache = []
  260. for c, decoder in zip(cache, self.decoders):
  261. x, tgt_mask, memory, memory_mask = decoder(
  262. x, tgt_mask, memory, None, cache=c
  263. )
  264. new_cache.append(x)
  265. if self.normalize_before:
  266. y = self.after_norm(x[:, -1])
  267. else:
  268. y = x[:, -1]
  269. if self.output_layer is not None:
  270. y = torch.log_softmax(self.output_layer(y), dim=-1)
  271. return y, new_cache
  272. def score(self, ys, state, x):
  273. """Score."""
  274. ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
  275. logp, state = self.forward_one_step(
  276. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  277. )
  278. return logp.squeeze(0), state
  279. def batch_score(
  280. self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
  281. ) -> Tuple[torch.Tensor, List[Any]]:
  282. """Score new token batch.
  283. Args:
  284. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  285. states (List[Any]): Scorer states for prefix tokens.
  286. xs (torch.Tensor):
  287. The encoder feature that generates ys (n_batch, xlen, n_feat).
  288. Returns:
  289. tuple[torch.Tensor, List[Any]]: Tuple of
  290. batchfied scores for next token with shape of `(n_batch, n_vocab)`
  291. and next state list for ys.
  292. """
  293. # merge states
  294. n_batch = len(ys)
  295. n_layers = len(self.decoders)
  296. if states[0] is None:
  297. batch_state = None
  298. else:
  299. # transpose state of [batch, layer] into [layer, batch]
  300. batch_state = [
  301. torch.stack([states[b][i] for b in range(n_batch)])
  302. for i in range(n_layers)
  303. ]
  304. # batch decoding
  305. ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
  306. logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
  307. # transpose state of [layer, batch] into [batch, layer]
  308. state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
  309. return logp, state_list
  310. @tables.register("decoder_classes", "TransformerDecoder")
  311. class TransformerDecoder(BaseTransformerDecoder):
  312. def __init__(
  313. self,
  314. vocab_size: int,
  315. encoder_output_size: int,
  316. attention_heads: int = 4,
  317. linear_units: int = 2048,
  318. num_blocks: int = 6,
  319. dropout_rate: float = 0.1,
  320. positional_dropout_rate: float = 0.1,
  321. self_attention_dropout_rate: float = 0.0,
  322. src_attention_dropout_rate: float = 0.0,
  323. input_layer: str = "embed",
  324. use_output_layer: bool = True,
  325. pos_enc_class=PositionalEncoding,
  326. normalize_before: bool = True,
  327. concat_after: bool = False,
  328. ):
  329. super().__init__(
  330. vocab_size=vocab_size,
  331. encoder_output_size=encoder_output_size,
  332. dropout_rate=dropout_rate,
  333. positional_dropout_rate=positional_dropout_rate,
  334. input_layer=input_layer,
  335. use_output_layer=use_output_layer,
  336. pos_enc_class=pos_enc_class,
  337. normalize_before=normalize_before,
  338. )
  339. attention_dim = encoder_output_size
  340. self.decoders = repeat(
  341. num_blocks,
  342. lambda lnum: DecoderLayer(
  343. attention_dim,
  344. MultiHeadedAttention(
  345. attention_heads, attention_dim, self_attention_dropout_rate
  346. ),
  347. MultiHeadedAttention(
  348. attention_heads, attention_dim, src_attention_dropout_rate
  349. ),
  350. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  351. dropout_rate,
  352. normalize_before,
  353. concat_after,
  354. ),
  355. )
  356. @tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder")
  357. class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
  358. def __init__(
  359. self,
  360. vocab_size: int,
  361. encoder_output_size: int,
  362. attention_heads: int = 4,
  363. linear_units: int = 2048,
  364. num_blocks: int = 6,
  365. dropout_rate: float = 0.1,
  366. positional_dropout_rate: float = 0.1,
  367. self_attention_dropout_rate: float = 0.0,
  368. src_attention_dropout_rate: float = 0.0,
  369. input_layer: str = "embed",
  370. use_output_layer: bool = True,
  371. pos_enc_class=PositionalEncoding,
  372. normalize_before: bool = True,
  373. concat_after: bool = False,
  374. conv_wshare: int = 4,
  375. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  376. conv_usebias: int = False,
  377. ):
  378. if len(conv_kernel_length) != num_blocks:
  379. raise ValueError(
  380. "conv_kernel_length must have equal number of values to num_blocks: "
  381. f"{len(conv_kernel_length)} != {num_blocks}"
  382. )
  383. super().__init__(
  384. vocab_size=vocab_size,
  385. encoder_output_size=encoder_output_size,
  386. dropout_rate=dropout_rate,
  387. positional_dropout_rate=positional_dropout_rate,
  388. input_layer=input_layer,
  389. use_output_layer=use_output_layer,
  390. pos_enc_class=pos_enc_class,
  391. normalize_before=normalize_before,
  392. )
  393. attention_dim = encoder_output_size
  394. self.decoders = repeat(
  395. num_blocks,
  396. lambda lnum: DecoderLayer(
  397. attention_dim,
  398. LightweightConvolution(
  399. wshare=conv_wshare,
  400. n_feat=attention_dim,
  401. dropout_rate=self_attention_dropout_rate,
  402. kernel_size=conv_kernel_length[lnum],
  403. use_kernel_mask=True,
  404. use_bias=conv_usebias,
  405. ),
  406. MultiHeadedAttention(
  407. attention_heads, attention_dim, src_attention_dropout_rate
  408. ),
  409. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  410. dropout_rate,
  411. normalize_before,
  412. concat_after,
  413. ),
  414. )
  415. @tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder")
  416. class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
  417. def __init__(
  418. self,
  419. vocab_size: int,
  420. encoder_output_size: int,
  421. attention_heads: int = 4,
  422. linear_units: int = 2048,
  423. num_blocks: int = 6,
  424. dropout_rate: float = 0.1,
  425. positional_dropout_rate: float = 0.1,
  426. self_attention_dropout_rate: float = 0.0,
  427. src_attention_dropout_rate: float = 0.0,
  428. input_layer: str = "embed",
  429. use_output_layer: bool = True,
  430. pos_enc_class=PositionalEncoding,
  431. normalize_before: bool = True,
  432. concat_after: bool = False,
  433. conv_wshare: int = 4,
  434. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  435. conv_usebias: int = False,
  436. ):
  437. if len(conv_kernel_length) != num_blocks:
  438. raise ValueError(
  439. "conv_kernel_length must have equal number of values to num_blocks: "
  440. f"{len(conv_kernel_length)} != {num_blocks}"
  441. )
  442. super().__init__(
  443. vocab_size=vocab_size,
  444. encoder_output_size=encoder_output_size,
  445. dropout_rate=dropout_rate,
  446. positional_dropout_rate=positional_dropout_rate,
  447. input_layer=input_layer,
  448. use_output_layer=use_output_layer,
  449. pos_enc_class=pos_enc_class,
  450. normalize_before=normalize_before,
  451. )
  452. attention_dim = encoder_output_size
  453. self.decoders = repeat(
  454. num_blocks,
  455. lambda lnum: DecoderLayer(
  456. attention_dim,
  457. LightweightConvolution2D(
  458. wshare=conv_wshare,
  459. n_feat=attention_dim,
  460. dropout_rate=self_attention_dropout_rate,
  461. kernel_size=conv_kernel_length[lnum],
  462. use_kernel_mask=True,
  463. use_bias=conv_usebias,
  464. ),
  465. MultiHeadedAttention(
  466. attention_heads, attention_dim, src_attention_dropout_rate
  467. ),
  468. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  469. dropout_rate,
  470. normalize_before,
  471. concat_after,
  472. ),
  473. )
  474. @tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder")
  475. class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
  476. def __init__(
  477. self,
  478. vocab_size: int,
  479. encoder_output_size: int,
  480. attention_heads: int = 4,
  481. linear_units: int = 2048,
  482. num_blocks: int = 6,
  483. dropout_rate: float = 0.1,
  484. positional_dropout_rate: float = 0.1,
  485. self_attention_dropout_rate: float = 0.0,
  486. src_attention_dropout_rate: float = 0.0,
  487. input_layer: str = "embed",
  488. use_output_layer: bool = True,
  489. pos_enc_class=PositionalEncoding,
  490. normalize_before: bool = True,
  491. concat_after: bool = False,
  492. conv_wshare: int = 4,
  493. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  494. conv_usebias: int = False,
  495. ):
  496. if len(conv_kernel_length) != num_blocks:
  497. raise ValueError(
  498. "conv_kernel_length must have equal number of values to num_blocks: "
  499. f"{len(conv_kernel_length)} != {num_blocks}"
  500. )
  501. super().__init__(
  502. vocab_size=vocab_size,
  503. encoder_output_size=encoder_output_size,
  504. dropout_rate=dropout_rate,
  505. positional_dropout_rate=positional_dropout_rate,
  506. input_layer=input_layer,
  507. use_output_layer=use_output_layer,
  508. pos_enc_class=pos_enc_class,
  509. normalize_before=normalize_before,
  510. )
  511. attention_dim = encoder_output_size
  512. self.decoders = repeat(
  513. num_blocks,
  514. lambda lnum: DecoderLayer(
  515. attention_dim,
  516. DynamicConvolution(
  517. wshare=conv_wshare,
  518. n_feat=attention_dim,
  519. dropout_rate=self_attention_dropout_rate,
  520. kernel_size=conv_kernel_length[lnum],
  521. use_kernel_mask=True,
  522. use_bias=conv_usebias,
  523. ),
  524. MultiHeadedAttention(
  525. attention_heads, attention_dim, src_attention_dropout_rate
  526. ),
  527. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  528. dropout_rate,
  529. normalize_before,
  530. concat_after,
  531. ),
  532. )
  533. @tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder")
  534. class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
  535. def __init__(
  536. self,
  537. vocab_size: int,
  538. encoder_output_size: int,
  539. attention_heads: int = 4,
  540. linear_units: int = 2048,
  541. num_blocks: int = 6,
  542. dropout_rate: float = 0.1,
  543. positional_dropout_rate: float = 0.1,
  544. self_attention_dropout_rate: float = 0.0,
  545. src_attention_dropout_rate: float = 0.0,
  546. input_layer: str = "embed",
  547. use_output_layer: bool = True,
  548. pos_enc_class=PositionalEncoding,
  549. normalize_before: bool = True,
  550. concat_after: bool = False,
  551. conv_wshare: int = 4,
  552. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  553. conv_usebias: int = False,
  554. ):
  555. if len(conv_kernel_length) != num_blocks:
  556. raise ValueError(
  557. "conv_kernel_length must have equal number of values to num_blocks: "
  558. f"{len(conv_kernel_length)} != {num_blocks}"
  559. )
  560. super().__init__(
  561. vocab_size=vocab_size,
  562. encoder_output_size=encoder_output_size,
  563. dropout_rate=dropout_rate,
  564. positional_dropout_rate=positional_dropout_rate,
  565. input_layer=input_layer,
  566. use_output_layer=use_output_layer,
  567. pos_enc_class=pos_enc_class,
  568. normalize_before=normalize_before,
  569. )
  570. attention_dim = encoder_output_size
  571. self.decoders = repeat(
  572. num_blocks,
  573. lambda lnum: DecoderLayer(
  574. attention_dim,
  575. DynamicConvolution2D(
  576. wshare=conv_wshare,
  577. n_feat=attention_dim,
  578. dropout_rate=self_attention_dropout_rate,
  579. kernel_size=conv_kernel_length[lnum],
  580. use_kernel_mask=True,
  581. use_bias=conv_usebias,
  582. ),
  583. MultiHeadedAttention(
  584. attention_heads, attention_dim, src_attention_dropout_rate
  585. ),
  586. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  587. dropout_rate,
  588. normalize_before,
  589. concat_after,
  590. ),
  591. )