transformer_decoder.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. # Copyright 2019 Shigeki Karita
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Decoder definition."""
  4. from typing import Any
  5. from typing import List
  6. from typing import Sequence
  7. from typing import Tuple
  8. import torch
  9. from torch import nn
  10. from typeguard import check_argument_types
  11. from funasr.models.decoder.abs_decoder import AbsDecoder
  12. from funasr.modules.attention import MultiHeadedAttention
  13. from funasr.modules.dynamic_conv import DynamicConvolution
  14. from funasr.modules.dynamic_conv2d import DynamicConvolution2D
  15. from funasr.modules.embedding import PositionalEncoding
  16. from funasr.modules.layer_norm import LayerNorm
  17. from funasr.modules.lightconv import LightweightConvolution
  18. from funasr.modules.lightconv2d import LightweightConvolution2D
  19. from funasr.modules.mask import subsequent_mask
  20. from funasr.modules.nets_utils import make_pad_mask
  21. from funasr.modules.positionwise_feed_forward import (
  22. PositionwiseFeedForward, # noqa: H301
  23. )
  24. from funasr.modules.repeat import repeat
  25. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  26. class DecoderLayer(nn.Module):
  27. """Single decoder layer module.
  28. Args:
  29. size (int): Input dimension.
  30. self_attn (torch.nn.Module): Self-attention module instance.
  31. `MultiHeadedAttention` instance can be used as the argument.
  32. src_attn (torch.nn.Module): Self-attention module instance.
  33. `MultiHeadedAttention` instance can be used as the argument.
  34. feed_forward (torch.nn.Module): Feed-forward module instance.
  35. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  36. can be used as the argument.
  37. dropout_rate (float): Dropout rate.
  38. normalize_before (bool): Whether to use layer_norm before the first block.
  39. concat_after (bool): Whether to concat attention layer's input and output.
  40. if True, additional linear will be applied.
  41. i.e. x -> x + linear(concat(x, att(x)))
  42. if False, no additional linear will be applied. i.e. x -> x + att(x)
  43. """
  44. def __init__(
  45. self,
  46. size,
  47. self_attn,
  48. src_attn,
  49. feed_forward,
  50. dropout_rate,
  51. normalize_before=True,
  52. concat_after=False,
  53. ):
  54. """Construct an DecoderLayer object."""
  55. super(DecoderLayer, self).__init__()
  56. self.size = size
  57. self.self_attn = self_attn
  58. self.src_attn = src_attn
  59. self.feed_forward = feed_forward
  60. self.norm1 = LayerNorm(size)
  61. self.norm2 = LayerNorm(size)
  62. self.norm3 = LayerNorm(size)
  63. self.dropout = nn.Dropout(dropout_rate)
  64. self.normalize_before = normalize_before
  65. self.concat_after = concat_after
  66. if self.concat_after:
  67. self.concat_linear1 = nn.Linear(size + size, size)
  68. self.concat_linear2 = nn.Linear(size + size, size)
  69. def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
  70. """Compute decoded features.
  71. Args:
  72. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  73. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  74. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  75. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  76. cache (List[torch.Tensor]): List of cached tensors.
  77. Each tensor shape should be (#batch, maxlen_out - 1, size).
  78. Returns:
  79. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  80. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  81. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  82. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  83. """
  84. residual = tgt
  85. if self.normalize_before:
  86. tgt = self.norm1(tgt)
  87. if cache is None:
  88. tgt_q = tgt
  89. tgt_q_mask = tgt_mask
  90. else:
  91. # compute only the last frame query keeping dim: max_time_out -> 1
  92. assert cache.shape == (
  93. tgt.shape[0],
  94. tgt.shape[1] - 1,
  95. self.size,
  96. ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
  97. tgt_q = tgt[:, -1:, :]
  98. residual = residual[:, -1:, :]
  99. tgt_q_mask = None
  100. if tgt_mask is not None:
  101. tgt_q_mask = tgt_mask[:, -1:, :]
  102. if self.concat_after:
  103. tgt_concat = torch.cat(
  104. (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
  105. )
  106. x = residual + self.concat_linear1(tgt_concat)
  107. else:
  108. x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
  109. if not self.normalize_before:
  110. x = self.norm1(x)
  111. residual = x
  112. if self.normalize_before:
  113. x = self.norm2(x)
  114. if self.concat_after:
  115. x_concat = torch.cat(
  116. (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
  117. )
  118. x = residual + self.concat_linear2(x_concat)
  119. else:
  120. x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
  121. if not self.normalize_before:
  122. x = self.norm2(x)
  123. residual = x
  124. if self.normalize_before:
  125. x = self.norm3(x)
  126. x = residual + self.dropout(self.feed_forward(x))
  127. if not self.normalize_before:
  128. x = self.norm3(x)
  129. if cache is not None:
  130. x = torch.cat([cache, x], dim=1)
  131. return x, tgt_mask, memory, memory_mask
  132. class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
  133. """Base class of Transfomer decoder module.
  134. Args:
  135. vocab_size: output dim
  136. encoder_output_size: dimension of attention
  137. attention_heads: the number of heads of multi head attention
  138. linear_units: the number of units of position-wise feed forward
  139. num_blocks: the number of decoder blocks
  140. dropout_rate: dropout rate
  141. self_attention_dropout_rate: dropout rate for attention
  142. input_layer: input layer type
  143. use_output_layer: whether to use output layer
  144. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  145. normalize_before: whether to use layer_norm before the first block
  146. concat_after: whether to concat attention layer's input and output
  147. if True, additional linear will be applied.
  148. i.e. x -> x + linear(concat(x, att(x)))
  149. if False, no additional linear will be applied.
  150. i.e. x -> x + att(x)
  151. """
  152. def __init__(
  153. self,
  154. vocab_size: int,
  155. encoder_output_size: int,
  156. dropout_rate: float = 0.1,
  157. positional_dropout_rate: float = 0.1,
  158. input_layer: str = "embed",
  159. use_output_layer: bool = True,
  160. pos_enc_class=PositionalEncoding,
  161. normalize_before: bool = True,
  162. ):
  163. assert check_argument_types()
  164. super().__init__()
  165. attention_dim = encoder_output_size
  166. if input_layer == "embed":
  167. self.embed = torch.nn.Sequential(
  168. torch.nn.Embedding(vocab_size, attention_dim),
  169. pos_enc_class(attention_dim, positional_dropout_rate),
  170. )
  171. elif input_layer == "linear":
  172. self.embed = torch.nn.Sequential(
  173. torch.nn.Linear(vocab_size, attention_dim),
  174. torch.nn.LayerNorm(attention_dim),
  175. torch.nn.Dropout(dropout_rate),
  176. torch.nn.ReLU(),
  177. pos_enc_class(attention_dim, positional_dropout_rate),
  178. )
  179. else:
  180. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  181. self.normalize_before = normalize_before
  182. if self.normalize_before:
  183. self.after_norm = LayerNorm(attention_dim)
  184. if use_output_layer:
  185. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  186. else:
  187. self.output_layer = None
  188. # Must set by the inheritance
  189. self.decoders = None
  190. def forward(
  191. self,
  192. hs_pad: torch.Tensor,
  193. hlens: torch.Tensor,
  194. ys_in_pad: torch.Tensor,
  195. ys_in_lens: torch.Tensor,
  196. ) -> Tuple[torch.Tensor, torch.Tensor]:
  197. """Forward decoder.
  198. Args:
  199. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  200. hlens: (batch)
  201. ys_in_pad:
  202. input token ids, int64 (batch, maxlen_out)
  203. if input_layer == "embed"
  204. input tensor (batch, maxlen_out, #mels) in the other cases
  205. ys_in_lens: (batch)
  206. Returns:
  207. (tuple): tuple containing:
  208. x: decoded token score before softmax (batch, maxlen_out, token)
  209. if use_output_layer is True,
  210. olens: (batch, )
  211. """
  212. tgt = ys_in_pad
  213. # tgt_mask: (B, 1, L)
  214. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  215. # m: (1, L, L)
  216. m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
  217. # tgt_mask: (B, L, L)
  218. tgt_mask = tgt_mask & m
  219. memory = hs_pad
  220. memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
  221. memory.device
  222. )
  223. # Padding for Longformer
  224. if memory_mask.shape[-1] != memory.shape[1]:
  225. padlen = memory.shape[1] - memory_mask.shape[-1]
  226. memory_mask = torch.nn.functional.pad(
  227. memory_mask, (0, padlen), "constant", False
  228. )
  229. x = self.embed(tgt)
  230. x, tgt_mask, memory, memory_mask = self.decoders(
  231. x, tgt_mask, memory, memory_mask
  232. )
  233. if self.normalize_before:
  234. x = self.after_norm(x)
  235. if self.output_layer is not None:
  236. x = self.output_layer(x)
  237. olens = tgt_mask.sum(1)
  238. return x, olens
  239. def forward_one_step(
  240. self,
  241. tgt: torch.Tensor,
  242. tgt_mask: torch.Tensor,
  243. memory: torch.Tensor,
  244. cache: List[torch.Tensor] = None,
  245. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  246. """Forward one step.
  247. Args:
  248. tgt: input token ids, int64 (batch, maxlen_out)
  249. tgt_mask: input token mask, (batch, maxlen_out)
  250. dtype=torch.uint8 in PyTorch 1.2-
  251. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  252. memory: encoded memory, float32 (batch, maxlen_in, feat)
  253. cache: cached output list of (batch, max_time_out-1, size)
  254. Returns:
  255. y, cache: NN output value and cache per `self.decoders`.
  256. y.shape` is (batch, maxlen_out, token)
  257. """
  258. x = self.embed(tgt)
  259. if cache is None:
  260. cache = [None] * len(self.decoders)
  261. new_cache = []
  262. for c, decoder in zip(cache, self.decoders):
  263. x, tgt_mask, memory, memory_mask = decoder(
  264. x, tgt_mask, memory, None, cache=c
  265. )
  266. new_cache.append(x)
  267. if self.normalize_before:
  268. y = self.after_norm(x[:, -1])
  269. else:
  270. y = x[:, -1]
  271. if self.output_layer is not None:
  272. y = torch.log_softmax(self.output_layer(y), dim=-1)
  273. return y, new_cache
  274. def score(self, ys, state, x):
  275. """Score."""
  276. ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
  277. logp, state = self.forward_one_step(
  278. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
  279. )
  280. return logp.squeeze(0), state
  281. def batch_score(
  282. self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
  283. ) -> Tuple[torch.Tensor, List[Any]]:
  284. """Score new token batch.
  285. Args:
  286. ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
  287. states (List[Any]): Scorer states for prefix tokens.
  288. xs (torch.Tensor):
  289. The encoder feature that generates ys (n_batch, xlen, n_feat).
  290. Returns:
  291. tuple[torch.Tensor, List[Any]]: Tuple of
  292. batchfied scores for next token with shape of `(n_batch, n_vocab)`
  293. and next state list for ys.
  294. """
  295. # merge states
  296. n_batch = len(ys)
  297. n_layers = len(self.decoders)
  298. if states[0] is None:
  299. batch_state = None
  300. else:
  301. # transpose state of [batch, layer] into [layer, batch]
  302. batch_state = [
  303. torch.stack([states[b][i] for b in range(n_batch)])
  304. for i in range(n_layers)
  305. ]
  306. # batch decoding
  307. ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
  308. logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
  309. # transpose state of [layer, batch] into [batch, layer]
  310. state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
  311. return logp, state_list
  312. class TransformerDecoder(BaseTransformerDecoder):
  313. def __init__(
  314. self,
  315. vocab_size: int,
  316. encoder_output_size: int,
  317. attention_heads: int = 4,
  318. linear_units: int = 2048,
  319. num_blocks: int = 6,
  320. dropout_rate: float = 0.1,
  321. positional_dropout_rate: float = 0.1,
  322. self_attention_dropout_rate: float = 0.0,
  323. src_attention_dropout_rate: float = 0.0,
  324. input_layer: str = "embed",
  325. use_output_layer: bool = True,
  326. pos_enc_class=PositionalEncoding,
  327. normalize_before: bool = True,
  328. concat_after: bool = False,
  329. ):
  330. assert check_argument_types()
  331. super().__init__(
  332. vocab_size=vocab_size,
  333. encoder_output_size=encoder_output_size,
  334. dropout_rate=dropout_rate,
  335. positional_dropout_rate=positional_dropout_rate,
  336. input_layer=input_layer,
  337. use_output_layer=use_output_layer,
  338. pos_enc_class=pos_enc_class,
  339. normalize_before=normalize_before,
  340. )
  341. attention_dim = encoder_output_size
  342. self.decoders = repeat(
  343. num_blocks,
  344. lambda lnum: DecoderLayer(
  345. attention_dim,
  346. MultiHeadedAttention(
  347. attention_heads, attention_dim, self_attention_dropout_rate
  348. ),
  349. MultiHeadedAttention(
  350. attention_heads, attention_dim, src_attention_dropout_rate
  351. ),
  352. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  353. dropout_rate,
  354. normalize_before,
  355. concat_after,
  356. ),
  357. )
  358. class ParaformerDecoderSAN(BaseTransformerDecoder):
  359. """
  360. Author: Speech Lab of DAMO Academy, Alibaba Group
  361. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  362. https://arxiv.org/abs/2006.01713
  363. """
  364. def __init__(
  365. self,
  366. vocab_size: int,
  367. encoder_output_size: int,
  368. attention_heads: int = 4,
  369. linear_units: int = 2048,
  370. num_blocks: int = 6,
  371. dropout_rate: float = 0.1,
  372. positional_dropout_rate: float = 0.1,
  373. self_attention_dropout_rate: float = 0.0,
  374. src_attention_dropout_rate: float = 0.0,
  375. input_layer: str = "embed",
  376. use_output_layer: bool = True,
  377. pos_enc_class=PositionalEncoding,
  378. normalize_before: bool = True,
  379. concat_after: bool = False,
  380. embeds_id: int = -1,
  381. ):
  382. assert check_argument_types()
  383. super().__init__(
  384. vocab_size=vocab_size,
  385. encoder_output_size=encoder_output_size,
  386. dropout_rate=dropout_rate,
  387. positional_dropout_rate=positional_dropout_rate,
  388. input_layer=input_layer,
  389. use_output_layer=use_output_layer,
  390. pos_enc_class=pos_enc_class,
  391. normalize_before=normalize_before,
  392. )
  393. attention_dim = encoder_output_size
  394. self.decoders = repeat(
  395. num_blocks,
  396. lambda lnum: DecoderLayer(
  397. attention_dim,
  398. MultiHeadedAttention(
  399. attention_heads, attention_dim, self_attention_dropout_rate
  400. ),
  401. MultiHeadedAttention(
  402. attention_heads, attention_dim, src_attention_dropout_rate
  403. ),
  404. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  405. dropout_rate,
  406. normalize_before,
  407. concat_after,
  408. ),
  409. )
  410. self.embeds_id = embeds_id
  411. self.attention_dim = attention_dim
  412. def forward(
  413. self,
  414. hs_pad: torch.Tensor,
  415. hlens: torch.Tensor,
  416. ys_in_pad: torch.Tensor,
  417. ys_in_lens: torch.Tensor,
  418. ) -> Tuple[torch.Tensor, torch.Tensor]:
  419. """Forward decoder.
  420. Args:
  421. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  422. hlens: (batch)
  423. ys_in_pad:
  424. input token ids, int64 (batch, maxlen_out)
  425. if input_layer == "embed"
  426. input tensor (batch, maxlen_out, #mels) in the other cases
  427. ys_in_lens: (batch)
  428. Returns:
  429. (tuple): tuple containing:
  430. x: decoded token score before softmax (batch, maxlen_out, token)
  431. if use_output_layer is True,
  432. olens: (batch, )
  433. """
  434. tgt = ys_in_pad
  435. tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
  436. memory = hs_pad
  437. memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
  438. memory.device
  439. )
  440. # Padding for Longformer
  441. if memory_mask.shape[-1] != memory.shape[1]:
  442. padlen = memory.shape[1] - memory_mask.shape[-1]
  443. memory_mask = torch.nn.functional.pad(
  444. memory_mask, (0, padlen), "constant", False
  445. )
  446. # x = self.embed(tgt)
  447. x = tgt
  448. embeds_outputs = None
  449. for layer_id, decoder in enumerate(self.decoders):
  450. x, tgt_mask, memory, memory_mask = decoder(
  451. x, tgt_mask, memory, memory_mask
  452. )
  453. if layer_id == self.embeds_id:
  454. embeds_outputs = x
  455. if self.normalize_before:
  456. x = self.after_norm(x)
  457. if self.output_layer is not None:
  458. x = self.output_layer(x)
  459. olens = tgt_mask.sum(1)
  460. if embeds_outputs is not None:
  461. return x, olens, embeds_outputs
  462. else:
  463. return x, olens
  464. class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
  465. def __init__(
  466. self,
  467. vocab_size: int,
  468. encoder_output_size: int,
  469. attention_heads: int = 4,
  470. linear_units: int = 2048,
  471. num_blocks: int = 6,
  472. dropout_rate: float = 0.1,
  473. positional_dropout_rate: float = 0.1,
  474. self_attention_dropout_rate: float = 0.0,
  475. src_attention_dropout_rate: float = 0.0,
  476. input_layer: str = "embed",
  477. use_output_layer: bool = True,
  478. pos_enc_class=PositionalEncoding,
  479. normalize_before: bool = True,
  480. concat_after: bool = False,
  481. conv_wshare: int = 4,
  482. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  483. conv_usebias: int = False,
  484. ):
  485. assert check_argument_types()
  486. if len(conv_kernel_length) != num_blocks:
  487. raise ValueError(
  488. "conv_kernel_length must have equal number of values to num_blocks: "
  489. f"{len(conv_kernel_length)} != {num_blocks}"
  490. )
  491. super().__init__(
  492. vocab_size=vocab_size,
  493. encoder_output_size=encoder_output_size,
  494. dropout_rate=dropout_rate,
  495. positional_dropout_rate=positional_dropout_rate,
  496. input_layer=input_layer,
  497. use_output_layer=use_output_layer,
  498. pos_enc_class=pos_enc_class,
  499. normalize_before=normalize_before,
  500. )
  501. attention_dim = encoder_output_size
  502. self.decoders = repeat(
  503. num_blocks,
  504. lambda lnum: DecoderLayer(
  505. attention_dim,
  506. LightweightConvolution(
  507. wshare=conv_wshare,
  508. n_feat=attention_dim,
  509. dropout_rate=self_attention_dropout_rate,
  510. kernel_size=conv_kernel_length[lnum],
  511. use_kernel_mask=True,
  512. use_bias=conv_usebias,
  513. ),
  514. MultiHeadedAttention(
  515. attention_heads, attention_dim, src_attention_dropout_rate
  516. ),
  517. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  518. dropout_rate,
  519. normalize_before,
  520. concat_after,
  521. ),
  522. )
  523. class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
  524. def __init__(
  525. self,
  526. vocab_size: int,
  527. encoder_output_size: int,
  528. attention_heads: int = 4,
  529. linear_units: int = 2048,
  530. num_blocks: int = 6,
  531. dropout_rate: float = 0.1,
  532. positional_dropout_rate: float = 0.1,
  533. self_attention_dropout_rate: float = 0.0,
  534. src_attention_dropout_rate: float = 0.0,
  535. input_layer: str = "embed",
  536. use_output_layer: bool = True,
  537. pos_enc_class=PositionalEncoding,
  538. normalize_before: bool = True,
  539. concat_after: bool = False,
  540. conv_wshare: int = 4,
  541. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  542. conv_usebias: int = False,
  543. ):
  544. assert check_argument_types()
  545. if len(conv_kernel_length) != num_blocks:
  546. raise ValueError(
  547. "conv_kernel_length must have equal number of values to num_blocks: "
  548. f"{len(conv_kernel_length)} != {num_blocks}"
  549. )
  550. super().__init__(
  551. vocab_size=vocab_size,
  552. encoder_output_size=encoder_output_size,
  553. dropout_rate=dropout_rate,
  554. positional_dropout_rate=positional_dropout_rate,
  555. input_layer=input_layer,
  556. use_output_layer=use_output_layer,
  557. pos_enc_class=pos_enc_class,
  558. normalize_before=normalize_before,
  559. )
  560. attention_dim = encoder_output_size
  561. self.decoders = repeat(
  562. num_blocks,
  563. lambda lnum: DecoderLayer(
  564. attention_dim,
  565. LightweightConvolution2D(
  566. wshare=conv_wshare,
  567. n_feat=attention_dim,
  568. dropout_rate=self_attention_dropout_rate,
  569. kernel_size=conv_kernel_length[lnum],
  570. use_kernel_mask=True,
  571. use_bias=conv_usebias,
  572. ),
  573. MultiHeadedAttention(
  574. attention_heads, attention_dim, src_attention_dropout_rate
  575. ),
  576. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  577. dropout_rate,
  578. normalize_before,
  579. concat_after,
  580. ),
  581. )
  582. class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
  583. def __init__(
  584. self,
  585. vocab_size: int,
  586. encoder_output_size: int,
  587. attention_heads: int = 4,
  588. linear_units: int = 2048,
  589. num_blocks: int = 6,
  590. dropout_rate: float = 0.1,
  591. positional_dropout_rate: float = 0.1,
  592. self_attention_dropout_rate: float = 0.0,
  593. src_attention_dropout_rate: float = 0.0,
  594. input_layer: str = "embed",
  595. use_output_layer: bool = True,
  596. pos_enc_class=PositionalEncoding,
  597. normalize_before: bool = True,
  598. concat_after: bool = False,
  599. conv_wshare: int = 4,
  600. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  601. conv_usebias: int = False,
  602. ):
  603. assert check_argument_types()
  604. if len(conv_kernel_length) != num_blocks:
  605. raise ValueError(
  606. "conv_kernel_length must have equal number of values to num_blocks: "
  607. f"{len(conv_kernel_length)} != {num_blocks}"
  608. )
  609. super().__init__(
  610. vocab_size=vocab_size,
  611. encoder_output_size=encoder_output_size,
  612. dropout_rate=dropout_rate,
  613. positional_dropout_rate=positional_dropout_rate,
  614. input_layer=input_layer,
  615. use_output_layer=use_output_layer,
  616. pos_enc_class=pos_enc_class,
  617. normalize_before=normalize_before,
  618. )
  619. attention_dim = encoder_output_size
  620. self.decoders = repeat(
  621. num_blocks,
  622. lambda lnum: DecoderLayer(
  623. attention_dim,
  624. DynamicConvolution(
  625. wshare=conv_wshare,
  626. n_feat=attention_dim,
  627. dropout_rate=self_attention_dropout_rate,
  628. kernel_size=conv_kernel_length[lnum],
  629. use_kernel_mask=True,
  630. use_bias=conv_usebias,
  631. ),
  632. MultiHeadedAttention(
  633. attention_heads, attention_dim, src_attention_dropout_rate
  634. ),
  635. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  636. dropout_rate,
  637. normalize_before,
  638. concat_after,
  639. ),
  640. )
  641. class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
  642. def __init__(
  643. self,
  644. vocab_size: int,
  645. encoder_output_size: int,
  646. attention_heads: int = 4,
  647. linear_units: int = 2048,
  648. num_blocks: int = 6,
  649. dropout_rate: float = 0.1,
  650. positional_dropout_rate: float = 0.1,
  651. self_attention_dropout_rate: float = 0.0,
  652. src_attention_dropout_rate: float = 0.0,
  653. input_layer: str = "embed",
  654. use_output_layer: bool = True,
  655. pos_enc_class=PositionalEncoding,
  656. normalize_before: bool = True,
  657. concat_after: bool = False,
  658. conv_wshare: int = 4,
  659. conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
  660. conv_usebias: int = False,
  661. ):
  662. assert check_argument_types()
  663. if len(conv_kernel_length) != num_blocks:
  664. raise ValueError(
  665. "conv_kernel_length must have equal number of values to num_blocks: "
  666. f"{len(conv_kernel_length)} != {num_blocks}"
  667. )
  668. super().__init__(
  669. vocab_size=vocab_size,
  670. encoder_output_size=encoder_output_size,
  671. dropout_rate=dropout_rate,
  672. positional_dropout_rate=positional_dropout_rate,
  673. input_layer=input_layer,
  674. use_output_layer=use_output_layer,
  675. pos_enc_class=pos_enc_class,
  676. normalize_before=normalize_before,
  677. )
  678. attention_dim = encoder_output_size
  679. self.decoders = repeat(
  680. num_blocks,
  681. lambda lnum: DecoderLayer(
  682. attention_dim,
  683. DynamicConvolution2D(
  684. wshare=conv_wshare,
  685. n_feat=attention_dim,
  686. dropout_rate=self_attention_dropout_rate,
  687. kernel_size=conv_kernel_length[lnum],
  688. use_kernel_mask=True,
  689. use_bias=conv_usebias,
  690. ),
  691. MultiHeadedAttention(
  692. attention_heads, attention_dim, src_attention_dropout_rate
  693. ),
  694. PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
  695. dropout_rate,
  696. normalize_before,
  697. concat_after,
  698. ),
  699. )