transformer_encoder.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. # Copyright 2019 Shigeki Karita
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Transformer encoder definition."""
  4. from typing import List
  5. from typing import Optional
  6. from typing import Tuple
  7. import torch
  8. from torch import nn
  9. from typeguard import check_argument_types
  10. import logging
  11. from funasr.models.ctc import CTC
  12. from funasr.models.encoder.abs_encoder import AbsEncoder
  13. from funasr.modules.attention import MultiHeadedAttention
  14. from funasr.modules.embedding import PositionalEncoding
  15. from funasr.modules.layer_norm import LayerNorm
  16. from funasr.modules.multi_layer_conv import Conv1dLinear
  17. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  18. from funasr.modules.nets_utils import make_pad_mask
  19. from funasr.modules.positionwise_feed_forward import (
  20. PositionwiseFeedForward, # noqa: H301
  21. )
  22. from funasr.modules.repeat import repeat
  23. from funasr.modules.nets_utils import rename_state_dict
  24. from funasr.modules.dynamic_conv import DynamicConvolution
  25. from funasr.modules.dynamic_conv2d import DynamicConvolution2D
  26. from funasr.modules.lightconv import LightweightConvolution
  27. from funasr.modules.lightconv2d import LightweightConvolution2D
  28. from funasr.modules.subsampling import Conv2dSubsampling
  29. from funasr.modules.subsampling import Conv2dSubsampling2
  30. from funasr.modules.subsampling import Conv2dSubsampling6
  31. from funasr.modules.subsampling import Conv2dSubsampling8
  32. from funasr.modules.subsampling import TooShortUttError
  33. from funasr.modules.subsampling import check_short_utt
  34. class EncoderLayer(nn.Module):
  35. """Encoder layer module.
  36. Args:
  37. size (int): Input dimension.
  38. self_attn (torch.nn.Module): Self-attention module instance.
  39. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  40. can be used as the argument.
  41. feed_forward (torch.nn.Module): Feed-forward module instance.
  42. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  43. can be used as the argument.
  44. dropout_rate (float): Dropout rate.
  45. normalize_before (bool): Whether to use layer_norm before the first block.
  46. concat_after (bool): Whether to concat attention layer's input and output.
  47. if True, additional linear will be applied.
  48. i.e. x -> x + linear(concat(x, att(x)))
  49. if False, no additional linear will be applied. i.e. x -> x + att(x)
  50. stochastic_depth_rate (float): Proability to skip this layer.
  51. During training, the layer may skip residual computation and return input
  52. as-is with given probability.
  53. """
  54. def __init__(
  55. self,
  56. size,
  57. self_attn,
  58. feed_forward,
  59. dropout_rate,
  60. normalize_before=True,
  61. concat_after=False,
  62. stochastic_depth_rate=0.0,
  63. ):
  64. """Construct an EncoderLayer object."""
  65. super(EncoderLayer, self).__init__()
  66. self.self_attn = self_attn
  67. self.feed_forward = feed_forward
  68. self.norm1 = LayerNorm(size)
  69. self.norm2 = LayerNorm(size)
  70. self.dropout = nn.Dropout(dropout_rate)
  71. self.size = size
  72. self.normalize_before = normalize_before
  73. self.concat_after = concat_after
  74. if self.concat_after:
  75. self.concat_linear = nn.Linear(size + size, size)
  76. self.stochastic_depth_rate = stochastic_depth_rate
  77. def forward(self, x, mask, cache=None):
  78. """Compute encoded features.
  79. Args:
  80. x_input (torch.Tensor): Input tensor (#batch, time, size).
  81. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  82. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  83. Returns:
  84. torch.Tensor: Output tensor (#batch, time, size).
  85. torch.Tensor: Mask tensor (#batch, time).
  86. """
  87. skip_layer = False
  88. # with stochastic depth, residual connection `x + f(x)` becomes
  89. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  90. stoch_layer_coeff = 1.0
  91. if self.training and self.stochastic_depth_rate > 0:
  92. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  93. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  94. if skip_layer:
  95. if cache is not None:
  96. x = torch.cat([cache, x], dim=1)
  97. return x, mask
  98. residual = x
  99. if self.normalize_before:
  100. x = self.norm1(x)
  101. if cache is None:
  102. x_q = x
  103. else:
  104. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  105. x_q = x[:, -1:, :]
  106. residual = residual[:, -1:, :]
  107. mask = None if mask is None else mask[:, -1:, :]
  108. if self.concat_after:
  109. x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
  110. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  111. else:
  112. x = residual + stoch_layer_coeff * self.dropout(
  113. self.self_attn(x_q, x, x, mask)
  114. )
  115. if not self.normalize_before:
  116. x = self.norm1(x)
  117. residual = x
  118. if self.normalize_before:
  119. x = self.norm2(x)
  120. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  121. if not self.normalize_before:
  122. x = self.norm2(x)
  123. if cache is not None:
  124. x = torch.cat([cache, x], dim=1)
  125. return x, mask
  126. class TransformerEncoder(AbsEncoder):
  127. """Transformer encoder module.
  128. Args:
  129. input_size: input dim
  130. output_size: dimension of attention
  131. attention_heads: the number of heads of multi head attention
  132. linear_units: the number of units of position-wise feed forward
  133. num_blocks: the number of decoder blocks
  134. dropout_rate: dropout rate
  135. attention_dropout_rate: dropout rate in attention
  136. positional_dropout_rate: dropout rate after adding positional encoding
  137. input_layer: input layer type
  138. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  139. normalize_before: whether to use layer_norm before the first block
  140. concat_after: whether to concat attention layer's input and output
  141. if True, additional linear will be applied.
  142. i.e. x -> x + linear(concat(x, att(x)))
  143. if False, no additional linear will be applied.
  144. i.e. x -> x + att(x)
  145. positionwise_layer_type: linear of conv1d
  146. positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
  147. padding_idx: padding_idx for input_layer=embed
  148. """
  149. def __init__(
  150. self,
  151. input_size: int,
  152. output_size: int = 256,
  153. attention_heads: int = 4,
  154. linear_units: int = 2048,
  155. num_blocks: int = 6,
  156. dropout_rate: float = 0.1,
  157. positional_dropout_rate: float = 0.1,
  158. attention_dropout_rate: float = 0.0,
  159. input_layer: Optional[str] = "conv2d",
  160. pos_enc_class=PositionalEncoding,
  161. normalize_before: bool = True,
  162. concat_after: bool = False,
  163. positionwise_layer_type: str = "linear",
  164. positionwise_conv_kernel_size: int = 1,
  165. padding_idx: int = -1,
  166. interctc_layer_idx: List[int] = [],
  167. interctc_use_conditioning: bool = False,
  168. ):
  169. assert check_argument_types()
  170. super().__init__()
  171. self._output_size = output_size
  172. if input_layer == "linear":
  173. self.embed = torch.nn.Sequential(
  174. torch.nn.Linear(input_size, output_size),
  175. torch.nn.LayerNorm(output_size),
  176. torch.nn.Dropout(dropout_rate),
  177. torch.nn.ReLU(),
  178. pos_enc_class(output_size, positional_dropout_rate),
  179. )
  180. elif input_layer == "conv2d":
  181. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  182. elif input_layer == "conv2d2":
  183. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  184. elif input_layer == "conv2d6":
  185. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  186. elif input_layer == "conv2d8":
  187. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  188. elif input_layer == "embed":
  189. self.embed = torch.nn.Sequential(
  190. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  191. pos_enc_class(output_size, positional_dropout_rate),
  192. )
  193. elif input_layer is None:
  194. if input_size == output_size:
  195. self.embed = None
  196. else:
  197. self.embed = torch.nn.Linear(input_size, output_size)
  198. else:
  199. raise ValueError("unknown input_layer: " + input_layer)
  200. self.normalize_before = normalize_before
  201. if positionwise_layer_type == "linear":
  202. positionwise_layer = PositionwiseFeedForward
  203. positionwise_layer_args = (
  204. output_size,
  205. linear_units,
  206. dropout_rate,
  207. )
  208. elif positionwise_layer_type == "conv1d":
  209. positionwise_layer = MultiLayeredConv1d
  210. positionwise_layer_args = (
  211. output_size,
  212. linear_units,
  213. positionwise_conv_kernel_size,
  214. dropout_rate,
  215. )
  216. elif positionwise_layer_type == "conv1d-linear":
  217. positionwise_layer = Conv1dLinear
  218. positionwise_layer_args = (
  219. output_size,
  220. linear_units,
  221. positionwise_conv_kernel_size,
  222. dropout_rate,
  223. )
  224. else:
  225. raise NotImplementedError("Support only linear or conv1d.")
  226. self.encoders = repeat(
  227. num_blocks,
  228. lambda lnum: EncoderLayer(
  229. output_size,
  230. MultiHeadedAttention(
  231. attention_heads, output_size, attention_dropout_rate
  232. ),
  233. positionwise_layer(*positionwise_layer_args),
  234. dropout_rate,
  235. normalize_before,
  236. concat_after,
  237. ),
  238. )
  239. if self.normalize_before:
  240. self.after_norm = LayerNorm(output_size)
  241. self.interctc_layer_idx = interctc_layer_idx
  242. if len(interctc_layer_idx) > 0:
  243. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  244. self.interctc_use_conditioning = interctc_use_conditioning
  245. self.conditioning_layer = None
  246. def output_size(self) -> int:
  247. return self._output_size
  248. def forward(
  249. self,
  250. xs_pad: torch.Tensor,
  251. ilens: torch.Tensor,
  252. prev_states: torch.Tensor = None,
  253. ctc: CTC = None,
  254. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  255. """Embed positions in tensor.
  256. Args:
  257. xs_pad: input tensor (B, L, D)
  258. ilens: input length (B)
  259. prev_states: Not to be used now.
  260. Returns:
  261. position embedded tensor and mask
  262. """
  263. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  264. if self.embed is None:
  265. xs_pad = xs_pad
  266. elif (
  267. isinstance(self.embed, Conv2dSubsampling)
  268. or isinstance(self.embed, Conv2dSubsampling2)
  269. or isinstance(self.embed, Conv2dSubsampling6)
  270. or isinstance(self.embed, Conv2dSubsampling8)
  271. ):
  272. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  273. if short_status:
  274. raise TooShortUttError(
  275. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  276. + f"(it needs more than {limit_size} frames), return empty results",
  277. xs_pad.size(1),
  278. limit_size,
  279. )
  280. xs_pad, masks = self.embed(xs_pad, masks)
  281. else:
  282. xs_pad = self.embed(xs_pad)
  283. intermediate_outs = []
  284. if len(self.interctc_layer_idx) == 0:
  285. xs_pad, masks = self.encoders(xs_pad, masks)
  286. else:
  287. for layer_idx, encoder_layer in enumerate(self.encoders):
  288. xs_pad, masks = encoder_layer(xs_pad, masks)
  289. if layer_idx + 1 in self.interctc_layer_idx:
  290. encoder_out = xs_pad
  291. # intermediate outputs are also normalized
  292. if self.normalize_before:
  293. encoder_out = self.after_norm(encoder_out)
  294. intermediate_outs.append((layer_idx + 1, encoder_out))
  295. if self.interctc_use_conditioning:
  296. ctc_out = ctc.softmax(encoder_out)
  297. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  298. if self.normalize_before:
  299. xs_pad = self.after_norm(xs_pad)
  300. olens = masks.squeeze(1).sum(1)
  301. if len(intermediate_outs) > 0:
  302. return (xs_pad, intermediate_outs), olens, None
  303. return xs_pad, olens, None
  304. def _pre_hook(
  305. state_dict,
  306. prefix,
  307. local_metadata,
  308. strict,
  309. missing_keys,
  310. unexpected_keys,
  311. error_msgs,
  312. ):
  313. # https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
  314. rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
  315. # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
  316. rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
  317. class TransformerEncoder_s0(torch.nn.Module):
  318. """Transformer encoder module.
  319. Args:
  320. idim (int): Input dimension.
  321. attention_dim (int): Dimension of attention.
  322. attention_heads (int): The number of heads of multi head attention.
  323. conv_wshare (int): The number of kernel of convolution. Only used in
  324. selfattention_layer_type == "lightconv*" or "dynamiconv*".
  325. conv_kernel_length (Union[int, str]): Kernel size str of convolution
  326. (e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type
  327. == "lightconv*" or "dynamiconv*".
  328. conv_usebias (bool): Whether to use bias in convolution. Only used in
  329. selfattention_layer_type == "lightconv*" or "dynamiconv*".
  330. linear_units (int): The number of units of position-wise feed forward.
  331. num_blocks (int): The number of decoder blocks.
  332. dropout_rate (float): Dropout rate.
  333. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  334. attention_dropout_rate (float): Dropout rate in attention.
  335. input_layer (Union[str, torch.nn.Module]): Input layer type.
  336. pos_enc_class (torch.nn.Module): Positional encoding module class.
  337. `PositionalEncoding `or `ScaledPositionalEncoding`
  338. normalize_before (bool): Whether to use layer_norm before the first block.
  339. concat_after (bool): Whether to concat attention layer's input and output.
  340. if True, additional linear will be applied.
  341. i.e. x -> x + linear(concat(x, att(x)))
  342. if False, no additional linear will be applied. i.e. x -> x + att(x)
  343. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  344. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  345. selfattention_layer_type (str): Encoder attention layer type.
  346. padding_idx (int): Padding idx for input_layer=embed.
  347. stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
  348. intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
  349. indices start from 1.
  350. if not None, intermediate outputs are returned (which changes return type
  351. signature.)
  352. """
  353. def __init__(
  354. self,
  355. idim,
  356. attention_dim=256,
  357. attention_heads=4,
  358. conv_wshare=4,
  359. conv_kernel_length="11",
  360. conv_usebias=False,
  361. linear_units=2048,
  362. num_blocks=6,
  363. dropout_rate=0.1,
  364. positional_dropout_rate=0.1,
  365. attention_dropout_rate=0.0,
  366. input_layer="conv2d",
  367. pos_enc_class=PositionalEncoding,
  368. normalize_before=True,
  369. concat_after=False,
  370. positionwise_layer_type="linear",
  371. positionwise_conv_kernel_size=1,
  372. selfattention_layer_type="selfattn",
  373. padding_idx=-1,
  374. stochastic_depth_rate=0.0,
  375. intermediate_layers=None,
  376. ctc_softmax=None,
  377. conditioning_layer_dim=None,
  378. ):
  379. """Construct an Encoder object."""
  380. super(TransformerEncoder_s0, self).__init__()
  381. self._register_load_state_dict_pre_hook(_pre_hook)
  382. self.conv_subsampling_factor = 1
  383. if input_layer == "linear":
  384. self.embed = torch.nn.Sequential(
  385. torch.nn.Linear(idim, attention_dim),
  386. torch.nn.LayerNorm(attention_dim),
  387. torch.nn.Dropout(dropout_rate),
  388. torch.nn.ReLU(),
  389. pos_enc_class(attention_dim, positional_dropout_rate),
  390. )
  391. elif input_layer == "conv2d":
  392. self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
  393. self.conv_subsampling_factor = 4
  394. elif input_layer == "conv2d-scaled-pos-enc":
  395. self.embed = Conv2dSubsampling(
  396. idim,
  397. attention_dim,
  398. dropout_rate,
  399. pos_enc_class(attention_dim, positional_dropout_rate),
  400. )
  401. self.conv_subsampling_factor = 4
  402. elif input_layer == "conv2d6":
  403. self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate)
  404. self.conv_subsampling_factor = 6
  405. elif input_layer == "conv2d8":
  406. self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate)
  407. self.conv_subsampling_factor = 8
  408. elif input_layer == "embed":
  409. self.embed = torch.nn.Sequential(
  410. torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
  411. pos_enc_class(attention_dim, positional_dropout_rate),
  412. )
  413. elif isinstance(input_layer, torch.nn.Module):
  414. self.embed = torch.nn.Sequential(
  415. input_layer,
  416. pos_enc_class(attention_dim, positional_dropout_rate),
  417. )
  418. elif input_layer is None:
  419. self.embed = torch.nn.Sequential(
  420. pos_enc_class(attention_dim, positional_dropout_rate)
  421. )
  422. else:
  423. raise ValueError("unknown input_layer: " + input_layer)
  424. self.normalize_before = normalize_before
  425. positionwise_layer, positionwise_layer_args = self.get_positionwise_layer(
  426. positionwise_layer_type,
  427. attention_dim,
  428. linear_units,
  429. dropout_rate,
  430. positionwise_conv_kernel_size,
  431. )
  432. if selfattention_layer_type in [
  433. "selfattn",
  434. "rel_selfattn",
  435. "legacy_rel_selfattn",
  436. ]:
  437. logging.info("encoder self-attention layer type = self-attention")
  438. encoder_selfattn_layer = MultiHeadedAttention
  439. encoder_selfattn_layer_args = [
  440. (
  441. attention_heads,
  442. attention_dim,
  443. attention_dropout_rate,
  444. )
  445. ] * num_blocks
  446. elif selfattention_layer_type == "lightconv":
  447. logging.info("encoder self-attention layer type = lightweight convolution")
  448. encoder_selfattn_layer = LightweightConvolution
  449. encoder_selfattn_layer_args = [
  450. (
  451. conv_wshare,
  452. attention_dim,
  453. attention_dropout_rate,
  454. int(conv_kernel_length.split("_")[lnum]),
  455. False,
  456. conv_usebias,
  457. )
  458. for lnum in range(num_blocks)
  459. ]
  460. elif selfattention_layer_type == "lightconv2d":
  461. logging.info(
  462. "encoder self-attention layer "
  463. "type = lightweight convolution 2-dimensional"
  464. )
  465. encoder_selfattn_layer = LightweightConvolution2D
  466. encoder_selfattn_layer_args = [
  467. (
  468. conv_wshare,
  469. attention_dim,
  470. attention_dropout_rate,
  471. int(conv_kernel_length.split("_")[lnum]),
  472. False,
  473. conv_usebias,
  474. )
  475. for lnum in range(num_blocks)
  476. ]
  477. elif selfattention_layer_type == "dynamicconv":
  478. logging.info("encoder self-attention layer type = dynamic convolution")
  479. encoder_selfattn_layer = DynamicConvolution
  480. encoder_selfattn_layer_args = [
  481. (
  482. conv_wshare,
  483. attention_dim,
  484. attention_dropout_rate,
  485. int(conv_kernel_length.split("_")[lnum]),
  486. False,
  487. conv_usebias,
  488. )
  489. for lnum in range(num_blocks)
  490. ]
  491. elif selfattention_layer_type == "dynamicconv2d":
  492. logging.info(
  493. "encoder self-attention layer type = dynamic convolution 2-dimensional"
  494. )
  495. encoder_selfattn_layer = DynamicConvolution2D
  496. encoder_selfattn_layer_args = [
  497. (
  498. conv_wshare,
  499. attention_dim,
  500. attention_dropout_rate,
  501. int(conv_kernel_length.split("_")[lnum]),
  502. False,
  503. conv_usebias,
  504. )
  505. for lnum in range(num_blocks)
  506. ]
  507. else:
  508. raise NotImplementedError(selfattention_layer_type)
  509. self.encoders = repeat(
  510. num_blocks,
  511. lambda lnum: EncoderLayer(
  512. attention_dim,
  513. encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]),
  514. positionwise_layer(*positionwise_layer_args),
  515. dropout_rate,
  516. normalize_before,
  517. concat_after,
  518. stochastic_depth_rate * float(1 + lnum) / num_blocks,
  519. ),
  520. )
  521. if self.normalize_before:
  522. self.after_norm = LayerNorm(attention_dim)
  523. self.intermediate_layers = intermediate_layers
  524. self.use_conditioning = True if ctc_softmax is not None else False
  525. if self.use_conditioning:
  526. self.ctc_softmax = ctc_softmax
  527. self.conditioning_layer = torch.nn.Linear(
  528. conditioning_layer_dim, attention_dim
  529. )
  530. def get_positionwise_layer(
  531. self,
  532. positionwise_layer_type="linear",
  533. attention_dim=256,
  534. linear_units=2048,
  535. dropout_rate=0.1,
  536. positionwise_conv_kernel_size=1,
  537. ):
  538. """Define positionwise layer."""
  539. if positionwise_layer_type == "linear":
  540. positionwise_layer = PositionwiseFeedForward
  541. positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
  542. elif positionwise_layer_type == "conv1d":
  543. positionwise_layer = MultiLayeredConv1d
  544. positionwise_layer_args = (
  545. attention_dim,
  546. linear_units,
  547. positionwise_conv_kernel_size,
  548. dropout_rate,
  549. )
  550. elif positionwise_layer_type == "conv1d-linear":
  551. positionwise_layer = Conv1dLinear
  552. positionwise_layer_args = (
  553. attention_dim,
  554. linear_units,
  555. positionwise_conv_kernel_size,
  556. dropout_rate,
  557. )
  558. else:
  559. raise NotImplementedError("Support only linear or conv1d.")
  560. return positionwise_layer, positionwise_layer_args
  561. def forward(self, xs, masks):
  562. """Encode input sequence.
  563. Args:
  564. xs (torch.Tensor): Input tensor (#batch, time, idim).
  565. masks (torch.Tensor): Mask tensor (#batch, time).
  566. Returns:
  567. torch.Tensor: Output tensor (#batch, time, attention_dim).
  568. torch.Tensor: Mask tensor (#batch, time).
  569. """
  570. if isinstance(
  571. self.embed,
  572. (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8),
  573. ):
  574. xs, masks = self.embed(xs, masks)
  575. else:
  576. xs = self.embed(xs)
  577. if self.intermediate_layers is None:
  578. xs, masks = self.encoders(xs, masks)
  579. else:
  580. intermediate_outputs = []
  581. for layer_idx, encoder_layer in enumerate(self.encoders):
  582. xs, masks = encoder_layer(xs, masks)
  583. if (
  584. self.intermediate_layers is not None
  585. and layer_idx + 1 in self.intermediate_layers
  586. ):
  587. encoder_output = xs
  588. # intermediate branches also require normalization.
  589. if self.normalize_before:
  590. encoder_output = self.after_norm(encoder_output)
  591. intermediate_outputs.append(encoder_output)
  592. if self.use_conditioning:
  593. intermediate_result = self.ctc_softmax(encoder_output)
  594. xs = xs + self.conditioning_layer(intermediate_result)
  595. if self.normalize_before:
  596. xs = self.after_norm(xs)
  597. if self.intermediate_layers is not None:
  598. return xs, masks, intermediate_outputs
  599. return xs, masks
  600. def forward_one_step(self, xs, masks, cache=None):
  601. """Encode input frame.
  602. Args:
  603. xs (torch.Tensor): Input tensor.
  604. masks (torch.Tensor): Mask tensor.
  605. cache (List[torch.Tensor]): List of cache tensors.
  606. Returns:
  607. torch.Tensor: Output tensor.
  608. torch.Tensor: Mask tensor.
  609. List[torch.Tensor]: List of new cache tensors.
  610. """
  611. if isinstance(self.embed, Conv2dSubsampling):
  612. xs, masks = self.embed(xs, masks)
  613. else:
  614. xs = self.embed(xs)
  615. if cache is None:
  616. cache = [None for _ in range(len(self.encoders))]
  617. new_cache = []
  618. for c, e in zip(cache, self.encoders):
  619. xs, masks = e(xs, masks, cache=c)
  620. new_cache.append(xs)
  621. if self.normalize_before:
  622. xs = self.after_norm(xs)
  623. return xs, masks, new_cache