branchformer_encoder.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. # Copyright 2022 Yifan Peng (Carnegie Mellon University)
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Branchformer encoder definition.
  4. Reference:
  5. Yifan Peng, Siddharth Dalmia, Ian Lane, and Shinji Watanabe,
  6. “Branchformer: Parallel MLP-Attention Architectures to Capture
  7. Local and Global Context for Speech Recognition and Understanding,”
  8. in Proceedings of ICML, 2022.
  9. """
  10. import logging
  11. from typing import List, Optional, Tuple, Union
  12. import numpy
  13. import torch
  14. from typeguard import check_argument_types
  15. from funasr.models.encoder.abs_encoder import AbsEncoder
  16. from funasr.modules.cgmlp import ConvolutionalGatingMLP
  17. from funasr.modules.fastformer import FastSelfAttention
  18. from funasr.modules.nets_utils import make_pad_mask
  19. from funasr.modules.attention import ( # noqa: H301
  20. LegacyRelPositionMultiHeadedAttention,
  21. MultiHeadedAttention,
  22. RelPositionMultiHeadedAttention,
  23. )
  24. from funasr.modules.embedding import ( # noqa: H301
  25. LegacyRelPositionalEncoding,
  26. PositionalEncoding,
  27. RelPositionalEncoding,
  28. ScaledPositionalEncoding,
  29. )
  30. from funasr.modules.layer_norm import LayerNorm
  31. from funasr.modules.repeat import repeat
  32. from funasr.modules.subsampling import (
  33. Conv2dSubsampling,
  34. Conv2dSubsampling2,
  35. Conv2dSubsampling6,
  36. Conv2dSubsampling8,
  37. TooShortUttError,
  38. check_short_utt,
  39. )
  40. class BranchformerEncoderLayer(torch.nn.Module):
  41. """Branchformer encoder layer module.
  42. Args:
  43. size (int): model dimension
  44. attn: standard self-attention or efficient attention, optional
  45. cgmlp: ConvolutionalGatingMLP, optional
  46. dropout_rate (float): dropout probability
  47. merge_method (str): concat, learned_ave, fixed_ave
  48. cgmlp_weight (float): weight of the cgmlp branch, between 0 and 1,
  49. used if merge_method is fixed_ave
  50. attn_branch_drop_rate (float): probability of dropping the attn branch,
  51. used if merge_method is learned_ave
  52. stochastic_depth_rate (float): stochastic depth probability
  53. """
  54. def __init__(
  55. self,
  56. size: int,
  57. attn: Optional[torch.nn.Module],
  58. cgmlp: Optional[torch.nn.Module],
  59. dropout_rate: float,
  60. merge_method: str,
  61. cgmlp_weight: float = 0.5,
  62. attn_branch_drop_rate: float = 0.0,
  63. stochastic_depth_rate: float = 0.0,
  64. ):
  65. super().__init__()
  66. assert (attn is not None) or (
  67. cgmlp is not None
  68. ), "At least one branch should be valid"
  69. self.size = size
  70. self.attn = attn
  71. self.cgmlp = cgmlp
  72. self.merge_method = merge_method
  73. self.cgmlp_weight = cgmlp_weight
  74. self.attn_branch_drop_rate = attn_branch_drop_rate
  75. self.stochastic_depth_rate = stochastic_depth_rate
  76. self.use_two_branches = (attn is not None) and (cgmlp is not None)
  77. if attn is not None:
  78. self.norm_mha = LayerNorm(size) # for the MHA module
  79. if cgmlp is not None:
  80. self.norm_mlp = LayerNorm(size) # for the MLP module
  81. self.norm_final = LayerNorm(size) # for the final output of the block
  82. self.dropout = torch.nn.Dropout(dropout_rate)
  83. if self.use_two_branches:
  84. if merge_method == "concat":
  85. self.merge_proj = torch.nn.Linear(size + size, size)
  86. elif merge_method == "learned_ave":
  87. # attention-based pooling for two branches
  88. self.pooling_proj1 = torch.nn.Linear(size, 1)
  89. self.pooling_proj2 = torch.nn.Linear(size, 1)
  90. # linear projections for calculating merging weights
  91. self.weight_proj1 = torch.nn.Linear(size, 1)
  92. self.weight_proj2 = torch.nn.Linear(size, 1)
  93. # linear projection after weighted average
  94. self.merge_proj = torch.nn.Linear(size, size)
  95. elif merge_method == "fixed_ave":
  96. assert (
  97. 0.0 <= cgmlp_weight <= 1.0
  98. ), "cgmlp weight should be between 0.0 and 1.0"
  99. # remove the other branch if only one branch is used
  100. if cgmlp_weight == 0.0:
  101. self.use_two_branches = False
  102. self.cgmlp = None
  103. self.norm_mlp = None
  104. elif cgmlp_weight == 1.0:
  105. self.use_two_branches = False
  106. self.attn = None
  107. self.norm_mha = None
  108. # linear projection after weighted average
  109. self.merge_proj = torch.nn.Linear(size, size)
  110. else:
  111. raise ValueError(f"unknown merge method: {merge_method}")
  112. else:
  113. self.merge_proj = torch.nn.Identity()
  114. def forward(self, x_input, mask, cache=None):
  115. """Compute encoded features.
  116. Args:
  117. x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
  118. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
  119. - w/o pos emb: Tensor (#batch, time, size).
  120. mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
  121. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  122. Returns:
  123. torch.Tensor: Output tensor (#batch, time, size).
  124. torch.Tensor: Mask tensor (#batch, time).
  125. """
  126. if cache is not None:
  127. raise NotImplementedError("cache is not None, which is not tested")
  128. if isinstance(x_input, tuple):
  129. x, pos_emb = x_input[0], x_input[1]
  130. else:
  131. x, pos_emb = x_input, None
  132. skip_layer = False
  133. # with stochastic depth, residual connection `x + f(x)` becomes
  134. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  135. stoch_layer_coeff = 1.0
  136. if self.training and self.stochastic_depth_rate > 0:
  137. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  138. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  139. if skip_layer:
  140. if cache is not None:
  141. x = torch.cat([cache, x], dim=1)
  142. if pos_emb is not None:
  143. return (x, pos_emb), mask
  144. return x, mask
  145. # Two branches
  146. x1 = x
  147. x2 = x
  148. # Branch 1: multi-headed attention module
  149. if self.attn is not None:
  150. x1 = self.norm_mha(x1)
  151. if isinstance(self.attn, FastSelfAttention):
  152. x_att = self.attn(x1, mask)
  153. else:
  154. if pos_emb is not None:
  155. x_att = self.attn(x1, x1, x1, pos_emb, mask)
  156. else:
  157. x_att = self.attn(x1, x1, x1, mask)
  158. x1 = self.dropout(x_att)
  159. # Branch 2: convolutional gating mlp
  160. if self.cgmlp is not None:
  161. x2 = self.norm_mlp(x2)
  162. if pos_emb is not None:
  163. x2 = (x2, pos_emb)
  164. x2 = self.cgmlp(x2, mask)
  165. if isinstance(x2, tuple):
  166. x2 = x2[0]
  167. x2 = self.dropout(x2)
  168. # Merge two branches
  169. if self.use_two_branches:
  170. if self.merge_method == "concat":
  171. x = x + stoch_layer_coeff * self.dropout(
  172. self.merge_proj(torch.cat([x1, x2], dim=-1))
  173. )
  174. elif self.merge_method == "learned_ave":
  175. if (
  176. self.training
  177. and self.attn_branch_drop_rate > 0
  178. and torch.rand(1).item() < self.attn_branch_drop_rate
  179. ):
  180. # Drop the attn branch
  181. w1, w2 = 0.0, 1.0
  182. else:
  183. # branch1
  184. score1 = (
  185. self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5
  186. ) # (batch, 1, time)
  187. if mask is not None:
  188. min_value = float(
  189. numpy.finfo(
  190. torch.tensor(0, dtype=score1.dtype).numpy().dtype
  191. ).min
  192. )
  193. score1 = score1.masked_fill(mask.eq(0), min_value)
  194. score1 = torch.softmax(score1, dim=-1).masked_fill(
  195. mask.eq(0), 0.0
  196. )
  197. else:
  198. score1 = torch.softmax(score1, dim=-1)
  199. pooled1 = torch.matmul(score1, x1).squeeze(1) # (batch, size)
  200. weight1 = self.weight_proj1(pooled1) # (batch, 1)
  201. # branch2
  202. score2 = (
  203. self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5
  204. ) # (batch, 1, time)
  205. if mask is not None:
  206. min_value = float(
  207. numpy.finfo(
  208. torch.tensor(0, dtype=score2.dtype).numpy().dtype
  209. ).min
  210. )
  211. score2 = score2.masked_fill(mask.eq(0), min_value)
  212. score2 = torch.softmax(score2, dim=-1).masked_fill(
  213. mask.eq(0), 0.0
  214. )
  215. else:
  216. score2 = torch.softmax(score2, dim=-1)
  217. pooled2 = torch.matmul(score2, x2).squeeze(1) # (batch, size)
  218. weight2 = self.weight_proj2(pooled2) # (batch, 1)
  219. # normalize weights of two branches
  220. merge_weights = torch.softmax(
  221. torch.cat([weight1, weight2], dim=-1), dim=-1
  222. ) # (batch, 2)
  223. merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
  224. -1
  225. ) # (batch, 2, 1, 1)
  226. w1, w2 = merge_weights[:, 0], merge_weights[:, 1] # (batch, 1, 1)
  227. x = x + stoch_layer_coeff * self.dropout(
  228. self.merge_proj(w1 * x1 + w2 * x2)
  229. )
  230. elif self.merge_method == "fixed_ave":
  231. x = x + stoch_layer_coeff * self.dropout(
  232. self.merge_proj(
  233. (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
  234. )
  235. )
  236. else:
  237. raise RuntimeError(f"unknown merge method: {self.merge_method}")
  238. else:
  239. if self.attn is None:
  240. x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2))
  241. elif self.cgmlp is None:
  242. x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1))
  243. else:
  244. # This should not happen
  245. raise RuntimeError("Both branches are not None, which is unexpected.")
  246. x = self.norm_final(x)
  247. if pos_emb is not None:
  248. return (x, pos_emb), mask
  249. return x, mask
  250. class BranchformerEncoder(AbsEncoder):
  251. """Branchformer encoder module."""
  252. def __init__(
  253. self,
  254. input_size: int,
  255. output_size: int = 256,
  256. use_attn: bool = True,
  257. attention_heads: int = 4,
  258. attention_layer_type: str = "rel_selfattn",
  259. pos_enc_layer_type: str = "rel_pos",
  260. rel_pos_type: str = "latest",
  261. use_cgmlp: bool = True,
  262. cgmlp_linear_units: int = 2048,
  263. cgmlp_conv_kernel: int = 31,
  264. use_linear_after_conv: bool = False,
  265. gate_activation: str = "identity",
  266. merge_method: str = "concat",
  267. cgmlp_weight: Union[float, List[float]] = 0.5,
  268. attn_branch_drop_rate: Union[float, List[float]] = 0.0,
  269. num_blocks: int = 12,
  270. dropout_rate: float = 0.1,
  271. positional_dropout_rate: float = 0.1,
  272. attention_dropout_rate: float = 0.0,
  273. input_layer: Optional[str] = "conv2d",
  274. zero_triu: bool = False,
  275. padding_idx: int = -1,
  276. stochastic_depth_rate: Union[float, List[float]] = 0.0,
  277. ):
  278. assert check_argument_types()
  279. super().__init__()
  280. self._output_size = output_size
  281. if rel_pos_type == "legacy":
  282. if pos_enc_layer_type == "rel_pos":
  283. pos_enc_layer_type = "legacy_rel_pos"
  284. if attention_layer_type == "rel_selfattn":
  285. attention_layer_type = "legacy_rel_selfattn"
  286. elif rel_pos_type == "latest":
  287. assert attention_layer_type != "legacy_rel_selfattn"
  288. assert pos_enc_layer_type != "legacy_rel_pos"
  289. else:
  290. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  291. if pos_enc_layer_type == "abs_pos":
  292. pos_enc_class = PositionalEncoding
  293. elif pos_enc_layer_type == "scaled_abs_pos":
  294. pos_enc_class = ScaledPositionalEncoding
  295. elif pos_enc_layer_type == "rel_pos":
  296. assert attention_layer_type == "rel_selfattn"
  297. pos_enc_class = RelPositionalEncoding
  298. elif pos_enc_layer_type == "legacy_rel_pos":
  299. assert attention_layer_type == "legacy_rel_selfattn"
  300. pos_enc_class = LegacyRelPositionalEncoding
  301. logging.warning(
  302. "Using legacy_rel_pos and it will be deprecated in the future."
  303. )
  304. else:
  305. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  306. if input_layer == "linear":
  307. self.embed = torch.nn.Sequential(
  308. torch.nn.Linear(input_size, output_size),
  309. torch.nn.LayerNorm(output_size),
  310. torch.nn.Dropout(dropout_rate),
  311. pos_enc_class(output_size, positional_dropout_rate),
  312. )
  313. elif input_layer == "conv2d":
  314. self.embed = Conv2dSubsampling(
  315. input_size,
  316. output_size,
  317. dropout_rate,
  318. pos_enc_class(output_size, positional_dropout_rate),
  319. )
  320. elif input_layer == "conv2d2":
  321. self.embed = Conv2dSubsampling2(
  322. input_size,
  323. output_size,
  324. dropout_rate,
  325. pos_enc_class(output_size, positional_dropout_rate),
  326. )
  327. elif input_layer == "conv2d6":
  328. self.embed = Conv2dSubsampling6(
  329. input_size,
  330. output_size,
  331. dropout_rate,
  332. pos_enc_class(output_size, positional_dropout_rate),
  333. )
  334. elif input_layer == "conv2d8":
  335. self.embed = Conv2dSubsampling8(
  336. input_size,
  337. output_size,
  338. dropout_rate,
  339. pos_enc_class(output_size, positional_dropout_rate),
  340. )
  341. elif input_layer == "embed":
  342. self.embed = torch.nn.Sequential(
  343. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  344. pos_enc_class(output_size, positional_dropout_rate),
  345. )
  346. elif isinstance(input_layer, torch.nn.Module):
  347. self.embed = torch.nn.Sequential(
  348. input_layer,
  349. pos_enc_class(output_size, positional_dropout_rate),
  350. )
  351. elif input_layer is None:
  352. if input_size == output_size:
  353. self.embed = None
  354. else:
  355. self.embed = torch.nn.Linear(input_size, output_size)
  356. else:
  357. raise ValueError("unknown input_layer: " + input_layer)
  358. if attention_layer_type == "selfattn":
  359. encoder_selfattn_layer = MultiHeadedAttention
  360. encoder_selfattn_layer_args = (
  361. attention_heads,
  362. output_size,
  363. attention_dropout_rate,
  364. )
  365. elif attention_layer_type == "legacy_rel_selfattn":
  366. assert pos_enc_layer_type == "legacy_rel_pos"
  367. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  368. encoder_selfattn_layer_args = (
  369. attention_heads,
  370. output_size,
  371. attention_dropout_rate,
  372. )
  373. logging.warning(
  374. "Using legacy_rel_selfattn and it will be deprecated in the future."
  375. )
  376. elif attention_layer_type == "rel_selfattn":
  377. assert pos_enc_layer_type == "rel_pos"
  378. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  379. encoder_selfattn_layer_args = (
  380. attention_heads,
  381. output_size,
  382. attention_dropout_rate,
  383. zero_triu,
  384. )
  385. elif attention_layer_type == "fast_selfattn":
  386. assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
  387. encoder_selfattn_layer = FastSelfAttention
  388. encoder_selfattn_layer_args = (
  389. output_size,
  390. attention_heads,
  391. attention_dropout_rate,
  392. )
  393. else:
  394. raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
  395. cgmlp_layer = ConvolutionalGatingMLP
  396. cgmlp_layer_args = (
  397. output_size,
  398. cgmlp_linear_units,
  399. cgmlp_conv_kernel,
  400. dropout_rate,
  401. use_linear_after_conv,
  402. gate_activation,
  403. )
  404. if isinstance(stochastic_depth_rate, float):
  405. stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
  406. if len(stochastic_depth_rate) != num_blocks:
  407. raise ValueError(
  408. f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
  409. f"should be equal to num_blocks ({num_blocks})"
  410. )
  411. if isinstance(cgmlp_weight, float):
  412. cgmlp_weight = [cgmlp_weight] * num_blocks
  413. if len(cgmlp_weight) != num_blocks:
  414. raise ValueError(
  415. f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to "
  416. f"num_blocks ({num_blocks})"
  417. )
  418. if isinstance(attn_branch_drop_rate, float):
  419. attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks
  420. if len(attn_branch_drop_rate) != num_blocks:
  421. raise ValueError(
  422. f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
  423. f"should be equal to num_blocks ({num_blocks})"
  424. )
  425. self.encoders = repeat(
  426. num_blocks,
  427. lambda lnum: BranchformerEncoderLayer(
  428. output_size,
  429. encoder_selfattn_layer(*encoder_selfattn_layer_args)
  430. if use_attn
  431. else None,
  432. cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
  433. dropout_rate,
  434. merge_method,
  435. cgmlp_weight[lnum],
  436. attn_branch_drop_rate[lnum],
  437. stochastic_depth_rate[lnum],
  438. ),
  439. )
  440. self.after_norm = LayerNorm(output_size)
  441. def output_size(self) -> int:
  442. return self._output_size
  443. def forward(
  444. self,
  445. xs_pad: torch.Tensor,
  446. ilens: torch.Tensor,
  447. prev_states: torch.Tensor = None,
  448. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  449. """Calculate forward propagation.
  450. Args:
  451. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  452. ilens (torch.Tensor): Input length (#batch).
  453. prev_states (torch.Tensor): Not to be used now.
  454. Returns:
  455. torch.Tensor: Output tensor (#batch, L, output_size).
  456. torch.Tensor: Output length (#batch).
  457. torch.Tensor: Not to be used now.
  458. """
  459. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  460. if (
  461. isinstance(self.embed, Conv2dSubsampling)
  462. or isinstance(self.embed, Conv2dSubsampling2)
  463. or isinstance(self.embed, Conv2dSubsampling6)
  464. or isinstance(self.embed, Conv2dSubsampling8)
  465. ):
  466. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  467. if short_status:
  468. raise TooShortUttError(
  469. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  470. + f"(it needs more than {limit_size} frames), return empty results",
  471. xs_pad.size(1),
  472. limit_size,
  473. )
  474. xs_pad, masks = self.embed(xs_pad, masks)
  475. elif self.embed is not None:
  476. xs_pad = self.embed(xs_pad)
  477. xs_pad, masks = self.encoders(xs_pad, masks)
  478. if isinstance(xs_pad, tuple):
  479. xs_pad = xs_pad[0]
  480. xs_pad = self.after_norm(xs_pad)
  481. olens = masks.squeeze(1).sum(1)
  482. return xs_pad, olens, None