transformer_encoder.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. # Copyright 2019 Shigeki Karita
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Transformer encoder definition."""
  4. from typing import List
  5. from typing import Optional
  6. from typing import Tuple
  7. import torch
  8. from torch import nn
  9. import logging
  10. from funasr.models.ctc import CTC
  11. from funasr.models.encoder.abs_encoder import AbsEncoder
  12. from funasr.modules.attention import MultiHeadedAttention
  13. from funasr.modules.embedding import PositionalEncoding
  14. from funasr.modules.layer_norm import LayerNorm
  15. from funasr.modules.multi_layer_conv import Conv1dLinear
  16. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  17. from funasr.modules.nets_utils import make_pad_mask
  18. from funasr.modules.positionwise_feed_forward import (
  19. PositionwiseFeedForward, # noqa: H301
  20. )
  21. from funasr.modules.repeat import repeat
  22. from funasr.modules.nets_utils import rename_state_dict
  23. from funasr.modules.dynamic_conv import DynamicConvolution
  24. from funasr.modules.dynamic_conv2d import DynamicConvolution2D
  25. from funasr.modules.lightconv import LightweightConvolution
  26. from funasr.modules.lightconv2d import LightweightConvolution2D
  27. from funasr.modules.subsampling import Conv2dSubsampling
  28. from funasr.modules.subsampling import Conv2dSubsampling2
  29. from funasr.modules.subsampling import Conv2dSubsampling6
  30. from funasr.modules.subsampling import Conv2dSubsampling8
  31. from funasr.modules.subsampling import TooShortUttError
  32. from funasr.modules.subsampling import check_short_utt
  33. class EncoderLayer(nn.Module):
  34. """Encoder layer module.
  35. Args:
  36. size (int): Input dimension.
  37. self_attn (torch.nn.Module): Self-attention module instance.
  38. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  39. can be used as the argument.
  40. feed_forward (torch.nn.Module): Feed-forward module instance.
  41. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  42. can be used as the argument.
  43. dropout_rate (float): Dropout rate.
  44. normalize_before (bool): Whether to use layer_norm before the first block.
  45. concat_after (bool): Whether to concat attention layer's input and output.
  46. if True, additional linear will be applied.
  47. i.e. x -> x + linear(concat(x, att(x)))
  48. if False, no additional linear will be applied. i.e. x -> x + att(x)
  49. stochastic_depth_rate (float): Proability to skip this layer.
  50. During training, the layer may skip residual computation and return input
  51. as-is with given probability.
  52. """
  53. def __init__(
  54. self,
  55. size,
  56. self_attn,
  57. feed_forward,
  58. dropout_rate,
  59. normalize_before=True,
  60. concat_after=False,
  61. stochastic_depth_rate=0.0,
  62. ):
  63. """Construct an EncoderLayer object."""
  64. super(EncoderLayer, self).__init__()
  65. self.self_attn = self_attn
  66. self.feed_forward = feed_forward
  67. self.norm1 = LayerNorm(size)
  68. self.norm2 = LayerNorm(size)
  69. self.dropout = nn.Dropout(dropout_rate)
  70. self.size = size
  71. self.normalize_before = normalize_before
  72. self.concat_after = concat_after
  73. if self.concat_after:
  74. self.concat_linear = nn.Linear(size + size, size)
  75. self.stochastic_depth_rate = stochastic_depth_rate
  76. def forward(self, x, mask, cache=None):
  77. """Compute encoded features.
  78. Args:
  79. x_input (torch.Tensor): Input tensor (#batch, time, size).
  80. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  81. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  82. Returns:
  83. torch.Tensor: Output tensor (#batch, time, size).
  84. torch.Tensor: Mask tensor (#batch, time).
  85. """
  86. skip_layer = False
  87. # with stochastic depth, residual connection `x + f(x)` becomes
  88. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  89. stoch_layer_coeff = 1.0
  90. if self.training and self.stochastic_depth_rate > 0:
  91. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  92. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  93. if skip_layer:
  94. if cache is not None:
  95. x = torch.cat([cache, x], dim=1)
  96. return x, mask
  97. residual = x
  98. if self.normalize_before:
  99. x = self.norm1(x)
  100. if cache is None:
  101. x_q = x
  102. else:
  103. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  104. x_q = x[:, -1:, :]
  105. residual = residual[:, -1:, :]
  106. mask = None if mask is None else mask[:, -1:, :]
  107. if self.concat_after:
  108. x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
  109. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  110. else:
  111. x = residual + stoch_layer_coeff * self.dropout(
  112. self.self_attn(x_q, x, x, mask)
  113. )
  114. if not self.normalize_before:
  115. x = self.norm1(x)
  116. residual = x
  117. if self.normalize_before:
  118. x = self.norm2(x)
  119. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  120. if not self.normalize_before:
  121. x = self.norm2(x)
  122. if cache is not None:
  123. x = torch.cat([cache, x], dim=1)
  124. return x, mask
  125. class TransformerEncoder(AbsEncoder):
  126. """Transformer encoder module.
  127. Args:
  128. input_size: input dim
  129. output_size: dimension of attention
  130. attention_heads: the number of heads of multi head attention
  131. linear_units: the number of units of position-wise feed forward
  132. num_blocks: the number of decoder blocks
  133. dropout_rate: dropout rate
  134. attention_dropout_rate: dropout rate in attention
  135. positional_dropout_rate: dropout rate after adding positional encoding
  136. input_layer: input layer type
  137. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  138. normalize_before: whether to use layer_norm before the first block
  139. concat_after: whether to concat attention layer's input and output
  140. if True, additional linear will be applied.
  141. i.e. x -> x + linear(concat(x, att(x)))
  142. if False, no additional linear will be applied.
  143. i.e. x -> x + att(x)
  144. positionwise_layer_type: linear of conv1d
  145. positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
  146. padding_idx: padding_idx for input_layer=embed
  147. """
  148. def __init__(
  149. self,
  150. input_size: int,
  151. output_size: int = 256,
  152. attention_heads: int = 4,
  153. linear_units: int = 2048,
  154. num_blocks: int = 6,
  155. dropout_rate: float = 0.1,
  156. positional_dropout_rate: float = 0.1,
  157. attention_dropout_rate: float = 0.0,
  158. input_layer: Optional[str] = "conv2d",
  159. pos_enc_class=PositionalEncoding,
  160. normalize_before: bool = True,
  161. concat_after: bool = False,
  162. positionwise_layer_type: str = "linear",
  163. positionwise_conv_kernel_size: int = 1,
  164. padding_idx: int = -1,
  165. interctc_layer_idx: List[int] = [],
  166. interctc_use_conditioning: bool = False,
  167. ):
  168. super().__init__()
  169. self._output_size = output_size
  170. if input_layer == "linear":
  171. self.embed = torch.nn.Sequential(
  172. torch.nn.Linear(input_size, output_size),
  173. torch.nn.LayerNorm(output_size),
  174. torch.nn.Dropout(dropout_rate),
  175. torch.nn.ReLU(),
  176. pos_enc_class(output_size, positional_dropout_rate),
  177. )
  178. elif input_layer == "conv2d":
  179. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  180. elif input_layer == "conv2d2":
  181. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  182. elif input_layer == "conv2d6":
  183. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  184. elif input_layer == "conv2d8":
  185. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  186. elif input_layer == "embed":
  187. self.embed = torch.nn.Sequential(
  188. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  189. pos_enc_class(output_size, positional_dropout_rate),
  190. )
  191. elif input_layer is None:
  192. if input_size == output_size:
  193. self.embed = None
  194. else:
  195. self.embed = torch.nn.Linear(input_size, output_size)
  196. else:
  197. raise ValueError("unknown input_layer: " + input_layer)
  198. self.normalize_before = normalize_before
  199. if positionwise_layer_type == "linear":
  200. positionwise_layer = PositionwiseFeedForward
  201. positionwise_layer_args = (
  202. output_size,
  203. linear_units,
  204. dropout_rate,
  205. )
  206. elif positionwise_layer_type == "conv1d":
  207. positionwise_layer = MultiLayeredConv1d
  208. positionwise_layer_args = (
  209. output_size,
  210. linear_units,
  211. positionwise_conv_kernel_size,
  212. dropout_rate,
  213. )
  214. elif positionwise_layer_type == "conv1d-linear":
  215. positionwise_layer = Conv1dLinear
  216. positionwise_layer_args = (
  217. output_size,
  218. linear_units,
  219. positionwise_conv_kernel_size,
  220. dropout_rate,
  221. )
  222. else:
  223. raise NotImplementedError("Support only linear or conv1d.")
  224. self.encoders = repeat(
  225. num_blocks,
  226. lambda lnum: EncoderLayer(
  227. output_size,
  228. MultiHeadedAttention(
  229. attention_heads, output_size, attention_dropout_rate
  230. ),
  231. positionwise_layer(*positionwise_layer_args),
  232. dropout_rate,
  233. normalize_before,
  234. concat_after,
  235. ),
  236. )
  237. if self.normalize_before:
  238. self.after_norm = LayerNorm(output_size)
  239. self.interctc_layer_idx = interctc_layer_idx
  240. if len(interctc_layer_idx) > 0:
  241. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  242. self.interctc_use_conditioning = interctc_use_conditioning
  243. self.conditioning_layer = None
  244. def output_size(self) -> int:
  245. return self._output_size
  246. def forward(
  247. self,
  248. xs_pad: torch.Tensor,
  249. ilens: torch.Tensor,
  250. prev_states: torch.Tensor = None,
  251. ctc: CTC = None,
  252. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  253. """Embed positions in tensor.
  254. Args:
  255. xs_pad: input tensor (B, L, D)
  256. ilens: input length (B)
  257. prev_states: Not to be used now.
  258. Returns:
  259. position embedded tensor and mask
  260. """
  261. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  262. if self.embed is None:
  263. xs_pad = xs_pad
  264. elif (
  265. isinstance(self.embed, Conv2dSubsampling)
  266. or isinstance(self.embed, Conv2dSubsampling2)
  267. or isinstance(self.embed, Conv2dSubsampling6)
  268. or isinstance(self.embed, Conv2dSubsampling8)
  269. ):
  270. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  271. if short_status:
  272. raise TooShortUttError(
  273. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  274. + f"(it needs more than {limit_size} frames), return empty results",
  275. xs_pad.size(1),
  276. limit_size,
  277. )
  278. xs_pad, masks = self.embed(xs_pad, masks)
  279. else:
  280. xs_pad = self.embed(xs_pad)
  281. intermediate_outs = []
  282. if len(self.interctc_layer_idx) == 0:
  283. xs_pad, masks = self.encoders(xs_pad, masks)
  284. else:
  285. for layer_idx, encoder_layer in enumerate(self.encoders):
  286. xs_pad, masks = encoder_layer(xs_pad, masks)
  287. if layer_idx + 1 in self.interctc_layer_idx:
  288. encoder_out = xs_pad
  289. # intermediate outputs are also normalized
  290. if self.normalize_before:
  291. encoder_out = self.after_norm(encoder_out)
  292. intermediate_outs.append((layer_idx + 1, encoder_out))
  293. if self.interctc_use_conditioning:
  294. ctc_out = ctc.softmax(encoder_out)
  295. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  296. if self.normalize_before:
  297. xs_pad = self.after_norm(xs_pad)
  298. olens = masks.squeeze(1).sum(1)
  299. if len(intermediate_outs) > 0:
  300. return (xs_pad, intermediate_outs), olens, None
  301. return xs_pad, olens, None
  302. def _pre_hook(
  303. state_dict,
  304. prefix,
  305. local_metadata,
  306. strict,
  307. missing_keys,
  308. unexpected_keys,
  309. error_msgs,
  310. ):
  311. # https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
  312. rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
  313. # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
  314. rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
  315. class TransformerEncoder_s0(torch.nn.Module):
  316. """Transformer encoder module.
  317. Args:
  318. idim (int): Input dimension.
  319. attention_dim (int): Dimension of attention.
  320. attention_heads (int): The number of heads of multi head attention.
  321. conv_wshare (int): The number of kernel of convolution. Only used in
  322. selfattention_layer_type == "lightconv*" or "dynamiconv*".
  323. conv_kernel_length (Union[int, str]): Kernel size str of convolution
  324. (e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type
  325. == "lightconv*" or "dynamiconv*".
  326. conv_usebias (bool): Whether to use bias in convolution. Only used in
  327. selfattention_layer_type == "lightconv*" or "dynamiconv*".
  328. linear_units (int): The number of units of position-wise feed forward.
  329. num_blocks (int): The number of decoder blocks.
  330. dropout_rate (float): Dropout rate.
  331. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  332. attention_dropout_rate (float): Dropout rate in attention.
  333. input_layer (Union[str, torch.nn.Module]): Input layer type.
  334. pos_enc_class (torch.nn.Module): Positional encoding module class.
  335. `PositionalEncoding `or `ScaledPositionalEncoding`
  336. normalize_before (bool): Whether to use layer_norm before the first block.
  337. concat_after (bool): Whether to concat attention layer's input and output.
  338. if True, additional linear will be applied.
  339. i.e. x -> x + linear(concat(x, att(x)))
  340. if False, no additional linear will be applied. i.e. x -> x + att(x)
  341. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  342. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  343. selfattention_layer_type (str): Encoder attention layer type.
  344. padding_idx (int): Padding idx for input_layer=embed.
  345. stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
  346. intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
  347. indices start from 1.
  348. if not None, intermediate outputs are returned (which changes return type
  349. signature.)
  350. """
  351. def __init__(
  352. self,
  353. idim,
  354. attention_dim=256,
  355. attention_heads=4,
  356. conv_wshare=4,
  357. conv_kernel_length="11",
  358. conv_usebias=False,
  359. linear_units=2048,
  360. num_blocks=6,
  361. dropout_rate=0.1,
  362. positional_dropout_rate=0.1,
  363. attention_dropout_rate=0.0,
  364. input_layer="conv2d",
  365. pos_enc_class=PositionalEncoding,
  366. normalize_before=True,
  367. concat_after=False,
  368. positionwise_layer_type="linear",
  369. positionwise_conv_kernel_size=1,
  370. selfattention_layer_type="selfattn",
  371. padding_idx=-1,
  372. stochastic_depth_rate=0.0,
  373. intermediate_layers=None,
  374. ctc_softmax=None,
  375. conditioning_layer_dim=None,
  376. ):
  377. """Construct an Encoder object."""
  378. super(TransformerEncoder_s0, self).__init__()
  379. self._register_load_state_dict_pre_hook(_pre_hook)
  380. self.conv_subsampling_factor = 1
  381. if input_layer == "linear":
  382. self.embed = torch.nn.Sequential(
  383. torch.nn.Linear(idim, attention_dim),
  384. torch.nn.LayerNorm(attention_dim),
  385. torch.nn.Dropout(dropout_rate),
  386. torch.nn.ReLU(),
  387. pos_enc_class(attention_dim, positional_dropout_rate),
  388. )
  389. elif input_layer == "conv2d":
  390. self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
  391. self.conv_subsampling_factor = 4
  392. elif input_layer == "conv2d-scaled-pos-enc":
  393. self.embed = Conv2dSubsampling(
  394. idim,
  395. attention_dim,
  396. dropout_rate,
  397. pos_enc_class(attention_dim, positional_dropout_rate),
  398. )
  399. self.conv_subsampling_factor = 4
  400. elif input_layer == "conv2d6":
  401. self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate)
  402. self.conv_subsampling_factor = 6
  403. elif input_layer == "conv2d8":
  404. self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate)
  405. self.conv_subsampling_factor = 8
  406. elif input_layer == "embed":
  407. self.embed = torch.nn.Sequential(
  408. torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
  409. pos_enc_class(attention_dim, positional_dropout_rate),
  410. )
  411. elif isinstance(input_layer, torch.nn.Module):
  412. self.embed = torch.nn.Sequential(
  413. input_layer,
  414. pos_enc_class(attention_dim, positional_dropout_rate),
  415. )
  416. elif input_layer is None:
  417. self.embed = torch.nn.Sequential(
  418. pos_enc_class(attention_dim, positional_dropout_rate)
  419. )
  420. else:
  421. raise ValueError("unknown input_layer: " + input_layer)
  422. self.normalize_before = normalize_before
  423. positionwise_layer, positionwise_layer_args = self.get_positionwise_layer(
  424. positionwise_layer_type,
  425. attention_dim,
  426. linear_units,
  427. dropout_rate,
  428. positionwise_conv_kernel_size,
  429. )
  430. if selfattention_layer_type in [
  431. "selfattn",
  432. "rel_selfattn",
  433. "legacy_rel_selfattn",
  434. ]:
  435. logging.info("encoder self-attention layer type = self-attention")
  436. encoder_selfattn_layer = MultiHeadedAttention
  437. encoder_selfattn_layer_args = [
  438. (
  439. attention_heads,
  440. attention_dim,
  441. attention_dropout_rate,
  442. )
  443. ] * num_blocks
  444. elif selfattention_layer_type == "lightconv":
  445. logging.info("encoder self-attention layer type = lightweight convolution")
  446. encoder_selfattn_layer = LightweightConvolution
  447. encoder_selfattn_layer_args = [
  448. (
  449. conv_wshare,
  450. attention_dim,
  451. attention_dropout_rate,
  452. int(conv_kernel_length.split("_")[lnum]),
  453. False,
  454. conv_usebias,
  455. )
  456. for lnum in range(num_blocks)
  457. ]
  458. elif selfattention_layer_type == "lightconv2d":
  459. logging.info(
  460. "encoder self-attention layer "
  461. "type = lightweight convolution 2-dimensional"
  462. )
  463. encoder_selfattn_layer = LightweightConvolution2D
  464. encoder_selfattn_layer_args = [
  465. (
  466. conv_wshare,
  467. attention_dim,
  468. attention_dropout_rate,
  469. int(conv_kernel_length.split("_")[lnum]),
  470. False,
  471. conv_usebias,
  472. )
  473. for lnum in range(num_blocks)
  474. ]
  475. elif selfattention_layer_type == "dynamicconv":
  476. logging.info("encoder self-attention layer type = dynamic convolution")
  477. encoder_selfattn_layer = DynamicConvolution
  478. encoder_selfattn_layer_args = [
  479. (
  480. conv_wshare,
  481. attention_dim,
  482. attention_dropout_rate,
  483. int(conv_kernel_length.split("_")[lnum]),
  484. False,
  485. conv_usebias,
  486. )
  487. for lnum in range(num_blocks)
  488. ]
  489. elif selfattention_layer_type == "dynamicconv2d":
  490. logging.info(
  491. "encoder self-attention layer type = dynamic convolution 2-dimensional"
  492. )
  493. encoder_selfattn_layer = DynamicConvolution2D
  494. encoder_selfattn_layer_args = [
  495. (
  496. conv_wshare,
  497. attention_dim,
  498. attention_dropout_rate,
  499. int(conv_kernel_length.split("_")[lnum]),
  500. False,
  501. conv_usebias,
  502. )
  503. for lnum in range(num_blocks)
  504. ]
  505. else:
  506. raise NotImplementedError(selfattention_layer_type)
  507. self.encoders = repeat(
  508. num_blocks,
  509. lambda lnum: EncoderLayer(
  510. attention_dim,
  511. encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]),
  512. positionwise_layer(*positionwise_layer_args),
  513. dropout_rate,
  514. normalize_before,
  515. concat_after,
  516. stochastic_depth_rate * float(1 + lnum) / num_blocks,
  517. ),
  518. )
  519. if self.normalize_before:
  520. self.after_norm = LayerNorm(attention_dim)
  521. self.intermediate_layers = intermediate_layers
  522. self.use_conditioning = True if ctc_softmax is not None else False
  523. if self.use_conditioning:
  524. self.ctc_softmax = ctc_softmax
  525. self.conditioning_layer = torch.nn.Linear(
  526. conditioning_layer_dim, attention_dim
  527. )
  528. def get_positionwise_layer(
  529. self,
  530. positionwise_layer_type="linear",
  531. attention_dim=256,
  532. linear_units=2048,
  533. dropout_rate=0.1,
  534. positionwise_conv_kernel_size=1,
  535. ):
  536. """Define positionwise layer."""
  537. if positionwise_layer_type == "linear":
  538. positionwise_layer = PositionwiseFeedForward
  539. positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
  540. elif positionwise_layer_type == "conv1d":
  541. positionwise_layer = MultiLayeredConv1d
  542. positionwise_layer_args = (
  543. attention_dim,
  544. linear_units,
  545. positionwise_conv_kernel_size,
  546. dropout_rate,
  547. )
  548. elif positionwise_layer_type == "conv1d-linear":
  549. positionwise_layer = Conv1dLinear
  550. positionwise_layer_args = (
  551. attention_dim,
  552. linear_units,
  553. positionwise_conv_kernel_size,
  554. dropout_rate,
  555. )
  556. else:
  557. raise NotImplementedError("Support only linear or conv1d.")
  558. return positionwise_layer, positionwise_layer_args
  559. def forward(self, xs, masks):
  560. """Encode input sequence.
  561. Args:
  562. xs (torch.Tensor): Input tensor (#batch, time, idim).
  563. masks (torch.Tensor): Mask tensor (#batch, time).
  564. Returns:
  565. torch.Tensor: Output tensor (#batch, time, attention_dim).
  566. torch.Tensor: Mask tensor (#batch, time).
  567. """
  568. if isinstance(
  569. self.embed,
  570. (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8),
  571. ):
  572. xs, masks = self.embed(xs, masks)
  573. else:
  574. xs = self.embed(xs)
  575. if self.intermediate_layers is None:
  576. xs, masks = self.encoders(xs, masks)
  577. else:
  578. intermediate_outputs = []
  579. for layer_idx, encoder_layer in enumerate(self.encoders):
  580. xs, masks = encoder_layer(xs, masks)
  581. if (
  582. self.intermediate_layers is not None
  583. and layer_idx + 1 in self.intermediate_layers
  584. ):
  585. encoder_output = xs
  586. # intermediate branches also require normalization.
  587. if self.normalize_before:
  588. encoder_output = self.after_norm(encoder_output)
  589. intermediate_outputs.append(encoder_output)
  590. if self.use_conditioning:
  591. intermediate_result = self.ctc_softmax(encoder_output)
  592. xs = xs + self.conditioning_layer(intermediate_result)
  593. if self.normalize_before:
  594. xs = self.after_norm(xs)
  595. if self.intermediate_layers is not None:
  596. return xs, masks, intermediate_outputs
  597. return xs, masks
  598. def forward_one_step(self, xs, masks, cache=None):
  599. """Encode input frame.
  600. Args:
  601. xs (torch.Tensor): Input tensor.
  602. masks (torch.Tensor): Mask tensor.
  603. cache (List[torch.Tensor]): List of cache tensors.
  604. Returns:
  605. torch.Tensor: Output tensor.
  606. torch.Tensor: Mask tensor.
  607. List[torch.Tensor]: List of new cache tensors.
  608. """
  609. if isinstance(self.embed, Conv2dSubsampling):
  610. xs, masks = self.embed(xs, masks)
  611. else:
  612. xs = self.embed(xs)
  613. if cache is None:
  614. cache = [None for _ in range(len(self.encoders))]
  615. new_cache = []
  616. for c, e in zip(cache, self.encoders):
  617. xs, masks = e(xs, masks, cache=c)
  618. new_cache.append(xs)
  619. if self.normalize_before:
  620. xs = self.after_norm(xs)
  621. return xs, masks, new_cache