encoder.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. # Copyright 2022 Kwangyoun Kim (ASAPP inc.)
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """E-Branchformer encoder definition.
  4. Reference:
  5. Kwangyoun Kim, Felix Wu, Yifan Peng, Jing Pan,
  6. Prashant Sridhar, Kyu J. Han, Shinji Watanabe,
  7. "E-Branchformer: Branchformer with Enhanced merging
  8. for speech recognition," in SLT 2022.
  9. """
  10. import logging
  11. from typing import List, Optional, Tuple
  12. import torch
  13. import torch.nn as nn
  14. from funasr.models.ctc.ctc import CTC
  15. from funasr.models.branchformer.cgmlp import ConvolutionalGatingMLP
  16. from funasr.models.branchformer.fastformer import FastSelfAttention
  17. from funasr.models.transformer.utils.nets_utils import get_activation, make_pad_mask
  18. from funasr.models.transformer.attention import ( # noqa: H301
  19. LegacyRelPositionMultiHeadedAttention,
  20. MultiHeadedAttention,
  21. RelPositionMultiHeadedAttention,
  22. )
  23. from funasr.models.transformer.embedding import ( # noqa: H301
  24. LegacyRelPositionalEncoding,
  25. PositionalEncoding,
  26. RelPositionalEncoding,
  27. ScaledPositionalEncoding,
  28. )
  29. from funasr.models.transformer.layer_norm import LayerNorm
  30. from funasr.models.transformer.positionwise_feed_forward import (
  31. PositionwiseFeedForward,
  32. )
  33. from funasr.models.transformer.utils.repeat import repeat
  34. from funasr.models.transformer.utils.subsampling import (
  35. Conv2dSubsampling,
  36. Conv2dSubsampling2,
  37. Conv2dSubsampling6,
  38. Conv2dSubsampling8,
  39. TooShortUttError,
  40. check_short_utt,
  41. )
  42. from funasr.register import tables
  43. class EBranchformerEncoderLayer(torch.nn.Module):
  44. """E-Branchformer encoder layer module.
  45. Args:
  46. size (int): model dimension
  47. attn: standard self-attention or efficient attention
  48. cgmlp: ConvolutionalGatingMLP
  49. feed_forward: feed-forward module, optional
  50. feed_forward: macaron-style feed-forward module, optional
  51. dropout_rate (float): dropout probability
  52. merge_conv_kernel (int): kernel size of the depth-wise conv in merge module
  53. """
  54. def __init__(
  55. self,
  56. size: int,
  57. attn: torch.nn.Module,
  58. cgmlp: torch.nn.Module,
  59. feed_forward: Optional[torch.nn.Module],
  60. feed_forward_macaron: Optional[torch.nn.Module],
  61. dropout_rate: float,
  62. merge_conv_kernel: int = 3,
  63. ):
  64. super().__init__()
  65. self.size = size
  66. self.attn = attn
  67. self.cgmlp = cgmlp
  68. self.feed_forward = feed_forward
  69. self.feed_forward_macaron = feed_forward_macaron
  70. self.ff_scale = 1.0
  71. if self.feed_forward is not None:
  72. self.norm_ff = LayerNorm(size)
  73. if self.feed_forward_macaron is not None:
  74. self.ff_scale = 0.5
  75. self.norm_ff_macaron = LayerNorm(size)
  76. self.norm_mha = LayerNorm(size) # for the MHA module
  77. self.norm_mlp = LayerNorm(size) # for the MLP module
  78. self.norm_final = LayerNorm(size) # for the final output of the block
  79. self.dropout = torch.nn.Dropout(dropout_rate)
  80. self.depthwise_conv_fusion = torch.nn.Conv1d(
  81. size + size,
  82. size + size,
  83. kernel_size=merge_conv_kernel,
  84. stride=1,
  85. padding=(merge_conv_kernel - 1) // 2,
  86. groups=size + size,
  87. bias=True,
  88. )
  89. self.merge_proj = torch.nn.Linear(size + size, size)
  90. def forward(self, x_input, mask, cache=None):
  91. """Compute encoded features.
  92. Args:
  93. x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
  94. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
  95. - w/o pos emb: Tensor (#batch, time, size).
  96. mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
  97. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  98. Returns:
  99. torch.Tensor: Output tensor (#batch, time, size).
  100. torch.Tensor: Mask tensor (#batch, time).
  101. """
  102. if cache is not None:
  103. raise NotImplementedError("cache is not None, which is not tested")
  104. if isinstance(x_input, tuple):
  105. x, pos_emb = x_input[0], x_input[1]
  106. else:
  107. x, pos_emb = x_input, None
  108. if self.feed_forward_macaron is not None:
  109. residual = x
  110. x = self.norm_ff_macaron(x)
  111. x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
  112. # Two branches
  113. x1 = x
  114. x2 = x
  115. # Branch 1: multi-headed attention module
  116. x1 = self.norm_mha(x1)
  117. if isinstance(self.attn, FastSelfAttention):
  118. x_att = self.attn(x1, mask)
  119. else:
  120. if pos_emb is not None:
  121. x_att = self.attn(x1, x1, x1, pos_emb, mask)
  122. else:
  123. x_att = self.attn(x1, x1, x1, mask)
  124. x1 = self.dropout(x_att)
  125. # Branch 2: convolutional gating mlp
  126. x2 = self.norm_mlp(x2)
  127. if pos_emb is not None:
  128. x2 = (x2, pos_emb)
  129. x2 = self.cgmlp(x2, mask)
  130. if isinstance(x2, tuple):
  131. x2 = x2[0]
  132. x2 = self.dropout(x2)
  133. # Merge two branches
  134. x_concat = torch.cat([x1, x2], dim=-1)
  135. x_tmp = x_concat.transpose(1, 2)
  136. x_tmp = self.depthwise_conv_fusion(x_tmp)
  137. x_tmp = x_tmp.transpose(1, 2)
  138. x = x + self.dropout(self.merge_proj(x_concat + x_tmp))
  139. if self.feed_forward is not None:
  140. # feed forward module
  141. residual = x
  142. x = self.norm_ff(x)
  143. x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
  144. x = self.norm_final(x)
  145. if pos_emb is not None:
  146. return (x, pos_emb), mask
  147. return x, mask
  148. @tables.register("encoder_classes", "EBranchformerEncoder")
  149. class EBranchformerEncoder(nn.Module):
  150. """E-Branchformer encoder module."""
  151. def __init__(
  152. self,
  153. input_size: int,
  154. output_size: int = 256,
  155. attention_heads: int = 4,
  156. attention_layer_type: str = "rel_selfattn",
  157. pos_enc_layer_type: str = "rel_pos",
  158. rel_pos_type: str = "latest",
  159. cgmlp_linear_units: int = 2048,
  160. cgmlp_conv_kernel: int = 31,
  161. use_linear_after_conv: bool = False,
  162. gate_activation: str = "identity",
  163. num_blocks: int = 12,
  164. dropout_rate: float = 0.1,
  165. positional_dropout_rate: float = 0.1,
  166. attention_dropout_rate: float = 0.0,
  167. input_layer: Optional[str] = "conv2d",
  168. zero_triu: bool = False,
  169. padding_idx: int = -1,
  170. layer_drop_rate: float = 0.0,
  171. max_pos_emb_len: int = 5000,
  172. use_ffn: bool = False,
  173. macaron_ffn: bool = False,
  174. ffn_activation_type: str = "swish",
  175. linear_units: int = 2048,
  176. positionwise_layer_type: str = "linear",
  177. merge_conv_kernel: int = 3,
  178. interctc_layer_idx=None,
  179. interctc_use_conditioning: bool = False,
  180. ):
  181. super().__init__()
  182. self._output_size = output_size
  183. if rel_pos_type == "legacy":
  184. if pos_enc_layer_type == "rel_pos":
  185. pos_enc_layer_type = "legacy_rel_pos"
  186. if attention_layer_type == "rel_selfattn":
  187. attention_layer_type = "legacy_rel_selfattn"
  188. elif rel_pos_type == "latest":
  189. assert attention_layer_type != "legacy_rel_selfattn"
  190. assert pos_enc_layer_type != "legacy_rel_pos"
  191. else:
  192. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  193. if pos_enc_layer_type == "abs_pos":
  194. pos_enc_class = PositionalEncoding
  195. elif pos_enc_layer_type == "scaled_abs_pos":
  196. pos_enc_class = ScaledPositionalEncoding
  197. elif pos_enc_layer_type == "rel_pos":
  198. assert attention_layer_type == "rel_selfattn"
  199. pos_enc_class = RelPositionalEncoding
  200. elif pos_enc_layer_type == "legacy_rel_pos":
  201. assert attention_layer_type == "legacy_rel_selfattn"
  202. pos_enc_class = LegacyRelPositionalEncoding
  203. logging.warning(
  204. "Using legacy_rel_pos and it will be deprecated in the future."
  205. )
  206. else:
  207. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  208. if input_layer == "linear":
  209. self.embed = torch.nn.Sequential(
  210. torch.nn.Linear(input_size, output_size),
  211. torch.nn.LayerNorm(output_size),
  212. torch.nn.Dropout(dropout_rate),
  213. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  214. )
  215. elif input_layer == "conv2d":
  216. self.embed = Conv2dSubsampling(
  217. input_size,
  218. output_size,
  219. dropout_rate,
  220. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  221. )
  222. elif input_layer == "conv2d2":
  223. self.embed = Conv2dSubsampling2(
  224. input_size,
  225. output_size,
  226. dropout_rate,
  227. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  228. )
  229. elif input_layer == "conv2d6":
  230. self.embed = Conv2dSubsampling6(
  231. input_size,
  232. output_size,
  233. dropout_rate,
  234. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  235. )
  236. elif input_layer == "conv2d8":
  237. self.embed = Conv2dSubsampling8(
  238. input_size,
  239. output_size,
  240. dropout_rate,
  241. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  242. )
  243. elif input_layer == "embed":
  244. self.embed = torch.nn.Sequential(
  245. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  246. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  247. )
  248. elif isinstance(input_layer, torch.nn.Module):
  249. self.embed = torch.nn.Sequential(
  250. input_layer,
  251. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  252. )
  253. elif input_layer is None:
  254. if input_size == output_size:
  255. self.embed = None
  256. else:
  257. self.embed = torch.nn.Linear(input_size, output_size)
  258. else:
  259. raise ValueError("unknown input_layer: " + input_layer)
  260. activation = get_activation(ffn_activation_type)
  261. if positionwise_layer_type == "linear":
  262. positionwise_layer = PositionwiseFeedForward
  263. positionwise_layer_args = (
  264. output_size,
  265. linear_units,
  266. dropout_rate,
  267. activation,
  268. )
  269. elif positionwise_layer_type is None:
  270. logging.warning("no macaron ffn")
  271. else:
  272. raise ValueError("Support only linear.")
  273. if attention_layer_type == "selfattn":
  274. encoder_selfattn_layer = MultiHeadedAttention
  275. encoder_selfattn_layer_args = (
  276. attention_heads,
  277. output_size,
  278. attention_dropout_rate,
  279. )
  280. elif attention_layer_type == "legacy_rel_selfattn":
  281. assert pos_enc_layer_type == "legacy_rel_pos"
  282. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  283. encoder_selfattn_layer_args = (
  284. attention_heads,
  285. output_size,
  286. attention_dropout_rate,
  287. )
  288. logging.warning(
  289. "Using legacy_rel_selfattn and it will be deprecated in the future."
  290. )
  291. elif attention_layer_type == "rel_selfattn":
  292. assert pos_enc_layer_type == "rel_pos"
  293. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  294. encoder_selfattn_layer_args = (
  295. attention_heads,
  296. output_size,
  297. attention_dropout_rate,
  298. zero_triu,
  299. )
  300. elif attention_layer_type == "fast_selfattn":
  301. assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
  302. encoder_selfattn_layer = FastSelfAttention
  303. encoder_selfattn_layer_args = (
  304. output_size,
  305. attention_heads,
  306. attention_dropout_rate,
  307. )
  308. else:
  309. raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
  310. cgmlp_layer = ConvolutionalGatingMLP
  311. cgmlp_layer_args = (
  312. output_size,
  313. cgmlp_linear_units,
  314. cgmlp_conv_kernel,
  315. dropout_rate,
  316. use_linear_after_conv,
  317. gate_activation,
  318. )
  319. self.encoders = repeat(
  320. num_blocks,
  321. lambda lnum: EBranchformerEncoderLayer(
  322. output_size,
  323. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  324. cgmlp_layer(*cgmlp_layer_args),
  325. positionwise_layer(*positionwise_layer_args) if use_ffn else None,
  326. positionwise_layer(*positionwise_layer_args)
  327. if use_ffn and macaron_ffn
  328. else None,
  329. dropout_rate,
  330. merge_conv_kernel,
  331. ),
  332. layer_drop_rate,
  333. )
  334. self.after_norm = LayerNorm(output_size)
  335. if interctc_layer_idx is None:
  336. interctc_layer_idx = []
  337. self.interctc_layer_idx = interctc_layer_idx
  338. if len(interctc_layer_idx) > 0:
  339. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  340. self.interctc_use_conditioning = interctc_use_conditioning
  341. self.conditioning_layer = None
  342. def output_size(self) -> int:
  343. return self._output_size
  344. def forward(
  345. self,
  346. xs_pad: torch.Tensor,
  347. ilens: torch.Tensor,
  348. prev_states: torch.Tensor = None,
  349. ctc: CTC = None,
  350. max_layer: int = None,
  351. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  352. """Calculate forward propagation.
  353. Args:
  354. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  355. ilens (torch.Tensor): Input length (#batch).
  356. prev_states (torch.Tensor): Not to be used now.
  357. ctc (CTC): Intermediate CTC module.
  358. max_layer (int): Layer depth below which InterCTC is applied.
  359. Returns:
  360. torch.Tensor: Output tensor (#batch, L, output_size).
  361. torch.Tensor: Output length (#batch).
  362. torch.Tensor: Not to be used now.
  363. """
  364. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  365. if (
  366. isinstance(self.embed, Conv2dSubsampling)
  367. or isinstance(self.embed, Conv2dSubsampling2)
  368. or isinstance(self.embed, Conv2dSubsampling6)
  369. or isinstance(self.embed, Conv2dSubsampling8)
  370. ):
  371. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  372. if short_status:
  373. raise TooShortUttError(
  374. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  375. + f"(it needs more than {limit_size} frames), return empty results",
  376. xs_pad.size(1),
  377. limit_size,
  378. )
  379. xs_pad, masks = self.embed(xs_pad, masks)
  380. elif self.embed is not None:
  381. xs_pad = self.embed(xs_pad)
  382. intermediate_outs = []
  383. if len(self.interctc_layer_idx) == 0:
  384. if max_layer is not None and 0 <= max_layer < len(self.encoders):
  385. for layer_idx, encoder_layer in enumerate(self.encoders):
  386. xs_pad, masks = encoder_layer(xs_pad, masks)
  387. if layer_idx >= max_layer:
  388. break
  389. else:
  390. xs_pad, masks = self.encoders(xs_pad, masks)
  391. else:
  392. for layer_idx, encoder_layer in enumerate(self.encoders):
  393. xs_pad, masks = encoder_layer(xs_pad, masks)
  394. if layer_idx + 1 in self.interctc_layer_idx:
  395. encoder_out = xs_pad
  396. if isinstance(encoder_out, tuple):
  397. encoder_out = encoder_out[0]
  398. intermediate_outs.append((layer_idx + 1, encoder_out))
  399. if self.interctc_use_conditioning:
  400. ctc_out = ctc.softmax(encoder_out)
  401. if isinstance(xs_pad, tuple):
  402. xs_pad = list(xs_pad)
  403. xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out)
  404. xs_pad = tuple(xs_pad)
  405. else:
  406. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  407. if isinstance(xs_pad, tuple):
  408. xs_pad = xs_pad[0]
  409. xs_pad = self.after_norm(xs_pad)
  410. olens = masks.squeeze(1).sum(1)
  411. if len(intermediate_outs) > 0:
  412. return (xs_pad, intermediate_outs), olens, None
  413. return xs_pad, olens, None