conformer_encoder.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. # Copyright 2020 Tomoki Hayashi
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Conformer encoder definition."""
  4. import logging
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. from torch import nn
  11. from typeguard import check_argument_types
  12. from funasr.models.ctc import CTC
  13. from funasr.models.encoder.abs_encoder import AbsEncoder
  14. from funasr.modules.attention import (
  15. MultiHeadedAttention, # noqa: H301
  16. RelPositionMultiHeadedAttention, # noqa: H301
  17. LegacyRelPositionMultiHeadedAttention, # noqa: H301
  18. )
  19. from funasr.modules.embedding import (
  20. PositionalEncoding, # noqa: H301
  21. ScaledPositionalEncoding, # noqa: H301
  22. RelPositionalEncoding, # noqa: H301
  23. LegacyRelPositionalEncoding, # noqa: H301
  24. )
  25. from funasr.modules.layer_norm import LayerNorm
  26. from funasr.modules.multi_layer_conv import Conv1dLinear
  27. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  28. from funasr.modules.nets_utils import get_activation
  29. from funasr.modules.nets_utils import make_pad_mask
  30. from funasr.modules.positionwise_feed_forward import (
  31. PositionwiseFeedForward, # noqa: H301
  32. )
  33. from funasr.modules.repeat import repeat
  34. from funasr.modules.subsampling import Conv2dSubsampling
  35. from funasr.modules.subsampling import Conv2dSubsampling2
  36. from funasr.modules.subsampling import Conv2dSubsampling6
  37. from funasr.modules.subsampling import Conv2dSubsampling8
  38. from funasr.modules.subsampling import TooShortUttError
  39. from funasr.modules.subsampling import check_short_utt
  40. class ConvolutionModule(nn.Module):
  41. """ConvolutionModule in Conformer model.
  42. Args:
  43. channels (int): The number of channels of conv layers.
  44. kernel_size (int): Kernerl size of conv layers.
  45. """
  46. def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
  47. """Construct an ConvolutionModule object."""
  48. super(ConvolutionModule, self).__init__()
  49. # kernerl_size should be a odd number for 'SAME' padding
  50. assert (kernel_size - 1) % 2 == 0
  51. self.pointwise_conv1 = nn.Conv1d(
  52. channels,
  53. 2 * channels,
  54. kernel_size=1,
  55. stride=1,
  56. padding=0,
  57. bias=bias,
  58. )
  59. self.depthwise_conv = nn.Conv1d(
  60. channels,
  61. channels,
  62. kernel_size,
  63. stride=1,
  64. padding=(kernel_size - 1) // 2,
  65. groups=channels,
  66. bias=bias,
  67. )
  68. self.norm = nn.BatchNorm1d(channels)
  69. self.pointwise_conv2 = nn.Conv1d(
  70. channels,
  71. channels,
  72. kernel_size=1,
  73. stride=1,
  74. padding=0,
  75. bias=bias,
  76. )
  77. self.activation = activation
  78. def forward(self, x):
  79. """Compute convolution module.
  80. Args:
  81. x (torch.Tensor): Input tensor (#batch, time, channels).
  82. Returns:
  83. torch.Tensor: Output tensor (#batch, time, channels).
  84. """
  85. # exchange the temporal dimension and the feature dimension
  86. x = x.transpose(1, 2)
  87. # GLU mechanism
  88. x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
  89. x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
  90. # 1D Depthwise Conv
  91. x = self.depthwise_conv(x)
  92. x = self.activation(self.norm(x))
  93. x = self.pointwise_conv2(x)
  94. return x.transpose(1, 2)
  95. class EncoderLayer(nn.Module):
  96. """Encoder layer module.
  97. Args:
  98. size (int): Input dimension.
  99. self_attn (torch.nn.Module): Self-attention module instance.
  100. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  101. can be used as the argument.
  102. feed_forward (torch.nn.Module): Feed-forward module instance.
  103. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  104. can be used as the argument.
  105. feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
  106. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  107. can be used as the argument.
  108. conv_module (torch.nn.Module): Convolution module instance.
  109. `ConvlutionModule` instance can be used as the argument.
  110. dropout_rate (float): Dropout rate.
  111. normalize_before (bool): Whether to use layer_norm before the first block.
  112. concat_after (bool): Whether to concat attention layer's input and output.
  113. if True, additional linear will be applied.
  114. i.e. x -> x + linear(concat(x, att(x)))
  115. if False, no additional linear will be applied. i.e. x -> x + att(x)
  116. stochastic_depth_rate (float): Proability to skip this layer.
  117. During training, the layer may skip residual computation and return input
  118. as-is with given probability.
  119. """
  120. def __init__(
  121. self,
  122. size,
  123. self_attn,
  124. feed_forward,
  125. feed_forward_macaron,
  126. conv_module,
  127. dropout_rate,
  128. normalize_before=True,
  129. concat_after=False,
  130. stochastic_depth_rate=0.0,
  131. ):
  132. """Construct an EncoderLayer object."""
  133. super(EncoderLayer, self).__init__()
  134. self.self_attn = self_attn
  135. self.feed_forward = feed_forward
  136. self.feed_forward_macaron = feed_forward_macaron
  137. self.conv_module = conv_module
  138. self.norm_ff = LayerNorm(size) # for the FNN module
  139. self.norm_mha = LayerNorm(size) # for the MHA module
  140. if feed_forward_macaron is not None:
  141. self.norm_ff_macaron = LayerNorm(size)
  142. self.ff_scale = 0.5
  143. else:
  144. self.ff_scale = 1.0
  145. if self.conv_module is not None:
  146. self.norm_conv = LayerNorm(size) # for the CNN module
  147. self.norm_final = LayerNorm(size) # for the final output of the block
  148. self.dropout = nn.Dropout(dropout_rate)
  149. self.size = size
  150. self.normalize_before = normalize_before
  151. self.concat_after = concat_after
  152. if self.concat_after:
  153. self.concat_linear = nn.Linear(size + size, size)
  154. self.stochastic_depth_rate = stochastic_depth_rate
  155. def forward(self, x_input, mask, cache=None):
  156. """Compute encoded features.
  157. Args:
  158. x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
  159. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
  160. - w/o pos emb: Tensor (#batch, time, size).
  161. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  162. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  163. Returns:
  164. torch.Tensor: Output tensor (#batch, time, size).
  165. torch.Tensor: Mask tensor (#batch, time).
  166. """
  167. if isinstance(x_input, tuple):
  168. x, pos_emb = x_input[0], x_input[1]
  169. else:
  170. x, pos_emb = x_input, None
  171. skip_layer = False
  172. # with stochastic depth, residual connection `x + f(x)` becomes
  173. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  174. stoch_layer_coeff = 1.0
  175. if self.training and self.stochastic_depth_rate > 0:
  176. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  177. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  178. if skip_layer:
  179. if cache is not None:
  180. x = torch.cat([cache, x], dim=1)
  181. if pos_emb is not None:
  182. return (x, pos_emb), mask
  183. return x, mask
  184. # whether to use macaron style
  185. if self.feed_forward_macaron is not None:
  186. residual = x
  187. if self.normalize_before:
  188. x = self.norm_ff_macaron(x)
  189. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  190. self.feed_forward_macaron(x)
  191. )
  192. if not self.normalize_before:
  193. x = self.norm_ff_macaron(x)
  194. # multi-headed self-attention module
  195. residual = x
  196. if self.normalize_before:
  197. x = self.norm_mha(x)
  198. if cache is None:
  199. x_q = x
  200. else:
  201. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  202. x_q = x[:, -1:, :]
  203. residual = residual[:, -1:, :]
  204. mask = None if mask is None else mask[:, -1:, :]
  205. if pos_emb is not None:
  206. x_att = self.self_attn(x_q, x, x, pos_emb, mask)
  207. else:
  208. x_att = self.self_attn(x_q, x, x, mask)
  209. if self.concat_after:
  210. x_concat = torch.cat((x, x_att), dim=-1)
  211. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  212. else:
  213. x = residual + stoch_layer_coeff * self.dropout(x_att)
  214. if not self.normalize_before:
  215. x = self.norm_mha(x)
  216. # convolution module
  217. if self.conv_module is not None:
  218. residual = x
  219. if self.normalize_before:
  220. x = self.norm_conv(x)
  221. x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
  222. if not self.normalize_before:
  223. x = self.norm_conv(x)
  224. # feed forward module
  225. residual = x
  226. if self.normalize_before:
  227. x = self.norm_ff(x)
  228. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  229. self.feed_forward(x)
  230. )
  231. if not self.normalize_before:
  232. x = self.norm_ff(x)
  233. if self.conv_module is not None:
  234. x = self.norm_final(x)
  235. if cache is not None:
  236. x = torch.cat([cache, x], dim=1)
  237. if pos_emb is not None:
  238. return (x, pos_emb), mask
  239. return x, mask
  240. class ConformerEncoder(AbsEncoder):
  241. """Conformer encoder module.
  242. Args:
  243. input_size (int): Input dimension.
  244. output_size (int): Dimension of attention.
  245. attention_heads (int): The number of heads of multi head attention.
  246. linear_units (int): The number of units of position-wise feed forward.
  247. num_blocks (int): The number of decoder blocks.
  248. dropout_rate (float): Dropout rate.
  249. attention_dropout_rate (float): Dropout rate in attention.
  250. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  251. input_layer (Union[str, torch.nn.Module]): Input layer type.
  252. normalize_before (bool): Whether to use layer_norm before the first block.
  253. concat_after (bool): Whether to concat attention layer's input and output.
  254. If True, additional linear will be applied.
  255. i.e. x -> x + linear(concat(x, att(x)))
  256. If False, no additional linear will be applied. i.e. x -> x + att(x)
  257. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  258. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  259. rel_pos_type (str): Whether to use the latest relative positional encoding or
  260. the legacy one. The legacy relative positional encoding will be deprecated
  261. in the future. More Details can be found in
  262. https://github.com/espnet/espnet/pull/2816.
  263. encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
  264. encoder_attn_layer_type (str): Encoder attention layer type.
  265. activation_type (str): Encoder activation function type.
  266. macaron_style (bool): Whether to use macaron style for positionwise layer.
  267. use_cnn_module (bool): Whether to use convolution module.
  268. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  269. cnn_module_kernel (int): Kernerl size of convolution module.
  270. padding_idx (int): Padding idx for input_layer=embed.
  271. """
  272. def __init__(
  273. self,
  274. input_size: int,
  275. output_size: int = 256,
  276. attention_heads: int = 4,
  277. linear_units: int = 2048,
  278. num_blocks: int = 6,
  279. dropout_rate: float = 0.1,
  280. positional_dropout_rate: float = 0.1,
  281. attention_dropout_rate: float = 0.0,
  282. input_layer: str = "conv2d",
  283. normalize_before: bool = True,
  284. concat_after: bool = False,
  285. positionwise_layer_type: str = "linear",
  286. positionwise_conv_kernel_size: int = 3,
  287. macaron_style: bool = False,
  288. rel_pos_type: str = "legacy",
  289. pos_enc_layer_type: str = "rel_pos",
  290. selfattention_layer_type: str = "rel_selfattn",
  291. activation_type: str = "swish",
  292. use_cnn_module: bool = True,
  293. zero_triu: bool = False,
  294. cnn_module_kernel: int = 31,
  295. padding_idx: int = -1,
  296. interctc_layer_idx: List[int] = [],
  297. interctc_use_conditioning: bool = False,
  298. stochastic_depth_rate: Union[float, List[float]] = 0.0,
  299. ):
  300. assert check_argument_types()
  301. super().__init__()
  302. self._output_size = output_size
  303. if rel_pos_type == "legacy":
  304. if pos_enc_layer_type == "rel_pos":
  305. pos_enc_layer_type = "legacy_rel_pos"
  306. if selfattention_layer_type == "rel_selfattn":
  307. selfattention_layer_type = "legacy_rel_selfattn"
  308. elif rel_pos_type == "latest":
  309. assert selfattention_layer_type != "legacy_rel_selfattn"
  310. assert pos_enc_layer_type != "legacy_rel_pos"
  311. else:
  312. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  313. activation = get_activation(activation_type)
  314. if pos_enc_layer_type == "abs_pos":
  315. pos_enc_class = PositionalEncoding
  316. elif pos_enc_layer_type == "scaled_abs_pos":
  317. pos_enc_class = ScaledPositionalEncoding
  318. elif pos_enc_layer_type == "rel_pos":
  319. assert selfattention_layer_type == "rel_selfattn"
  320. pos_enc_class = RelPositionalEncoding
  321. elif pos_enc_layer_type == "legacy_rel_pos":
  322. assert selfattention_layer_type == "legacy_rel_selfattn"
  323. pos_enc_class = LegacyRelPositionalEncoding
  324. logging.warning(
  325. "Using legacy_rel_pos and it will be deprecated in the future."
  326. )
  327. else:
  328. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  329. if input_layer == "linear":
  330. self.embed = torch.nn.Sequential(
  331. torch.nn.Linear(input_size, output_size),
  332. torch.nn.LayerNorm(output_size),
  333. torch.nn.Dropout(dropout_rate),
  334. pos_enc_class(output_size, positional_dropout_rate),
  335. )
  336. elif input_layer == "conv2d":
  337. self.embed = Conv2dSubsampling(
  338. input_size,
  339. output_size,
  340. dropout_rate,
  341. pos_enc_class(output_size, positional_dropout_rate),
  342. )
  343. elif input_layer == "conv2d2":
  344. self.embed = Conv2dSubsampling2(
  345. input_size,
  346. output_size,
  347. dropout_rate,
  348. pos_enc_class(output_size, positional_dropout_rate),
  349. )
  350. elif input_layer == "conv2d6":
  351. self.embed = Conv2dSubsampling6(
  352. input_size,
  353. output_size,
  354. dropout_rate,
  355. pos_enc_class(output_size, positional_dropout_rate),
  356. )
  357. elif input_layer == "conv2d8":
  358. self.embed = Conv2dSubsampling8(
  359. input_size,
  360. output_size,
  361. dropout_rate,
  362. pos_enc_class(output_size, positional_dropout_rate),
  363. )
  364. elif input_layer == "embed":
  365. self.embed = torch.nn.Sequential(
  366. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  367. pos_enc_class(output_size, positional_dropout_rate),
  368. )
  369. elif isinstance(input_layer, torch.nn.Module):
  370. self.embed = torch.nn.Sequential(
  371. input_layer,
  372. pos_enc_class(output_size, positional_dropout_rate),
  373. )
  374. elif input_layer is None:
  375. self.embed = torch.nn.Sequential(
  376. pos_enc_class(output_size, positional_dropout_rate)
  377. )
  378. else:
  379. raise ValueError("unknown input_layer: " + input_layer)
  380. self.normalize_before = normalize_before
  381. if positionwise_layer_type == "linear":
  382. positionwise_layer = PositionwiseFeedForward
  383. positionwise_layer_args = (
  384. output_size,
  385. linear_units,
  386. dropout_rate,
  387. activation,
  388. )
  389. elif positionwise_layer_type == "conv1d":
  390. positionwise_layer = MultiLayeredConv1d
  391. positionwise_layer_args = (
  392. output_size,
  393. linear_units,
  394. positionwise_conv_kernel_size,
  395. dropout_rate,
  396. )
  397. elif positionwise_layer_type == "conv1d-linear":
  398. positionwise_layer = Conv1dLinear
  399. positionwise_layer_args = (
  400. output_size,
  401. linear_units,
  402. positionwise_conv_kernel_size,
  403. dropout_rate,
  404. )
  405. else:
  406. raise NotImplementedError("Support only linear or conv1d.")
  407. if selfattention_layer_type == "selfattn":
  408. encoder_selfattn_layer = MultiHeadedAttention
  409. encoder_selfattn_layer_args = (
  410. attention_heads,
  411. output_size,
  412. attention_dropout_rate,
  413. )
  414. elif selfattention_layer_type == "legacy_rel_selfattn":
  415. assert pos_enc_layer_type == "legacy_rel_pos"
  416. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  417. encoder_selfattn_layer_args = (
  418. attention_heads,
  419. output_size,
  420. attention_dropout_rate,
  421. )
  422. logging.warning(
  423. "Using legacy_rel_selfattn and it will be deprecated in the future."
  424. )
  425. elif selfattention_layer_type == "rel_selfattn":
  426. assert pos_enc_layer_type == "rel_pos"
  427. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  428. encoder_selfattn_layer_args = (
  429. attention_heads,
  430. output_size,
  431. attention_dropout_rate,
  432. zero_triu,
  433. )
  434. else:
  435. raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
  436. convolution_layer = ConvolutionModule
  437. convolution_layer_args = (output_size, cnn_module_kernel, activation)
  438. if isinstance(stochastic_depth_rate, float):
  439. stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
  440. if len(stochastic_depth_rate) != num_blocks:
  441. raise ValueError(
  442. f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
  443. f"should be equal to num_blocks ({num_blocks})"
  444. )
  445. self.encoders = repeat(
  446. num_blocks,
  447. lambda lnum: EncoderLayer(
  448. output_size,
  449. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  450. positionwise_layer(*positionwise_layer_args),
  451. positionwise_layer(*positionwise_layer_args) if macaron_style else None,
  452. convolution_layer(*convolution_layer_args) if use_cnn_module else None,
  453. dropout_rate,
  454. normalize_before,
  455. concat_after,
  456. stochastic_depth_rate[lnum],
  457. ),
  458. )
  459. if self.normalize_before:
  460. self.after_norm = LayerNorm(output_size)
  461. self.interctc_layer_idx = interctc_layer_idx
  462. if len(interctc_layer_idx) > 0:
  463. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  464. self.interctc_use_conditioning = interctc_use_conditioning
  465. self.conditioning_layer = None
  466. def output_size(self) -> int:
  467. return self._output_size
  468. def forward(
  469. self,
  470. xs_pad: torch.Tensor,
  471. ilens: torch.Tensor,
  472. prev_states: torch.Tensor = None,
  473. ctc: CTC = None,
  474. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  475. """Calculate forward propagation.
  476. Args:
  477. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  478. ilens (torch.Tensor): Input length (#batch).
  479. prev_states (torch.Tensor): Not to be used now.
  480. Returns:
  481. torch.Tensor: Output tensor (#batch, L, output_size).
  482. torch.Tensor: Output length (#batch).
  483. torch.Tensor: Not to be used now.
  484. """
  485. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  486. if (
  487. isinstance(self.embed, Conv2dSubsampling)
  488. or isinstance(self.embed, Conv2dSubsampling2)
  489. or isinstance(self.embed, Conv2dSubsampling6)
  490. or isinstance(self.embed, Conv2dSubsampling8)
  491. ):
  492. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  493. if short_status:
  494. raise TooShortUttError(
  495. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  496. + f"(it needs more than {limit_size} frames), return empty results",
  497. xs_pad.size(1),
  498. limit_size,
  499. )
  500. xs_pad, masks = self.embed(xs_pad, masks)
  501. else:
  502. xs_pad = self.embed(xs_pad)
  503. intermediate_outs = []
  504. if len(self.interctc_layer_idx) == 0:
  505. xs_pad, masks = self.encoders(xs_pad, masks)
  506. else:
  507. for layer_idx, encoder_layer in enumerate(self.encoders):
  508. xs_pad, masks = encoder_layer(xs_pad, masks)
  509. if layer_idx + 1 in self.interctc_layer_idx:
  510. encoder_out = xs_pad
  511. if isinstance(encoder_out, tuple):
  512. encoder_out = encoder_out[0]
  513. # intermediate outputs are also normalized
  514. if self.normalize_before:
  515. encoder_out = self.after_norm(encoder_out)
  516. intermediate_outs.append((layer_idx + 1, encoder_out))
  517. if self.interctc_use_conditioning:
  518. ctc_out = ctc.softmax(encoder_out)
  519. if isinstance(xs_pad, tuple):
  520. x, pos_emb = xs_pad
  521. x = x + self.conditioning_layer(ctc_out)
  522. xs_pad = (x, pos_emb)
  523. else:
  524. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  525. if isinstance(xs_pad, tuple):
  526. xs_pad = xs_pad[0]
  527. if self.normalize_before:
  528. xs_pad = self.after_norm(xs_pad)
  529. olens = masks.squeeze(1).sum(1)
  530. if len(intermediate_outs) > 0:
  531. return (xs_pad, intermediate_outs), olens, None
  532. return xs_pad, olens, None