branchformer_encoder.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. # Copyright 2022 Yifan Peng (Carnegie Mellon University)
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Branchformer encoder definition.
  4. Reference:
  5. Yifan Peng, Siddharth Dalmia, Ian Lane, and Shinji Watanabe,
  6. “Branchformer: Parallel MLP-Attention Architectures to Capture
  7. Local and Global Context for Speech Recognition and Understanding,”
  8. in Proceedings of ICML, 2022.
  9. """
  10. import logging
  11. from typing import List, Optional, Tuple, Union
  12. import numpy
  13. import torch
  14. from funasr.models.encoder.abs_encoder import AbsEncoder
  15. from funasr.modules.cgmlp import ConvolutionalGatingMLP
  16. from funasr.modules.fastformer import FastSelfAttention
  17. from funasr.modules.nets_utils import make_pad_mask
  18. from funasr.modules.attention import ( # noqa: H301
  19. LegacyRelPositionMultiHeadedAttention,
  20. MultiHeadedAttention,
  21. RelPositionMultiHeadedAttention,
  22. )
  23. from funasr.modules.embedding import ( # noqa: H301
  24. LegacyRelPositionalEncoding,
  25. PositionalEncoding,
  26. RelPositionalEncoding,
  27. ScaledPositionalEncoding,
  28. )
  29. from funasr.modules.layer_norm import LayerNorm
  30. from funasr.modules.repeat import repeat
  31. from funasr.modules.subsampling import (
  32. Conv2dSubsampling,
  33. Conv2dSubsampling2,
  34. Conv2dSubsampling6,
  35. Conv2dSubsampling8,
  36. TooShortUttError,
  37. check_short_utt,
  38. )
  39. class BranchformerEncoderLayer(torch.nn.Module):
  40. """Branchformer encoder layer module.
  41. Args:
  42. size (int): model dimension
  43. attn: standard self-attention or efficient attention, optional
  44. cgmlp: ConvolutionalGatingMLP, optional
  45. dropout_rate (float): dropout probability
  46. merge_method (str): concat, learned_ave, fixed_ave
  47. cgmlp_weight (float): weight of the cgmlp branch, between 0 and 1,
  48. used if merge_method is fixed_ave
  49. attn_branch_drop_rate (float): probability of dropping the attn branch,
  50. used if merge_method is learned_ave
  51. stochastic_depth_rate (float): stochastic depth probability
  52. """
  53. def __init__(
  54. self,
  55. size: int,
  56. attn: Optional[torch.nn.Module],
  57. cgmlp: Optional[torch.nn.Module],
  58. dropout_rate: float,
  59. merge_method: str,
  60. cgmlp_weight: float = 0.5,
  61. attn_branch_drop_rate: float = 0.0,
  62. stochastic_depth_rate: float = 0.0,
  63. ):
  64. super().__init__()
  65. assert (attn is not None) or (
  66. cgmlp is not None
  67. ), "At least one branch should be valid"
  68. self.size = size
  69. self.attn = attn
  70. self.cgmlp = cgmlp
  71. self.merge_method = merge_method
  72. self.cgmlp_weight = cgmlp_weight
  73. self.attn_branch_drop_rate = attn_branch_drop_rate
  74. self.stochastic_depth_rate = stochastic_depth_rate
  75. self.use_two_branches = (attn is not None) and (cgmlp is not None)
  76. if attn is not None:
  77. self.norm_mha = LayerNorm(size) # for the MHA module
  78. if cgmlp is not None:
  79. self.norm_mlp = LayerNorm(size) # for the MLP module
  80. self.norm_final = LayerNorm(size) # for the final output of the block
  81. self.dropout = torch.nn.Dropout(dropout_rate)
  82. if self.use_two_branches:
  83. if merge_method == "concat":
  84. self.merge_proj = torch.nn.Linear(size + size, size)
  85. elif merge_method == "learned_ave":
  86. # attention-based pooling for two branches
  87. self.pooling_proj1 = torch.nn.Linear(size, 1)
  88. self.pooling_proj2 = torch.nn.Linear(size, 1)
  89. # linear projections for calculating merging weights
  90. self.weight_proj1 = torch.nn.Linear(size, 1)
  91. self.weight_proj2 = torch.nn.Linear(size, 1)
  92. # linear projection after weighted average
  93. self.merge_proj = torch.nn.Linear(size, size)
  94. elif merge_method == "fixed_ave":
  95. assert (
  96. 0.0 <= cgmlp_weight <= 1.0
  97. ), "cgmlp weight should be between 0.0 and 1.0"
  98. # remove the other branch if only one branch is used
  99. if cgmlp_weight == 0.0:
  100. self.use_two_branches = False
  101. self.cgmlp = None
  102. self.norm_mlp = None
  103. elif cgmlp_weight == 1.0:
  104. self.use_two_branches = False
  105. self.attn = None
  106. self.norm_mha = None
  107. # linear projection after weighted average
  108. self.merge_proj = torch.nn.Linear(size, size)
  109. else:
  110. raise ValueError(f"unknown merge method: {merge_method}")
  111. else:
  112. self.merge_proj = torch.nn.Identity()
  113. def forward(self, x_input, mask, cache=None):
  114. """Compute encoded features.
  115. Args:
  116. x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
  117. - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
  118. - w/o pos emb: Tensor (#batch, time, size).
  119. mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
  120. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  121. Returns:
  122. torch.Tensor: Output tensor (#batch, time, size).
  123. torch.Tensor: Mask tensor (#batch, time).
  124. """
  125. if cache is not None:
  126. raise NotImplementedError("cache is not None, which is not tested")
  127. if isinstance(x_input, tuple):
  128. x, pos_emb = x_input[0], x_input[1]
  129. else:
  130. x, pos_emb = x_input, None
  131. skip_layer = False
  132. # with stochastic depth, residual connection `x + f(x)` becomes
  133. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  134. stoch_layer_coeff = 1.0
  135. if self.training and self.stochastic_depth_rate > 0:
  136. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  137. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  138. if skip_layer:
  139. if cache is not None:
  140. x = torch.cat([cache, x], dim=1)
  141. if pos_emb is not None:
  142. return (x, pos_emb), mask
  143. return x, mask
  144. # Two branches
  145. x1 = x
  146. x2 = x
  147. # Branch 1: multi-headed attention module
  148. if self.attn is not None:
  149. x1 = self.norm_mha(x1)
  150. if isinstance(self.attn, FastSelfAttention):
  151. x_att = self.attn(x1, mask)
  152. else:
  153. if pos_emb is not None:
  154. x_att = self.attn(x1, x1, x1, pos_emb, mask)
  155. else:
  156. x_att = self.attn(x1, x1, x1, mask)
  157. x1 = self.dropout(x_att)
  158. # Branch 2: convolutional gating mlp
  159. if self.cgmlp is not None:
  160. x2 = self.norm_mlp(x2)
  161. if pos_emb is not None:
  162. x2 = (x2, pos_emb)
  163. x2 = self.cgmlp(x2, mask)
  164. if isinstance(x2, tuple):
  165. x2 = x2[0]
  166. x2 = self.dropout(x2)
  167. # Merge two branches
  168. if self.use_two_branches:
  169. if self.merge_method == "concat":
  170. x = x + stoch_layer_coeff * self.dropout(
  171. self.merge_proj(torch.cat([x1, x2], dim=-1))
  172. )
  173. elif self.merge_method == "learned_ave":
  174. if (
  175. self.training
  176. and self.attn_branch_drop_rate > 0
  177. and torch.rand(1).item() < self.attn_branch_drop_rate
  178. ):
  179. # Drop the attn branch
  180. w1, w2 = 0.0, 1.0
  181. else:
  182. # branch1
  183. score1 = (
  184. self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5
  185. ) # (batch, 1, time)
  186. if mask is not None:
  187. min_value = float(
  188. numpy.finfo(
  189. torch.tensor(0, dtype=score1.dtype).numpy().dtype
  190. ).min
  191. )
  192. score1 = score1.masked_fill(mask.eq(0), min_value)
  193. score1 = torch.softmax(score1, dim=-1).masked_fill(
  194. mask.eq(0), 0.0
  195. )
  196. else:
  197. score1 = torch.softmax(score1, dim=-1)
  198. pooled1 = torch.matmul(score1, x1).squeeze(1) # (batch, size)
  199. weight1 = self.weight_proj1(pooled1) # (batch, 1)
  200. # branch2
  201. score2 = (
  202. self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5
  203. ) # (batch, 1, time)
  204. if mask is not None:
  205. min_value = float(
  206. numpy.finfo(
  207. torch.tensor(0, dtype=score2.dtype).numpy().dtype
  208. ).min
  209. )
  210. score2 = score2.masked_fill(mask.eq(0), min_value)
  211. score2 = torch.softmax(score2, dim=-1).masked_fill(
  212. mask.eq(0), 0.0
  213. )
  214. else:
  215. score2 = torch.softmax(score2, dim=-1)
  216. pooled2 = torch.matmul(score2, x2).squeeze(1) # (batch, size)
  217. weight2 = self.weight_proj2(pooled2) # (batch, 1)
  218. # normalize weights of two branches
  219. merge_weights = torch.softmax(
  220. torch.cat([weight1, weight2], dim=-1), dim=-1
  221. ) # (batch, 2)
  222. merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
  223. -1
  224. ) # (batch, 2, 1, 1)
  225. w1, w2 = merge_weights[:, 0], merge_weights[:, 1] # (batch, 1, 1)
  226. x = x + stoch_layer_coeff * self.dropout(
  227. self.merge_proj(w1 * x1 + w2 * x2)
  228. )
  229. elif self.merge_method == "fixed_ave":
  230. x = x + stoch_layer_coeff * self.dropout(
  231. self.merge_proj(
  232. (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
  233. )
  234. )
  235. else:
  236. raise RuntimeError(f"unknown merge method: {self.merge_method}")
  237. else:
  238. if self.attn is None:
  239. x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2))
  240. elif self.cgmlp is None:
  241. x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1))
  242. else:
  243. # This should not happen
  244. raise RuntimeError("Both branches are not None, which is unexpected.")
  245. x = self.norm_final(x)
  246. if pos_emb is not None:
  247. return (x, pos_emb), mask
  248. return x, mask
  249. class BranchformerEncoder(AbsEncoder):
  250. """Branchformer encoder module."""
  251. def __init__(
  252. self,
  253. input_size: int,
  254. output_size: int = 256,
  255. use_attn: bool = True,
  256. attention_heads: int = 4,
  257. attention_layer_type: str = "rel_selfattn",
  258. pos_enc_layer_type: str = "rel_pos",
  259. rel_pos_type: str = "latest",
  260. use_cgmlp: bool = True,
  261. cgmlp_linear_units: int = 2048,
  262. cgmlp_conv_kernel: int = 31,
  263. use_linear_after_conv: bool = False,
  264. gate_activation: str = "identity",
  265. merge_method: str = "concat",
  266. cgmlp_weight: Union[float, List[float]] = 0.5,
  267. attn_branch_drop_rate: Union[float, List[float]] = 0.0,
  268. num_blocks: int = 12,
  269. dropout_rate: float = 0.1,
  270. positional_dropout_rate: float = 0.1,
  271. attention_dropout_rate: float = 0.0,
  272. input_layer: Optional[str] = "conv2d",
  273. zero_triu: bool = False,
  274. padding_idx: int = -1,
  275. stochastic_depth_rate: Union[float, List[float]] = 0.0,
  276. ):
  277. super().__init__()
  278. self._output_size = output_size
  279. if rel_pos_type == "legacy":
  280. if pos_enc_layer_type == "rel_pos":
  281. pos_enc_layer_type = "legacy_rel_pos"
  282. if attention_layer_type == "rel_selfattn":
  283. attention_layer_type = "legacy_rel_selfattn"
  284. elif rel_pos_type == "latest":
  285. assert attention_layer_type != "legacy_rel_selfattn"
  286. assert pos_enc_layer_type != "legacy_rel_pos"
  287. else:
  288. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  289. if pos_enc_layer_type == "abs_pos":
  290. pos_enc_class = PositionalEncoding
  291. elif pos_enc_layer_type == "scaled_abs_pos":
  292. pos_enc_class = ScaledPositionalEncoding
  293. elif pos_enc_layer_type == "rel_pos":
  294. assert attention_layer_type == "rel_selfattn"
  295. pos_enc_class = RelPositionalEncoding
  296. elif pos_enc_layer_type == "legacy_rel_pos":
  297. assert attention_layer_type == "legacy_rel_selfattn"
  298. pos_enc_class = LegacyRelPositionalEncoding
  299. logging.warning(
  300. "Using legacy_rel_pos and it will be deprecated in the future."
  301. )
  302. else:
  303. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  304. if input_layer == "linear":
  305. self.embed = torch.nn.Sequential(
  306. torch.nn.Linear(input_size, output_size),
  307. torch.nn.LayerNorm(output_size),
  308. torch.nn.Dropout(dropout_rate),
  309. pos_enc_class(output_size, positional_dropout_rate),
  310. )
  311. elif input_layer == "conv2d":
  312. self.embed = Conv2dSubsampling(
  313. input_size,
  314. output_size,
  315. dropout_rate,
  316. pos_enc_class(output_size, positional_dropout_rate),
  317. )
  318. elif input_layer == "conv2d2":
  319. self.embed = Conv2dSubsampling2(
  320. input_size,
  321. output_size,
  322. dropout_rate,
  323. pos_enc_class(output_size, positional_dropout_rate),
  324. )
  325. elif input_layer == "conv2d6":
  326. self.embed = Conv2dSubsampling6(
  327. input_size,
  328. output_size,
  329. dropout_rate,
  330. pos_enc_class(output_size, positional_dropout_rate),
  331. )
  332. elif input_layer == "conv2d8":
  333. self.embed = Conv2dSubsampling8(
  334. input_size,
  335. output_size,
  336. dropout_rate,
  337. pos_enc_class(output_size, positional_dropout_rate),
  338. )
  339. elif input_layer == "embed":
  340. self.embed = torch.nn.Sequential(
  341. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  342. pos_enc_class(output_size, positional_dropout_rate),
  343. )
  344. elif isinstance(input_layer, torch.nn.Module):
  345. self.embed = torch.nn.Sequential(
  346. input_layer,
  347. pos_enc_class(output_size, positional_dropout_rate),
  348. )
  349. elif input_layer is None:
  350. if input_size == output_size:
  351. self.embed = None
  352. else:
  353. self.embed = torch.nn.Linear(input_size, output_size)
  354. else:
  355. raise ValueError("unknown input_layer: " + input_layer)
  356. if attention_layer_type == "selfattn":
  357. encoder_selfattn_layer = MultiHeadedAttention
  358. encoder_selfattn_layer_args = (
  359. attention_heads,
  360. output_size,
  361. attention_dropout_rate,
  362. )
  363. elif attention_layer_type == "legacy_rel_selfattn":
  364. assert pos_enc_layer_type == "legacy_rel_pos"
  365. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  366. encoder_selfattn_layer_args = (
  367. attention_heads,
  368. output_size,
  369. attention_dropout_rate,
  370. )
  371. logging.warning(
  372. "Using legacy_rel_selfattn and it will be deprecated in the future."
  373. )
  374. elif attention_layer_type == "rel_selfattn":
  375. assert pos_enc_layer_type == "rel_pos"
  376. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  377. encoder_selfattn_layer_args = (
  378. attention_heads,
  379. output_size,
  380. attention_dropout_rate,
  381. zero_triu,
  382. )
  383. elif attention_layer_type == "fast_selfattn":
  384. assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
  385. encoder_selfattn_layer = FastSelfAttention
  386. encoder_selfattn_layer_args = (
  387. output_size,
  388. attention_heads,
  389. attention_dropout_rate,
  390. )
  391. else:
  392. raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
  393. cgmlp_layer = ConvolutionalGatingMLP
  394. cgmlp_layer_args = (
  395. output_size,
  396. cgmlp_linear_units,
  397. cgmlp_conv_kernel,
  398. dropout_rate,
  399. use_linear_after_conv,
  400. gate_activation,
  401. )
  402. if isinstance(stochastic_depth_rate, float):
  403. stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
  404. if len(stochastic_depth_rate) != num_blocks:
  405. raise ValueError(
  406. f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
  407. f"should be equal to num_blocks ({num_blocks})"
  408. )
  409. if isinstance(cgmlp_weight, float):
  410. cgmlp_weight = [cgmlp_weight] * num_blocks
  411. if len(cgmlp_weight) != num_blocks:
  412. raise ValueError(
  413. f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to "
  414. f"num_blocks ({num_blocks})"
  415. )
  416. if isinstance(attn_branch_drop_rate, float):
  417. attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks
  418. if len(attn_branch_drop_rate) != num_blocks:
  419. raise ValueError(
  420. f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
  421. f"should be equal to num_blocks ({num_blocks})"
  422. )
  423. self.encoders = repeat(
  424. num_blocks,
  425. lambda lnum: BranchformerEncoderLayer(
  426. output_size,
  427. encoder_selfattn_layer(*encoder_selfattn_layer_args)
  428. if use_attn
  429. else None,
  430. cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
  431. dropout_rate,
  432. merge_method,
  433. cgmlp_weight[lnum],
  434. attn_branch_drop_rate[lnum],
  435. stochastic_depth_rate[lnum],
  436. ),
  437. )
  438. self.after_norm = LayerNorm(output_size)
  439. def output_size(self) -> int:
  440. return self._output_size
  441. def forward(
  442. self,
  443. xs_pad: torch.Tensor,
  444. ilens: torch.Tensor,
  445. prev_states: torch.Tensor = None,
  446. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  447. """Calculate forward propagation.
  448. Args:
  449. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  450. ilens (torch.Tensor): Input length (#batch).
  451. prev_states (torch.Tensor): Not to be used now.
  452. Returns:
  453. torch.Tensor: Output tensor (#batch, L, output_size).
  454. torch.Tensor: Output length (#batch).
  455. torch.Tensor: Not to be used now.
  456. """
  457. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  458. if (
  459. isinstance(self.embed, Conv2dSubsampling)
  460. or isinstance(self.embed, Conv2dSubsampling2)
  461. or isinstance(self.embed, Conv2dSubsampling6)
  462. or isinstance(self.embed, Conv2dSubsampling8)
  463. ):
  464. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  465. if short_status:
  466. raise TooShortUttError(
  467. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  468. + f"(it needs more than {limit_size} frames), return empty results",
  469. xs_pad.size(1),
  470. limit_size,
  471. )
  472. xs_pad, masks = self.embed(xs_pad, masks)
  473. elif self.embed is not None:
  474. xs_pad = self.embed(xs_pad)
  475. xs_pad, masks = self.encoders(xs_pad, masks)
  476. if isinstance(xs_pad, tuple):
  477. xs_pad = xs_pad[0]
  478. xs_pad = self.after_norm(xs_pad)
  479. olens = masks.squeeze(1).sum(1)
  480. return xs_pad, olens, None