sanm_encoder.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import torch
  7. import torch.nn as nn
  8. from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
  9. from typeguard import check_argument_types
  10. from funasr.modules.nets_utils import make_pad_mask
  11. from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
  12. from funasr.modules.embedding import SinusoidalPositionEncoder
  13. from funasr.modules.layer_norm import LayerNorm
  14. from funasr.modules.multi_layer_conv import Conv1dLinear
  15. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  16. from funasr.modules.positionwise_feed_forward import (
  17. PositionwiseFeedForward, # noqa: H301
  18. )
  19. from funasr.modules.repeat import repeat
  20. from funasr.modules.subsampling import Conv2dSubsampling
  21. from funasr.modules.subsampling import Conv2dSubsampling2
  22. from funasr.modules.subsampling import Conv2dSubsampling6
  23. from funasr.modules.subsampling import Conv2dSubsampling8
  24. from funasr.modules.subsampling import TooShortUttError
  25. from funasr.modules.subsampling import check_short_utt
  26. from funasr.models.ctc import CTC
  27. from funasr.models.encoder.abs_encoder import AbsEncoder
  28. class EncoderLayerSANM(nn.Module):
  29. def __init__(
  30. self,
  31. in_size,
  32. size,
  33. self_attn,
  34. feed_forward,
  35. dropout_rate,
  36. normalize_before=True,
  37. concat_after=False,
  38. stochastic_depth_rate=0.0,
  39. ):
  40. """Construct an EncoderLayer object."""
  41. super(EncoderLayerSANM, self).__init__()
  42. self.self_attn = self_attn
  43. self.feed_forward = feed_forward
  44. self.norm1 = LayerNorm(in_size)
  45. self.norm2 = LayerNorm(size)
  46. self.dropout = nn.Dropout(dropout_rate)
  47. self.in_size = in_size
  48. self.size = size
  49. self.normalize_before = normalize_before
  50. self.concat_after = concat_after
  51. if self.concat_after:
  52. self.concat_linear = nn.Linear(size + size, size)
  53. self.stochastic_depth_rate = stochastic_depth_rate
  54. self.dropout_rate = dropout_rate
  55. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  56. """Compute encoded features.
  57. Args:
  58. x_input (torch.Tensor): Input tensor (#batch, time, size).
  59. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  60. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  61. Returns:
  62. torch.Tensor: Output tensor (#batch, time, size).
  63. torch.Tensor: Mask tensor (#batch, time).
  64. """
  65. skip_layer = False
  66. # with stochastic depth, residual connection `x + f(x)` becomes
  67. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  68. stoch_layer_coeff = 1.0
  69. if self.training and self.stochastic_depth_rate > 0:
  70. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  71. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  72. if skip_layer:
  73. if cache is not None:
  74. x = torch.cat([cache, x], dim=1)
  75. return x, mask
  76. residual = x
  77. if self.normalize_before:
  78. x = self.norm1(x)
  79. if self.concat_after:
  80. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  81. if self.in_size == self.size:
  82. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  83. else:
  84. x = stoch_layer_coeff * self.concat_linear(x_concat)
  85. else:
  86. if self.in_size == self.size:
  87. x = residual + stoch_layer_coeff * self.dropout(
  88. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  89. )
  90. else:
  91. x = stoch_layer_coeff * self.dropout(
  92. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  93. )
  94. if not self.normalize_before:
  95. x = self.norm1(x)
  96. residual = x
  97. if self.normalize_before:
  98. x = self.norm2(x)
  99. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  100. if not self.normalize_before:
  101. x = self.norm2(x)
  102. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  103. class SANMEncoder(AbsEncoder):
  104. """
  105. author: Speech Lab, Alibaba Group, China
  106. San-m: Memory equipped self-attention for end-to-end speech recognition
  107. https://arxiv.org/abs/2006.01713
  108. """
  109. def __init__(
  110. self,
  111. input_size: int,
  112. output_size: int = 256,
  113. attention_heads: int = 4,
  114. linear_units: int = 2048,
  115. num_blocks: int = 6,
  116. dropout_rate: float = 0.1,
  117. positional_dropout_rate: float = 0.1,
  118. attention_dropout_rate: float = 0.0,
  119. input_layer: Optional[str] = "conv2d",
  120. pos_enc_class=SinusoidalPositionEncoder,
  121. normalize_before: bool = True,
  122. concat_after: bool = False,
  123. positionwise_layer_type: str = "linear",
  124. positionwise_conv_kernel_size: int = 1,
  125. padding_idx: int = -1,
  126. interctc_layer_idx: List[int] = [],
  127. interctc_use_conditioning: bool = False,
  128. kernel_size : int = 11,
  129. sanm_shfit : int = 0,
  130. selfattention_layer_type: str = "sanm",
  131. ):
  132. assert check_argument_types()
  133. super().__init__()
  134. self._output_size = output_size
  135. if input_layer == "linear":
  136. self.embed = torch.nn.Sequential(
  137. torch.nn.Linear(input_size, output_size),
  138. torch.nn.LayerNorm(output_size),
  139. torch.nn.Dropout(dropout_rate),
  140. torch.nn.ReLU(),
  141. pos_enc_class(output_size, positional_dropout_rate),
  142. )
  143. elif input_layer == "conv2d":
  144. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  145. elif input_layer == "conv2d2":
  146. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  147. elif input_layer == "conv2d6":
  148. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  149. elif input_layer == "conv2d8":
  150. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  151. elif input_layer == "embed":
  152. self.embed = torch.nn.Sequential(
  153. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  154. pos_enc_class(output_size, positional_dropout_rate),
  155. )
  156. elif input_layer is None:
  157. if input_size == output_size:
  158. self.embed = None
  159. else:
  160. self.embed = torch.nn.Linear(input_size, output_size)
  161. elif input_layer == "pe":
  162. self.embed = SinusoidalPositionEncoder()
  163. else:
  164. raise ValueError("unknown input_layer: " + input_layer)
  165. self.normalize_before = normalize_before
  166. if positionwise_layer_type == "linear":
  167. positionwise_layer = PositionwiseFeedForward
  168. positionwise_layer_args = (
  169. output_size,
  170. linear_units,
  171. dropout_rate,
  172. )
  173. elif positionwise_layer_type == "conv1d":
  174. positionwise_layer = MultiLayeredConv1d
  175. positionwise_layer_args = (
  176. output_size,
  177. linear_units,
  178. positionwise_conv_kernel_size,
  179. dropout_rate,
  180. )
  181. elif positionwise_layer_type == "conv1d-linear":
  182. positionwise_layer = Conv1dLinear
  183. positionwise_layer_args = (
  184. output_size,
  185. linear_units,
  186. positionwise_conv_kernel_size,
  187. dropout_rate,
  188. )
  189. else:
  190. raise NotImplementedError("Support only linear or conv1d.")
  191. if selfattention_layer_type == "selfattn":
  192. encoder_selfattn_layer = MultiHeadedAttention
  193. encoder_selfattn_layer_args = (
  194. attention_heads,
  195. output_size,
  196. attention_dropout_rate,
  197. )
  198. elif selfattention_layer_type == "sanm":
  199. encoder_selfattn_layer = MultiHeadedAttentionSANM
  200. encoder_selfattn_layer_args0 = (
  201. attention_heads,
  202. input_size,
  203. output_size,
  204. attention_dropout_rate,
  205. kernel_size,
  206. sanm_shfit,
  207. )
  208. encoder_selfattn_layer_args = (
  209. attention_heads,
  210. output_size,
  211. output_size,
  212. attention_dropout_rate,
  213. kernel_size,
  214. sanm_shfit,
  215. )
  216. self.encoders0 = repeat(
  217. 1,
  218. lambda lnum: EncoderLayerSANM(
  219. input_size,
  220. output_size,
  221. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  222. positionwise_layer(*positionwise_layer_args),
  223. dropout_rate,
  224. normalize_before,
  225. concat_after,
  226. ),
  227. )
  228. self.encoders = repeat(
  229. num_blocks-1,
  230. lambda lnum: EncoderLayerSANM(
  231. output_size,
  232. output_size,
  233. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  234. positionwise_layer(*positionwise_layer_args),
  235. dropout_rate,
  236. normalize_before,
  237. concat_after,
  238. ),
  239. )
  240. if self.normalize_before:
  241. self.after_norm = LayerNorm(output_size)
  242. self.interctc_layer_idx = interctc_layer_idx
  243. if len(interctc_layer_idx) > 0:
  244. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  245. self.interctc_use_conditioning = interctc_use_conditioning
  246. self.conditioning_layer = None
  247. self.dropout = nn.Dropout(dropout_rate)
  248. def output_size(self) -> int:
  249. return self._output_size
  250. def forward(
  251. self,
  252. xs_pad: torch.Tensor,
  253. ilens: torch.Tensor,
  254. prev_states: torch.Tensor = None,
  255. ctc: CTC = None,
  256. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  257. """Embed positions in tensor.
  258. Args:
  259. xs_pad: input tensor (B, L, D)
  260. ilens: input length (B)
  261. prev_states: Not to be used now.
  262. Returns:
  263. position embedded tensor and mask
  264. """
  265. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  266. xs_pad *= self.output_size()**0.5
  267. if self.embed is None:
  268. xs_pad = xs_pad
  269. elif (
  270. isinstance(self.embed, Conv2dSubsampling)
  271. or isinstance(self.embed, Conv2dSubsampling2)
  272. or isinstance(self.embed, Conv2dSubsampling6)
  273. or isinstance(self.embed, Conv2dSubsampling8)
  274. ):
  275. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  276. if short_status:
  277. raise TooShortUttError(
  278. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  279. + f"(it needs more than {limit_size} frames), return empty results",
  280. xs_pad.size(1),
  281. limit_size,
  282. )
  283. xs_pad, masks = self.embed(xs_pad, masks)
  284. else:
  285. xs_pad = self.embed(xs_pad)
  286. # xs_pad = self.dropout(xs_pad)
  287. encoder_outs = self.encoders0(xs_pad, masks)
  288. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  289. intermediate_outs = []
  290. if len(self.interctc_layer_idx) == 0:
  291. encoder_outs = self.encoders(xs_pad, masks)
  292. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  293. else:
  294. for layer_idx, encoder_layer in enumerate(self.encoders):
  295. encoder_outs = encoder_layer(xs_pad, masks)
  296. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  297. if layer_idx + 1 in self.interctc_layer_idx:
  298. encoder_out = xs_pad
  299. # intermediate outputs are also normalized
  300. if self.normalize_before:
  301. encoder_out = self.after_norm(encoder_out)
  302. intermediate_outs.append((layer_idx + 1, encoder_out))
  303. if self.interctc_use_conditioning:
  304. ctc_out = ctc.softmax(encoder_out)
  305. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  306. if self.normalize_before:
  307. xs_pad = self.after_norm(xs_pad)
  308. olens = masks.squeeze(1).sum(1)
  309. if len(intermediate_outs) > 0:
  310. return (xs_pad, intermediate_outs), olens, None
  311. return xs_pad, olens, None
  312. class SANMEncoderChunkOpt(AbsEncoder):
  313. """
  314. author: Speech Lab, Alibaba Group, China
  315. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  316. https://arxiv.org/abs/2006.01713
  317. """
  318. def __init__(
  319. self,
  320. input_size: int,
  321. output_size: int = 256,
  322. attention_heads: int = 4,
  323. linear_units: int = 2048,
  324. num_blocks: int = 6,
  325. dropout_rate: float = 0.1,
  326. positional_dropout_rate: float = 0.1,
  327. attention_dropout_rate: float = 0.0,
  328. input_layer: Optional[str] = "conv2d",
  329. pos_enc_class=SinusoidalPositionEncoder,
  330. normalize_before: bool = True,
  331. concat_after: bool = False,
  332. positionwise_layer_type: str = "linear",
  333. positionwise_conv_kernel_size: int = 1,
  334. padding_idx: int = -1,
  335. interctc_layer_idx: List[int] = [],
  336. interctc_use_conditioning: bool = False,
  337. kernel_size: int = 11,
  338. sanm_shfit: int = 0,
  339. selfattention_layer_type: str = "sanm",
  340. chunk_size: Union[int, Sequence[int]] = (16,),
  341. stride: Union[int, Sequence[int]] = (10,),
  342. pad_left: Union[int, Sequence[int]] = (0,),
  343. encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  344. decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  345. ):
  346. assert check_argument_types()
  347. super().__init__()
  348. self._output_size = output_size
  349. if input_layer == "linear":
  350. self.embed = torch.nn.Sequential(
  351. torch.nn.Linear(input_size, output_size),
  352. torch.nn.LayerNorm(output_size),
  353. torch.nn.Dropout(dropout_rate),
  354. torch.nn.ReLU(),
  355. pos_enc_class(output_size, positional_dropout_rate),
  356. )
  357. elif input_layer == "conv2d":
  358. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  359. elif input_layer == "conv2d2":
  360. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  361. elif input_layer == "conv2d6":
  362. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  363. elif input_layer == "conv2d8":
  364. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  365. elif input_layer == "embed":
  366. self.embed = torch.nn.Sequential(
  367. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  368. pos_enc_class(output_size, positional_dropout_rate),
  369. )
  370. elif input_layer is None:
  371. if input_size == output_size:
  372. self.embed = None
  373. else:
  374. self.embed = torch.nn.Linear(input_size, output_size)
  375. elif input_layer == "pe":
  376. self.embed = SinusoidalPositionEncoder()
  377. else:
  378. raise ValueError("unknown input_layer: " + input_layer)
  379. self.normalize_before = normalize_before
  380. if positionwise_layer_type == "linear":
  381. positionwise_layer = PositionwiseFeedForward
  382. positionwise_layer_args = (
  383. output_size,
  384. linear_units,
  385. dropout_rate,
  386. )
  387. elif positionwise_layer_type == "conv1d":
  388. positionwise_layer = MultiLayeredConv1d
  389. positionwise_layer_args = (
  390. output_size,
  391. linear_units,
  392. positionwise_conv_kernel_size,
  393. dropout_rate,
  394. )
  395. elif positionwise_layer_type == "conv1d-linear":
  396. positionwise_layer = Conv1dLinear
  397. positionwise_layer_args = (
  398. output_size,
  399. linear_units,
  400. positionwise_conv_kernel_size,
  401. dropout_rate,
  402. )
  403. else:
  404. raise NotImplementedError("Support only linear or conv1d.")
  405. if selfattention_layer_type == "selfattn":
  406. encoder_selfattn_layer = MultiHeadedAttention
  407. encoder_selfattn_layer_args = (
  408. attention_heads,
  409. output_size,
  410. attention_dropout_rate,
  411. )
  412. elif selfattention_layer_type == "sanm":
  413. encoder_selfattn_layer = MultiHeadedAttentionSANM
  414. encoder_selfattn_layer_args0 = (
  415. attention_heads,
  416. input_size,
  417. output_size,
  418. attention_dropout_rate,
  419. kernel_size,
  420. sanm_shfit,
  421. )
  422. encoder_selfattn_layer_args = (
  423. attention_heads,
  424. output_size,
  425. output_size,
  426. attention_dropout_rate,
  427. kernel_size,
  428. sanm_shfit,
  429. )
  430. self.encoders0 = repeat(
  431. 1,
  432. lambda lnum: EncoderLayerSANM(
  433. input_size,
  434. output_size,
  435. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  436. positionwise_layer(*positionwise_layer_args),
  437. dropout_rate,
  438. normalize_before,
  439. concat_after,
  440. ),
  441. )
  442. self.encoders = repeat(
  443. num_blocks - 1,
  444. lambda lnum: EncoderLayerSANM(
  445. output_size,
  446. output_size,
  447. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  448. positionwise_layer(*positionwise_layer_args),
  449. dropout_rate,
  450. normalize_before,
  451. concat_after,
  452. ),
  453. )
  454. if self.normalize_before:
  455. self.after_norm = LayerNorm(output_size)
  456. self.interctc_layer_idx = interctc_layer_idx
  457. if len(interctc_layer_idx) > 0:
  458. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  459. self.interctc_use_conditioning = interctc_use_conditioning
  460. self.conditioning_layer = None
  461. shfit_fsmn = (kernel_size - 1) // 2
  462. self.overlap_chunk_cls = overlap_chunk(
  463. chunk_size=chunk_size,
  464. stride=stride,
  465. pad_left=pad_left,
  466. shfit_fsmn=shfit_fsmn,
  467. encoder_att_look_back_factor=encoder_att_look_back_factor,
  468. decoder_att_look_back_factor=decoder_att_look_back_factor,
  469. )
  470. def output_size(self) -> int:
  471. return self._output_size
  472. def forward(
  473. self,
  474. xs_pad: torch.Tensor,
  475. ilens: torch.Tensor,
  476. prev_states: torch.Tensor = None,
  477. ctc: CTC = None,
  478. ind: int = 0,
  479. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  480. """Embed positions in tensor.
  481. Args:
  482. xs_pad: input tensor (B, L, D)
  483. ilens: input length (B)
  484. prev_states: Not to be used now.
  485. Returns:
  486. position embedded tensor and mask
  487. """
  488. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  489. xs_pad *= self.output_size() ** 0.5
  490. if self.embed is None:
  491. xs_pad = xs_pad
  492. elif (
  493. isinstance(self.embed, Conv2dSubsampling)
  494. or isinstance(self.embed, Conv2dSubsampling2)
  495. or isinstance(self.embed, Conv2dSubsampling6)
  496. or isinstance(self.embed, Conv2dSubsampling8)
  497. ):
  498. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  499. if short_status:
  500. raise TooShortUttError(
  501. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  502. + f"(it needs more than {limit_size} frames), return empty results",
  503. xs_pad.size(1),
  504. limit_size,
  505. )
  506. xs_pad, masks = self.embed(xs_pad, masks)
  507. else:
  508. xs_pad = self.embed(xs_pad)
  509. mask_shfit_chunk, mask_att_chunk_encoder = None, None
  510. if self.overlap_chunk_cls is not None:
  511. ilens = masks.squeeze(1).sum(1)
  512. chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
  513. xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
  514. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  515. mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
  516. dtype=xs_pad.dtype)
  517. mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
  518. xs_pad.size(0),
  519. dtype=xs_pad.dtype)
  520. encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  521. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  522. intermediate_outs = []
  523. if len(self.interctc_layer_idx) == 0:
  524. encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  525. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  526. else:
  527. for layer_idx, encoder_layer in enumerate(self.encoders):
  528. encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  529. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  530. if layer_idx + 1 in self.interctc_layer_idx:
  531. encoder_out = xs_pad
  532. # intermediate outputs are also normalized
  533. if self.normalize_before:
  534. encoder_out = self.after_norm(encoder_out)
  535. intermediate_outs.append((layer_idx + 1, encoder_out))
  536. if self.interctc_use_conditioning:
  537. ctc_out = ctc.softmax(encoder_out)
  538. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  539. if self.normalize_before:
  540. xs_pad = self.after_norm(xs_pad)
  541. olens = masks.squeeze(1).sum(1)
  542. if len(intermediate_outs) > 0:
  543. return (xs_pad, intermediate_outs), olens, None
  544. return xs_pad, olens, None