e_branchformer_encoder.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. # Copyright 2022 Kwangyoun Kim (ASAPP inc.)
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """E-Branchformer encoder definition.
  4. Reference:
  5. Kwangyoun Kim, Felix Wu, Yifan Peng, Jing Pan,
  6. Prashant Sridhar, Kyu J. Han, Shinji Watanabe,
  7. "E-Branchformer: Branchformer with Enhanced merging
  8. for speech recognition," in SLT 2022.
  9. """
  10. import logging
  11. from typing import List, Optional, Tuple
  12. import torch
  13. from funasr.models.ctc import CTC
  14. from funasr.models.encoder.abs_encoder import AbsEncoder
  15. from funasr.modules.cgmlp import ConvolutionalGatingMLP
  16. from funasr.modules.fastformer import FastSelfAttention
  17. from funasr.modules.nets_utils import get_activation, make_pad_mask
  18. from funasr.modules.attention import ( # noqa: H301
  19. LegacyRelPositionMultiHeadedAttention,
  20. MultiHeadedAttention,
  21. RelPositionMultiHeadedAttention,
  22. )
  23. from funasr.modules.embedding import ( # noqa: H301
  24. LegacyRelPositionalEncoding,
  25. PositionalEncoding,
  26. RelPositionalEncoding,
  27. ScaledPositionalEncoding,
  28. )
  29. from funasr.modules.layer_norm import LayerNorm
  30. from funasr.modules.positionwise_feed_forward import (
  31. PositionwiseFeedForward,
  32. )
  33. from funasr.modules.repeat import repeat
  34. from funasr.modules.subsampling import (
  35. Conv2dSubsampling,
  36. Conv2dSubsampling2,
  37. Conv2dSubsampling6,
  38. Conv2dSubsampling8,
  39. TooShortUttError,
  40. check_short_utt,
  41. )
  42. class EBranchformerEncoderLayer(torch.nn.Module):
  43. """E-Branchformer encoder layer module.
  44. Args:
  45. size (int): model dimension
  46. attn: standard self-attention or efficient attention
  47. cgmlp: ConvolutionalGatingMLP
  48. feed_forward: feed-forward module, optional
  49. feed_forward: macaron-style feed-forward module, optional
  50. dropout_rate (float): dropout probability
  51. merge_conv_kernel (int): kernel size of the depth-wise conv in merge module
  52. """
  53. def __init__(
  54. self,
  55. size: int,
  56. attn: torch.nn.Module,
  57. cgmlp: torch.nn.Module,
  58. feed_forward: Optional[torch.nn.Module],
  59. feed_forward_macaron: Optional[torch.nn.Module],
  60. dropout_rate: float,
  61. merge_conv_kernel: int = 3,
  62. ):
  63. super().__init__()
  64. self.size = size
  65. self.attn = attn
  66. self.cgmlp = cgmlp
  67. self.feed_forward = feed_forward
  68. self.feed_forward_macaron = feed_forward_macaron
  69. self.ff_scale = 1.0
  70. if self.feed_forward is not None:
  71. self.norm_ff = LayerNorm(size)
  72. if self.feed_forward_macaron is not None:
  73. self.ff_scale = 0.5
  74. self.norm_ff_macaron = LayerNorm(size)
  75. self.norm_mha = LayerNorm(size) # for the MHA module
  76. self.norm_mlp = LayerNorm(size) # for the MLP module
  77. self.norm_final = LayerNorm(size) # for the final output of the block
  78. self.dropout = torch.nn.Dropout(dropout_rate)
  79. self.depthwise_conv_fusion = torch.nn.Conv1d(
  80. size + size,
  81. size + size,
  82. kernel_size=merge_conv_kernel,
  83. stride=1,
  84. padding=(merge_conv_kernel - 1) // 2,
  85. groups=size + size,
  86. bias=True,
  87. )
  88. self.merge_proj = torch.nn.Linear(size + size, size)
  89. def forward(self, x_input, mask, cache=None):
  90. """Compute encoded features.
  91. Args:
  92. x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
  93. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
  94. - w/o pos emb: Tensor (#batch, time, size).
  95. mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
  96. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  97. Returns:
  98. torch.Tensor: Output tensor (#batch, time, size).
  99. torch.Tensor: Mask tensor (#batch, time).
  100. """
  101. if cache is not None:
  102. raise NotImplementedError("cache is not None, which is not tested")
  103. if isinstance(x_input, tuple):
  104. x, pos_emb = x_input[0], x_input[1]
  105. else:
  106. x, pos_emb = x_input, None
  107. if self.feed_forward_macaron is not None:
  108. residual = x
  109. x = self.norm_ff_macaron(x)
  110. x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
  111. # Two branches
  112. x1 = x
  113. x2 = x
  114. # Branch 1: multi-headed attention module
  115. x1 = self.norm_mha(x1)
  116. if isinstance(self.attn, FastSelfAttention):
  117. x_att = self.attn(x1, mask)
  118. else:
  119. if pos_emb is not None:
  120. x_att = self.attn(x1, x1, x1, pos_emb, mask)
  121. else:
  122. x_att = self.attn(x1, x1, x1, mask)
  123. x1 = self.dropout(x_att)
  124. # Branch 2: convolutional gating mlp
  125. x2 = self.norm_mlp(x2)
  126. if pos_emb is not None:
  127. x2 = (x2, pos_emb)
  128. x2 = self.cgmlp(x2, mask)
  129. if isinstance(x2, tuple):
  130. x2 = x2[0]
  131. x2 = self.dropout(x2)
  132. # Merge two branches
  133. x_concat = torch.cat([x1, x2], dim=-1)
  134. x_tmp = x_concat.transpose(1, 2)
  135. x_tmp = self.depthwise_conv_fusion(x_tmp)
  136. x_tmp = x_tmp.transpose(1, 2)
  137. x = x + self.dropout(self.merge_proj(x_concat + x_tmp))
  138. if self.feed_forward is not None:
  139. # feed forward module
  140. residual = x
  141. x = self.norm_ff(x)
  142. x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
  143. x = self.norm_final(x)
  144. if pos_emb is not None:
  145. return (x, pos_emb), mask
  146. return x, mask
  147. class EBranchformerEncoder(AbsEncoder):
  148. """E-Branchformer encoder module."""
  149. def __init__(
  150. self,
  151. input_size: int,
  152. output_size: int = 256,
  153. attention_heads: int = 4,
  154. attention_layer_type: str = "rel_selfattn",
  155. pos_enc_layer_type: str = "rel_pos",
  156. rel_pos_type: str = "latest",
  157. cgmlp_linear_units: int = 2048,
  158. cgmlp_conv_kernel: int = 31,
  159. use_linear_after_conv: bool = False,
  160. gate_activation: str = "identity",
  161. num_blocks: int = 12,
  162. dropout_rate: float = 0.1,
  163. positional_dropout_rate: float = 0.1,
  164. attention_dropout_rate: float = 0.0,
  165. input_layer: Optional[str] = "conv2d",
  166. zero_triu: bool = False,
  167. padding_idx: int = -1,
  168. layer_drop_rate: float = 0.0,
  169. max_pos_emb_len: int = 5000,
  170. use_ffn: bool = False,
  171. macaron_ffn: bool = False,
  172. ffn_activation_type: str = "swish",
  173. linear_units: int = 2048,
  174. positionwise_layer_type: str = "linear",
  175. merge_conv_kernel: int = 3,
  176. interctc_layer_idx=None,
  177. interctc_use_conditioning: bool = False,
  178. ):
  179. super().__init__()
  180. self._output_size = output_size
  181. if rel_pos_type == "legacy":
  182. if pos_enc_layer_type == "rel_pos":
  183. pos_enc_layer_type = "legacy_rel_pos"
  184. if attention_layer_type == "rel_selfattn":
  185. attention_layer_type = "legacy_rel_selfattn"
  186. elif rel_pos_type == "latest":
  187. assert attention_layer_type != "legacy_rel_selfattn"
  188. assert pos_enc_layer_type != "legacy_rel_pos"
  189. else:
  190. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  191. if pos_enc_layer_type == "abs_pos":
  192. pos_enc_class = PositionalEncoding
  193. elif pos_enc_layer_type == "scaled_abs_pos":
  194. pos_enc_class = ScaledPositionalEncoding
  195. elif pos_enc_layer_type == "rel_pos":
  196. assert attention_layer_type == "rel_selfattn"
  197. pos_enc_class = RelPositionalEncoding
  198. elif pos_enc_layer_type == "legacy_rel_pos":
  199. assert attention_layer_type == "legacy_rel_selfattn"
  200. pos_enc_class = LegacyRelPositionalEncoding
  201. logging.warning(
  202. "Using legacy_rel_pos and it will be deprecated in the future."
  203. )
  204. else:
  205. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  206. if input_layer == "linear":
  207. self.embed = torch.nn.Sequential(
  208. torch.nn.Linear(input_size, output_size),
  209. torch.nn.LayerNorm(output_size),
  210. torch.nn.Dropout(dropout_rate),
  211. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  212. )
  213. elif input_layer == "conv2d":
  214. self.embed = Conv2dSubsampling(
  215. input_size,
  216. output_size,
  217. dropout_rate,
  218. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  219. )
  220. elif input_layer == "conv2d2":
  221. self.embed = Conv2dSubsampling2(
  222. input_size,
  223. output_size,
  224. dropout_rate,
  225. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  226. )
  227. elif input_layer == "conv2d6":
  228. self.embed = Conv2dSubsampling6(
  229. input_size,
  230. output_size,
  231. dropout_rate,
  232. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  233. )
  234. elif input_layer == "conv2d8":
  235. self.embed = Conv2dSubsampling8(
  236. input_size,
  237. output_size,
  238. dropout_rate,
  239. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  240. )
  241. elif input_layer == "embed":
  242. self.embed = torch.nn.Sequential(
  243. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  244. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  245. )
  246. elif isinstance(input_layer, torch.nn.Module):
  247. self.embed = torch.nn.Sequential(
  248. input_layer,
  249. pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
  250. )
  251. elif input_layer is None:
  252. if input_size == output_size:
  253. self.embed = None
  254. else:
  255. self.embed = torch.nn.Linear(input_size, output_size)
  256. else:
  257. raise ValueError("unknown input_layer: " + input_layer)
  258. activation = get_activation(ffn_activation_type)
  259. if positionwise_layer_type == "linear":
  260. positionwise_layer = PositionwiseFeedForward
  261. positionwise_layer_args = (
  262. output_size,
  263. linear_units,
  264. dropout_rate,
  265. activation,
  266. )
  267. elif positionwise_layer_type is None:
  268. logging.warning("no macaron ffn")
  269. else:
  270. raise ValueError("Support only linear.")
  271. if attention_layer_type == "selfattn":
  272. encoder_selfattn_layer = MultiHeadedAttention
  273. encoder_selfattn_layer_args = (
  274. attention_heads,
  275. output_size,
  276. attention_dropout_rate,
  277. )
  278. elif attention_layer_type == "legacy_rel_selfattn":
  279. assert pos_enc_layer_type == "legacy_rel_pos"
  280. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  281. encoder_selfattn_layer_args = (
  282. attention_heads,
  283. output_size,
  284. attention_dropout_rate,
  285. )
  286. logging.warning(
  287. "Using legacy_rel_selfattn and it will be deprecated in the future."
  288. )
  289. elif attention_layer_type == "rel_selfattn":
  290. assert pos_enc_layer_type == "rel_pos"
  291. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  292. encoder_selfattn_layer_args = (
  293. attention_heads,
  294. output_size,
  295. attention_dropout_rate,
  296. zero_triu,
  297. )
  298. elif attention_layer_type == "fast_selfattn":
  299. assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
  300. encoder_selfattn_layer = FastSelfAttention
  301. encoder_selfattn_layer_args = (
  302. output_size,
  303. attention_heads,
  304. attention_dropout_rate,
  305. )
  306. else:
  307. raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
  308. cgmlp_layer = ConvolutionalGatingMLP
  309. cgmlp_layer_args = (
  310. output_size,
  311. cgmlp_linear_units,
  312. cgmlp_conv_kernel,
  313. dropout_rate,
  314. use_linear_after_conv,
  315. gate_activation,
  316. )
  317. self.encoders = repeat(
  318. num_blocks,
  319. lambda lnum: EBranchformerEncoderLayer(
  320. output_size,
  321. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  322. cgmlp_layer(*cgmlp_layer_args),
  323. positionwise_layer(*positionwise_layer_args) if use_ffn else None,
  324. positionwise_layer(*positionwise_layer_args)
  325. if use_ffn and macaron_ffn
  326. else None,
  327. dropout_rate,
  328. merge_conv_kernel,
  329. ),
  330. layer_drop_rate,
  331. )
  332. self.after_norm = LayerNorm(output_size)
  333. if interctc_layer_idx is None:
  334. interctc_layer_idx = []
  335. self.interctc_layer_idx = interctc_layer_idx
  336. if len(interctc_layer_idx) > 0:
  337. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  338. self.interctc_use_conditioning = interctc_use_conditioning
  339. self.conditioning_layer = None
  340. def output_size(self) -> int:
  341. return self._output_size
  342. def forward(
  343. self,
  344. xs_pad: torch.Tensor,
  345. ilens: torch.Tensor,
  346. prev_states: torch.Tensor = None,
  347. ctc: CTC = None,
  348. max_layer: int = None,
  349. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  350. """Calculate forward propagation.
  351. Args:
  352. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  353. ilens (torch.Tensor): Input length (#batch).
  354. prev_states (torch.Tensor): Not to be used now.
  355. ctc (CTC): Intermediate CTC module.
  356. max_layer (int): Layer depth below which InterCTC is applied.
  357. Returns:
  358. torch.Tensor: Output tensor (#batch, L, output_size).
  359. torch.Tensor: Output length (#batch).
  360. torch.Tensor: Not to be used now.
  361. """
  362. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  363. if (
  364. isinstance(self.embed, Conv2dSubsampling)
  365. or isinstance(self.embed, Conv2dSubsampling2)
  366. or isinstance(self.embed, Conv2dSubsampling6)
  367. or isinstance(self.embed, Conv2dSubsampling8)
  368. ):
  369. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  370. if short_status:
  371. raise TooShortUttError(
  372. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  373. + f"(it needs more than {limit_size} frames), return empty results",
  374. xs_pad.size(1),
  375. limit_size,
  376. )
  377. xs_pad, masks = self.embed(xs_pad, masks)
  378. elif self.embed is not None:
  379. xs_pad = self.embed(xs_pad)
  380. intermediate_outs = []
  381. if len(self.interctc_layer_idx) == 0:
  382. if max_layer is not None and 0 <= max_layer < len(self.encoders):
  383. for layer_idx, encoder_layer in enumerate(self.encoders):
  384. xs_pad, masks = encoder_layer(xs_pad, masks)
  385. if layer_idx >= max_layer:
  386. break
  387. else:
  388. xs_pad, masks = self.encoders(xs_pad, masks)
  389. else:
  390. for layer_idx, encoder_layer in enumerate(self.encoders):
  391. xs_pad, masks = encoder_layer(xs_pad, masks)
  392. if layer_idx + 1 in self.interctc_layer_idx:
  393. encoder_out = xs_pad
  394. if isinstance(encoder_out, tuple):
  395. encoder_out = encoder_out[0]
  396. intermediate_outs.append((layer_idx + 1, encoder_out))
  397. if self.interctc_use_conditioning:
  398. ctc_out = ctc.softmax(encoder_out)
  399. if isinstance(xs_pad, tuple):
  400. xs_pad = list(xs_pad)
  401. xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out)
  402. xs_pad = tuple(xs_pad)
  403. else:
  404. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  405. if isinstance(xs_pad, tuple):
  406. xs_pad = xs_pad[0]
  407. xs_pad = self.after_norm(xs_pad)
  408. olens = masks.squeeze(1).sum(1)
  409. if len(intermediate_outs) > 0:
  410. return (xs_pad, intermediate_outs), olens, None
  411. return xs_pad, olens, None