conformer_encoder.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. # Copyright 2020 Tomoki Hayashi
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Conformer encoder definition."""
  4. import logging
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. from torch import nn
  11. from typeguard import check_argument_types
  12. from funasr.models.ctc import CTC
  13. from funasr.models.encoder.abs_encoder import AbsEncoder
  14. from funasr.modules.attention import (
  15. MultiHeadedAttention, # noqa: H301
  16. RelPositionMultiHeadedAttention, # noqa: H301
  17. LegacyRelPositionMultiHeadedAttention, # noqa: H301
  18. )
  19. from funasr.modules.embedding import (
  20. PositionalEncoding, # noqa: H301
  21. ScaledPositionalEncoding, # noqa: H301
  22. RelPositionalEncoding, # noqa: H301
  23. LegacyRelPositionalEncoding, # noqa: H301
  24. )
  25. from funasr.modules.layer_norm import LayerNorm
  26. from funasr.modules.multi_layer_conv import Conv1dLinear
  27. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  28. from funasr.modules.nets_utils import get_activation
  29. from funasr.modules.nets_utils import make_pad_mask
  30. from funasr.modules.positionwise_feed_forward import (
  31. PositionwiseFeedForward, # noqa: H301
  32. )
  33. from funasr.modules.repeat import repeat
  34. from funasr.modules.subsampling import Conv2dSubsampling
  35. from funasr.modules.subsampling import Conv2dSubsampling2
  36. from funasr.modules.subsampling import Conv2dSubsampling6
  37. from funasr.modules.subsampling import Conv2dSubsampling8
  38. from funasr.modules.subsampling import TooShortUttError
  39. from funasr.modules.subsampling import check_short_utt
  40. from funasr.modules.subsampling import Conv2dSubsamplingPad
  41. class ConvolutionModule(nn.Module):
  42. """ConvolutionModule in Conformer model.
  43. Args:
  44. channels (int): The number of channels of conv layers.
  45. kernel_size (int): Kernerl size of conv layers.
  46. """
  47. def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
  48. """Construct an ConvolutionModule object."""
  49. super(ConvolutionModule, self).__init__()
  50. # kernerl_size should be a odd number for 'SAME' padding
  51. assert (kernel_size - 1) % 2 == 0
  52. self.pointwise_conv1 = nn.Conv1d(
  53. channels,
  54. 2 * channels,
  55. kernel_size=1,
  56. stride=1,
  57. padding=0,
  58. bias=bias,
  59. )
  60. self.depthwise_conv = nn.Conv1d(
  61. channels,
  62. channels,
  63. kernel_size,
  64. stride=1,
  65. padding=(kernel_size - 1) // 2,
  66. groups=channels,
  67. bias=bias,
  68. )
  69. self.norm = nn.BatchNorm1d(channels)
  70. self.pointwise_conv2 = nn.Conv1d(
  71. channels,
  72. channels,
  73. kernel_size=1,
  74. stride=1,
  75. padding=0,
  76. bias=bias,
  77. )
  78. self.activation = activation
  79. def forward(self, x):
  80. """Compute convolution module.
  81. Args:
  82. x (torch.Tensor): Input tensor (#batch, time, channels).
  83. Returns:
  84. torch.Tensor: Output tensor (#batch, time, channels).
  85. """
  86. # exchange the temporal dimension and the feature dimension
  87. x = x.transpose(1, 2)
  88. # GLU mechanism
  89. x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
  90. x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
  91. # 1D Depthwise Conv
  92. x = self.depthwise_conv(x)
  93. x = self.activation(self.norm(x))
  94. x = self.pointwise_conv2(x)
  95. return x.transpose(1, 2)
  96. class EncoderLayer(nn.Module):
  97. """Encoder layer module.
  98. Args:
  99. size (int): Input dimension.
  100. self_attn (torch.nn.Module): Self-attention module instance.
  101. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  102. can be used as the argument.
  103. feed_forward (torch.nn.Module): Feed-forward module instance.
  104. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  105. can be used as the argument.
  106. feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
  107. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  108. can be used as the argument.
  109. conv_module (torch.nn.Module): Convolution module instance.
  110. `ConvlutionModule` instance can be used as the argument.
  111. dropout_rate (float): Dropout rate.
  112. normalize_before (bool): Whether to use layer_norm before the first block.
  113. concat_after (bool): Whether to concat attention layer's input and output.
  114. if True, additional linear will be applied.
  115. i.e. x -> x + linear(concat(x, att(x)))
  116. if False, no additional linear will be applied. i.e. x -> x + att(x)
  117. stochastic_depth_rate (float): Proability to skip this layer.
  118. During training, the layer may skip residual computation and return input
  119. as-is with given probability.
  120. """
  121. def __init__(
  122. self,
  123. size,
  124. self_attn,
  125. feed_forward,
  126. feed_forward_macaron,
  127. conv_module,
  128. dropout_rate,
  129. normalize_before=True,
  130. concat_after=False,
  131. stochastic_depth_rate=0.0,
  132. ):
  133. """Construct an EncoderLayer object."""
  134. super(EncoderLayer, self).__init__()
  135. self.self_attn = self_attn
  136. self.feed_forward = feed_forward
  137. self.feed_forward_macaron = feed_forward_macaron
  138. self.conv_module = conv_module
  139. self.norm_ff = LayerNorm(size) # for the FNN module
  140. self.norm_mha = LayerNorm(size) # for the MHA module
  141. if feed_forward_macaron is not None:
  142. self.norm_ff_macaron = LayerNorm(size)
  143. self.ff_scale = 0.5
  144. else:
  145. self.ff_scale = 1.0
  146. if self.conv_module is not None:
  147. self.norm_conv = LayerNorm(size) # for the CNN module
  148. self.norm_final = LayerNorm(size) # for the final output of the block
  149. self.dropout = nn.Dropout(dropout_rate)
  150. self.size = size
  151. self.normalize_before = normalize_before
  152. self.concat_after = concat_after
  153. if self.concat_after:
  154. self.concat_linear = nn.Linear(size + size, size)
  155. self.stochastic_depth_rate = stochastic_depth_rate
  156. def forward(self, x_input, mask, cache=None):
  157. """Compute encoded features.
  158. Args:
  159. x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
  160. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
  161. - w/o pos emb: Tensor (#batch, time, size).
  162. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  163. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  164. Returns:
  165. torch.Tensor: Output tensor (#batch, time, size).
  166. torch.Tensor: Mask tensor (#batch, time).
  167. """
  168. if isinstance(x_input, tuple):
  169. x, pos_emb = x_input[0], x_input[1]
  170. else:
  171. x, pos_emb = x_input, None
  172. skip_layer = False
  173. # with stochastic depth, residual connection `x + f(x)` becomes
  174. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  175. stoch_layer_coeff = 1.0
  176. if self.training and self.stochastic_depth_rate > 0:
  177. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  178. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  179. if skip_layer:
  180. if cache is not None:
  181. x = torch.cat([cache, x], dim=1)
  182. if pos_emb is not None:
  183. return (x, pos_emb), mask
  184. return x, mask
  185. # whether to use macaron style
  186. if self.feed_forward_macaron is not None:
  187. residual = x
  188. if self.normalize_before:
  189. x = self.norm_ff_macaron(x)
  190. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  191. self.feed_forward_macaron(x)
  192. )
  193. if not self.normalize_before:
  194. x = self.norm_ff_macaron(x)
  195. # multi-headed self-attention module
  196. residual = x
  197. if self.normalize_before:
  198. x = self.norm_mha(x)
  199. if cache is None:
  200. x_q = x
  201. else:
  202. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  203. x_q = x[:, -1:, :]
  204. residual = residual[:, -1:, :]
  205. mask = None if mask is None else mask[:, -1:, :]
  206. if pos_emb is not None:
  207. x_att = self.self_attn(x_q, x, x, pos_emb, mask)
  208. else:
  209. x_att = self.self_attn(x_q, x, x, mask)
  210. if self.concat_after:
  211. x_concat = torch.cat((x, x_att), dim=-1)
  212. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  213. else:
  214. x = residual + stoch_layer_coeff * self.dropout(x_att)
  215. if not self.normalize_before:
  216. x = self.norm_mha(x)
  217. # convolution module
  218. if self.conv_module is not None:
  219. residual = x
  220. if self.normalize_before:
  221. x = self.norm_conv(x)
  222. x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
  223. if not self.normalize_before:
  224. x = self.norm_conv(x)
  225. # feed forward module
  226. residual = x
  227. if self.normalize_before:
  228. x = self.norm_ff(x)
  229. x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
  230. self.feed_forward(x)
  231. )
  232. if not self.normalize_before:
  233. x = self.norm_ff(x)
  234. if self.conv_module is not None:
  235. x = self.norm_final(x)
  236. if cache is not None:
  237. x = torch.cat([cache, x], dim=1)
  238. if pos_emb is not None:
  239. return (x, pos_emb), mask
  240. return x, mask
  241. class ConformerEncoder(AbsEncoder):
  242. """Conformer encoder module.
  243. Args:
  244. input_size (int): Input dimension.
  245. output_size (int): Dimension of attention.
  246. attention_heads (int): The number of heads of multi head attention.
  247. linear_units (int): The number of units of position-wise feed forward.
  248. num_blocks (int): The number of decoder blocks.
  249. dropout_rate (float): Dropout rate.
  250. attention_dropout_rate (float): Dropout rate in attention.
  251. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  252. input_layer (Union[str, torch.nn.Module]): Input layer type.
  253. normalize_before (bool): Whether to use layer_norm before the first block.
  254. concat_after (bool): Whether to concat attention layer's input and output.
  255. If True, additional linear will be applied.
  256. i.e. x -> x + linear(concat(x, att(x)))
  257. If False, no additional linear will be applied. i.e. x -> x + att(x)
  258. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  259. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  260. rel_pos_type (str): Whether to use the latest relative positional encoding or
  261. the legacy one. The legacy relative positional encoding will be deprecated
  262. in the future. More Details can be found in
  263. https://github.com/espnet/espnet/pull/2816.
  264. encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
  265. encoder_attn_layer_type (str): Encoder attention layer type.
  266. activation_type (str): Encoder activation function type.
  267. macaron_style (bool): Whether to use macaron style for positionwise layer.
  268. use_cnn_module (bool): Whether to use convolution module.
  269. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  270. cnn_module_kernel (int): Kernerl size of convolution module.
  271. padding_idx (int): Padding idx for input_layer=embed.
  272. """
  273. def __init__(
  274. self,
  275. input_size: int,
  276. output_size: int = 256,
  277. attention_heads: int = 4,
  278. linear_units: int = 2048,
  279. num_blocks: int = 6,
  280. dropout_rate: float = 0.1,
  281. positional_dropout_rate: float = 0.1,
  282. attention_dropout_rate: float = 0.0,
  283. input_layer: str = "conv2d",
  284. normalize_before: bool = True,
  285. concat_after: bool = False,
  286. positionwise_layer_type: str = "linear",
  287. positionwise_conv_kernel_size: int = 3,
  288. macaron_style: bool = False,
  289. rel_pos_type: str = "legacy",
  290. pos_enc_layer_type: str = "rel_pos",
  291. selfattention_layer_type: str = "rel_selfattn",
  292. activation_type: str = "swish",
  293. use_cnn_module: bool = True,
  294. zero_triu: bool = False,
  295. cnn_module_kernel: int = 31,
  296. padding_idx: int = -1,
  297. interctc_layer_idx: List[int] = [],
  298. interctc_use_conditioning: bool = False,
  299. stochastic_depth_rate: Union[float, List[float]] = 0.0,
  300. ):
  301. assert check_argument_types()
  302. super().__init__()
  303. self._output_size = output_size
  304. if rel_pos_type == "legacy":
  305. if pos_enc_layer_type == "rel_pos":
  306. pos_enc_layer_type = "legacy_rel_pos"
  307. if selfattention_layer_type == "rel_selfattn":
  308. selfattention_layer_type = "legacy_rel_selfattn"
  309. elif rel_pos_type == "latest":
  310. assert selfattention_layer_type != "legacy_rel_selfattn"
  311. assert pos_enc_layer_type != "legacy_rel_pos"
  312. else:
  313. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  314. activation = get_activation(activation_type)
  315. if pos_enc_layer_type == "abs_pos":
  316. pos_enc_class = PositionalEncoding
  317. elif pos_enc_layer_type == "scaled_abs_pos":
  318. pos_enc_class = ScaledPositionalEncoding
  319. elif pos_enc_layer_type == "rel_pos":
  320. assert selfattention_layer_type == "rel_selfattn"
  321. pos_enc_class = RelPositionalEncoding
  322. elif pos_enc_layer_type == "legacy_rel_pos":
  323. assert selfattention_layer_type == "legacy_rel_selfattn"
  324. pos_enc_class = LegacyRelPositionalEncoding
  325. logging.warning(
  326. "Using legacy_rel_pos and it will be deprecated in the future."
  327. )
  328. else:
  329. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  330. if input_layer == "linear":
  331. self.embed = torch.nn.Sequential(
  332. torch.nn.Linear(input_size, output_size),
  333. torch.nn.LayerNorm(output_size),
  334. torch.nn.Dropout(dropout_rate),
  335. pos_enc_class(output_size, positional_dropout_rate),
  336. )
  337. elif input_layer == "conv2d":
  338. self.embed = Conv2dSubsampling(
  339. input_size,
  340. output_size,
  341. dropout_rate,
  342. pos_enc_class(output_size, positional_dropout_rate),
  343. )
  344. elif input_layer == "conv2dpad":
  345. self.embed = Conv2dSubsamplingPad(
  346. input_size,
  347. output_size,
  348. dropout_rate,
  349. pos_enc_class(output_size, positional_dropout_rate),
  350. )
  351. elif input_layer == "conv2d2":
  352. self.embed = Conv2dSubsampling2(
  353. input_size,
  354. output_size,
  355. dropout_rate,
  356. pos_enc_class(output_size, positional_dropout_rate),
  357. )
  358. elif input_layer == "conv2d6":
  359. self.embed = Conv2dSubsampling6(
  360. input_size,
  361. output_size,
  362. dropout_rate,
  363. pos_enc_class(output_size, positional_dropout_rate),
  364. )
  365. elif input_layer == "conv2d8":
  366. self.embed = Conv2dSubsampling8(
  367. input_size,
  368. output_size,
  369. dropout_rate,
  370. pos_enc_class(output_size, positional_dropout_rate),
  371. )
  372. elif input_layer == "embed":
  373. self.embed = torch.nn.Sequential(
  374. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  375. pos_enc_class(output_size, positional_dropout_rate),
  376. )
  377. elif isinstance(input_layer, torch.nn.Module):
  378. self.embed = torch.nn.Sequential(
  379. input_layer,
  380. pos_enc_class(output_size, positional_dropout_rate),
  381. )
  382. elif input_layer is None:
  383. self.embed = torch.nn.Sequential(
  384. pos_enc_class(output_size, positional_dropout_rate)
  385. )
  386. else:
  387. raise ValueError("unknown input_layer: " + input_layer)
  388. self.normalize_before = normalize_before
  389. if positionwise_layer_type == "linear":
  390. positionwise_layer = PositionwiseFeedForward
  391. positionwise_layer_args = (
  392. output_size,
  393. linear_units,
  394. dropout_rate,
  395. activation,
  396. )
  397. elif positionwise_layer_type == "conv1d":
  398. positionwise_layer = MultiLayeredConv1d
  399. positionwise_layer_args = (
  400. output_size,
  401. linear_units,
  402. positionwise_conv_kernel_size,
  403. dropout_rate,
  404. )
  405. elif positionwise_layer_type == "conv1d-linear":
  406. positionwise_layer = Conv1dLinear
  407. positionwise_layer_args = (
  408. output_size,
  409. linear_units,
  410. positionwise_conv_kernel_size,
  411. dropout_rate,
  412. )
  413. else:
  414. raise NotImplementedError("Support only linear or conv1d.")
  415. if selfattention_layer_type == "selfattn":
  416. encoder_selfattn_layer = MultiHeadedAttention
  417. encoder_selfattn_layer_args = (
  418. attention_heads,
  419. output_size,
  420. attention_dropout_rate,
  421. )
  422. elif selfattention_layer_type == "legacy_rel_selfattn":
  423. assert pos_enc_layer_type == "legacy_rel_pos"
  424. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  425. encoder_selfattn_layer_args = (
  426. attention_heads,
  427. output_size,
  428. attention_dropout_rate,
  429. )
  430. logging.warning(
  431. "Using legacy_rel_selfattn and it will be deprecated in the future."
  432. )
  433. elif selfattention_layer_type == "rel_selfattn":
  434. assert pos_enc_layer_type == "rel_pos"
  435. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  436. encoder_selfattn_layer_args = (
  437. attention_heads,
  438. output_size,
  439. attention_dropout_rate,
  440. zero_triu,
  441. )
  442. else:
  443. raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
  444. convolution_layer = ConvolutionModule
  445. convolution_layer_args = (output_size, cnn_module_kernel, activation)
  446. if isinstance(stochastic_depth_rate, float):
  447. stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
  448. if len(stochastic_depth_rate) != num_blocks:
  449. raise ValueError(
  450. f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
  451. f"should be equal to num_blocks ({num_blocks})"
  452. )
  453. self.encoders = repeat(
  454. num_blocks,
  455. lambda lnum: EncoderLayer(
  456. output_size,
  457. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  458. positionwise_layer(*positionwise_layer_args),
  459. positionwise_layer(*positionwise_layer_args) if macaron_style else None,
  460. convolution_layer(*convolution_layer_args) if use_cnn_module else None,
  461. dropout_rate,
  462. normalize_before,
  463. concat_after,
  464. stochastic_depth_rate[lnum],
  465. ),
  466. )
  467. if self.normalize_before:
  468. self.after_norm = LayerNorm(output_size)
  469. self.interctc_layer_idx = interctc_layer_idx
  470. if len(interctc_layer_idx) > 0:
  471. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  472. self.interctc_use_conditioning = interctc_use_conditioning
  473. self.conditioning_layer = None
  474. def output_size(self) -> int:
  475. return self._output_size
  476. def forward(
  477. self,
  478. xs_pad: torch.Tensor,
  479. ilens: torch.Tensor,
  480. prev_states: torch.Tensor = None,
  481. ctc: CTC = None,
  482. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  483. """Calculate forward propagation.
  484. Args:
  485. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  486. ilens (torch.Tensor): Input length (#batch).
  487. prev_states (torch.Tensor): Not to be used now.
  488. Returns:
  489. torch.Tensor: Output tensor (#batch, L, output_size).
  490. torch.Tensor: Output length (#batch).
  491. torch.Tensor: Not to be used now.
  492. """
  493. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  494. if (
  495. isinstance(self.embed, Conv2dSubsampling)
  496. or isinstance(self.embed, Conv2dSubsampling2)
  497. or isinstance(self.embed, Conv2dSubsampling6)
  498. or isinstance(self.embed, Conv2dSubsampling8)
  499. or isinstance(self.embed, Conv2dSubsamplingPad)
  500. ):
  501. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  502. if short_status:
  503. raise TooShortUttError(
  504. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  505. + f"(it needs more than {limit_size} frames), return empty results",
  506. xs_pad.size(1),
  507. limit_size,
  508. )
  509. xs_pad, masks = self.embed(xs_pad, masks)
  510. else:
  511. xs_pad = self.embed(xs_pad)
  512. intermediate_outs = []
  513. if len(self.interctc_layer_idx) == 0:
  514. xs_pad, masks = self.encoders(xs_pad, masks)
  515. else:
  516. for layer_idx, encoder_layer in enumerate(self.encoders):
  517. xs_pad, masks = encoder_layer(xs_pad, masks)
  518. if layer_idx + 1 in self.interctc_layer_idx:
  519. encoder_out = xs_pad
  520. if isinstance(encoder_out, tuple):
  521. encoder_out = encoder_out[0]
  522. # intermediate outputs are also normalized
  523. if self.normalize_before:
  524. encoder_out = self.after_norm(encoder_out)
  525. intermediate_outs.append((layer_idx + 1, encoder_out))
  526. if self.interctc_use_conditioning:
  527. ctc_out = ctc.softmax(encoder_out)
  528. if isinstance(xs_pad, tuple):
  529. x, pos_emb = xs_pad
  530. x = x + self.conditioning_layer(ctc_out)
  531. xs_pad = (x, pos_emb)
  532. else:
  533. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  534. if isinstance(xs_pad, tuple):
  535. xs_pad = xs_pad[0]
  536. if self.normalize_before:
  537. xs_pad = self.after_norm(xs_pad)
  538. olens = masks.squeeze(1).sum(1)
  539. if len(intermediate_outs) > 0:
  540. return (xs_pad, intermediate_outs), olens, None
  541. return xs_pad, olens, None