sanm_encoder.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
  10. from typeguard import check_argument_types
  11. import numpy as np
  12. from funasr.modules.nets_utils import make_pad_mask
  13. from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
  14. from funasr.modules.embedding import SinusoidalPositionEncoder
  15. from funasr.modules.layer_norm import LayerNorm
  16. from funasr.modules.multi_layer_conv import Conv1dLinear
  17. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  18. from funasr.modules.positionwise_feed_forward import (
  19. PositionwiseFeedForward, # noqa: H301
  20. )
  21. from funasr.modules.repeat import repeat
  22. from funasr.modules.subsampling import Conv2dSubsampling
  23. from funasr.modules.subsampling import Conv2dSubsampling2
  24. from funasr.modules.subsampling import Conv2dSubsampling6
  25. from funasr.modules.subsampling import Conv2dSubsampling8
  26. from funasr.modules.subsampling import TooShortUttError
  27. from funasr.modules.subsampling import check_short_utt
  28. from funasr.models.ctc import CTC
  29. from funasr.models.encoder.abs_encoder import AbsEncoder
  30. from funasr.modules.nets_utils import make_pad_mask
  31. from funasr.modules.mask import subsequent_mask, vad_mask
  32. class EncoderLayerSANM(nn.Module):
  33. def __init__(
  34. self,
  35. in_size,
  36. size,
  37. self_attn,
  38. feed_forward,
  39. dropout_rate,
  40. normalize_before=True,
  41. concat_after=False,
  42. stochastic_depth_rate=0.0,
  43. ):
  44. """Construct an EncoderLayer object."""
  45. super(EncoderLayerSANM, self).__init__()
  46. self.self_attn = self_attn
  47. self.feed_forward = feed_forward
  48. self.norm1 = LayerNorm(in_size)
  49. self.norm2 = LayerNorm(size)
  50. self.dropout = nn.Dropout(dropout_rate)
  51. self.in_size = in_size
  52. self.size = size
  53. self.normalize_before = normalize_before
  54. self.concat_after = concat_after
  55. if self.concat_after:
  56. self.concat_linear = nn.Linear(size + size, size)
  57. self.stochastic_depth_rate = stochastic_depth_rate
  58. self.dropout_rate = dropout_rate
  59. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  60. """Compute encoded features.
  61. Args:
  62. x_input (torch.Tensor): Input tensor (#batch, time, size).
  63. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  64. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  65. Returns:
  66. torch.Tensor: Output tensor (#batch, time, size).
  67. torch.Tensor: Mask tensor (#batch, time).
  68. """
  69. skip_layer = False
  70. # with stochastic depth, residual connection `x + f(x)` becomes
  71. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  72. stoch_layer_coeff = 1.0
  73. if self.training and self.stochastic_depth_rate > 0:
  74. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  75. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  76. if skip_layer:
  77. if cache is not None:
  78. x = torch.cat([cache, x], dim=1)
  79. return x, mask
  80. residual = x
  81. if self.normalize_before:
  82. x = self.norm1(x)
  83. if self.concat_after:
  84. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  85. if self.in_size == self.size:
  86. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  87. else:
  88. x = stoch_layer_coeff * self.concat_linear(x_concat)
  89. else:
  90. if self.in_size == self.size:
  91. x = residual + stoch_layer_coeff * self.dropout(
  92. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  93. )
  94. else:
  95. x = stoch_layer_coeff * self.dropout(
  96. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  97. )
  98. if not self.normalize_before:
  99. x = self.norm1(x)
  100. residual = x
  101. if self.normalize_before:
  102. x = self.norm2(x)
  103. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  104. if not self.normalize_before:
  105. x = self.norm2(x)
  106. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  107. class SANMEncoder(AbsEncoder):
  108. """
  109. author: Speech Lab, Alibaba Group, China
  110. """
  111. def __init__(
  112. self,
  113. input_size: int,
  114. output_size: int = 256,
  115. attention_heads: int = 4,
  116. linear_units: int = 2048,
  117. num_blocks: int = 6,
  118. dropout_rate: float = 0.1,
  119. positional_dropout_rate: float = 0.1,
  120. attention_dropout_rate: float = 0.0,
  121. input_layer: Optional[str] = "conv2d",
  122. pos_enc_class=SinusoidalPositionEncoder,
  123. normalize_before: bool = True,
  124. concat_after: bool = False,
  125. positionwise_layer_type: str = "linear",
  126. positionwise_conv_kernel_size: int = 1,
  127. padding_idx: int = -1,
  128. interctc_layer_idx: List[int] = [],
  129. interctc_use_conditioning: bool = False,
  130. kernel_size : int = 11,
  131. sanm_shfit : int = 0,
  132. selfattention_layer_type: str = "sanm",
  133. ):
  134. assert check_argument_types()
  135. super().__init__()
  136. self._output_size = output_size
  137. if input_layer == "linear":
  138. self.embed = torch.nn.Sequential(
  139. torch.nn.Linear(input_size, output_size),
  140. torch.nn.LayerNorm(output_size),
  141. torch.nn.Dropout(dropout_rate),
  142. torch.nn.ReLU(),
  143. pos_enc_class(output_size, positional_dropout_rate),
  144. )
  145. elif input_layer == "conv2d":
  146. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  147. elif input_layer == "conv2d2":
  148. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  149. elif input_layer == "conv2d6":
  150. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  151. elif input_layer == "conv2d8":
  152. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  153. elif input_layer == "embed":
  154. self.embed = torch.nn.Sequential(
  155. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  156. SinusoidalPositionEncoder(),
  157. )
  158. elif input_layer is None:
  159. if input_size == output_size:
  160. self.embed = None
  161. else:
  162. self.embed = torch.nn.Linear(input_size, output_size)
  163. elif input_layer == "pe":
  164. self.embed = SinusoidalPositionEncoder()
  165. else:
  166. raise ValueError("unknown input_layer: " + input_layer)
  167. self.normalize_before = normalize_before
  168. if positionwise_layer_type == "linear":
  169. positionwise_layer = PositionwiseFeedForward
  170. positionwise_layer_args = (
  171. output_size,
  172. linear_units,
  173. dropout_rate,
  174. )
  175. elif positionwise_layer_type == "conv1d":
  176. positionwise_layer = MultiLayeredConv1d
  177. positionwise_layer_args = (
  178. output_size,
  179. linear_units,
  180. positionwise_conv_kernel_size,
  181. dropout_rate,
  182. )
  183. elif positionwise_layer_type == "conv1d-linear":
  184. positionwise_layer = Conv1dLinear
  185. positionwise_layer_args = (
  186. output_size,
  187. linear_units,
  188. positionwise_conv_kernel_size,
  189. dropout_rate,
  190. )
  191. else:
  192. raise NotImplementedError("Support only linear or conv1d.")
  193. if selfattention_layer_type == "selfattn":
  194. encoder_selfattn_layer = MultiHeadedAttention
  195. encoder_selfattn_layer_args = (
  196. attention_heads,
  197. output_size,
  198. attention_dropout_rate,
  199. )
  200. elif selfattention_layer_type == "sanm":
  201. self.encoder_selfattn_layer = MultiHeadedAttentionSANM
  202. encoder_selfattn_layer_args0 = (
  203. attention_heads,
  204. input_size,
  205. output_size,
  206. attention_dropout_rate,
  207. kernel_size,
  208. sanm_shfit,
  209. )
  210. encoder_selfattn_layer_args = (
  211. attention_heads,
  212. output_size,
  213. output_size,
  214. attention_dropout_rate,
  215. kernel_size,
  216. sanm_shfit,
  217. )
  218. self.encoders0 = repeat(
  219. 1,
  220. lambda lnum: EncoderLayerSANM(
  221. input_size,
  222. output_size,
  223. self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  224. positionwise_layer(*positionwise_layer_args),
  225. dropout_rate,
  226. normalize_before,
  227. concat_after,
  228. ),
  229. )
  230. self.encoders = repeat(
  231. num_blocks-1,
  232. lambda lnum: EncoderLayerSANM(
  233. output_size,
  234. output_size,
  235. self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
  236. positionwise_layer(*positionwise_layer_args),
  237. dropout_rate,
  238. normalize_before,
  239. concat_after,
  240. ),
  241. )
  242. if self.normalize_before:
  243. self.after_norm = LayerNorm(output_size)
  244. self.interctc_layer_idx = interctc_layer_idx
  245. if len(interctc_layer_idx) > 0:
  246. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  247. self.interctc_use_conditioning = interctc_use_conditioning
  248. self.conditioning_layer = None
  249. self.dropout = nn.Dropout(dropout_rate)
  250. def output_size(self) -> int:
  251. return self._output_size
  252. def forward(
  253. self,
  254. xs_pad: torch.Tensor,
  255. ilens: torch.Tensor,
  256. prev_states: torch.Tensor = None,
  257. ctc: CTC = None,
  258. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  259. """Embed positions in tensor.
  260. Args:
  261. xs_pad: input tensor (B, L, D)
  262. ilens: input length (B)
  263. prev_states: Not to be used now.
  264. Returns:
  265. position embedded tensor and mask
  266. """
  267. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  268. xs_pad *= self.output_size()**0.5
  269. if self.embed is None:
  270. xs_pad = xs_pad
  271. elif (
  272. isinstance(self.embed, Conv2dSubsampling)
  273. or isinstance(self.embed, Conv2dSubsampling2)
  274. or isinstance(self.embed, Conv2dSubsampling6)
  275. or isinstance(self.embed, Conv2dSubsampling8)
  276. ):
  277. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  278. if short_status:
  279. raise TooShortUttError(
  280. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  281. + f"(it needs more than {limit_size} frames), return empty results",
  282. xs_pad.size(1),
  283. limit_size,
  284. )
  285. xs_pad, masks = self.embed(xs_pad, masks)
  286. else:
  287. xs_pad = self.embed(xs_pad)
  288. # xs_pad = self.dropout(xs_pad)
  289. encoder_outs = self.encoders0(xs_pad, masks)
  290. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  291. intermediate_outs = []
  292. if len(self.interctc_layer_idx) == 0:
  293. encoder_outs = self.encoders(xs_pad, masks)
  294. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  295. else:
  296. for layer_idx, encoder_layer in enumerate(self.encoders):
  297. encoder_outs = encoder_layer(xs_pad, masks)
  298. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  299. if layer_idx + 1 in self.interctc_layer_idx:
  300. encoder_out = xs_pad
  301. # intermediate outputs are also normalized
  302. if self.normalize_before:
  303. encoder_out = self.after_norm(encoder_out)
  304. intermediate_outs.append((layer_idx + 1, encoder_out))
  305. if self.interctc_use_conditioning:
  306. ctc_out = ctc.softmax(encoder_out)
  307. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  308. if self.normalize_before:
  309. xs_pad = self.after_norm(xs_pad)
  310. olens = masks.squeeze(1).sum(1)
  311. if len(intermediate_outs) > 0:
  312. return (xs_pad, intermediate_outs), olens, None
  313. return xs_pad, olens, None
  314. class SANMVadEncoder(AbsEncoder):
  315. """
  316. author: Speech Lab, Alibaba Group, China
  317. """
  318. def __init__(
  319. self,
  320. input_size: int,
  321. output_size: int = 256,
  322. attention_heads: int = 4,
  323. linear_units: int = 2048,
  324. num_blocks: int = 6,
  325. dropout_rate: float = 0.1,
  326. positional_dropout_rate: float = 0.1,
  327. attention_dropout_rate: float = 0.0,
  328. input_layer: Optional[str] = "conv2d",
  329. pos_enc_class=SinusoidalPositionEncoder,
  330. normalize_before: bool = True,
  331. concat_after: bool = False,
  332. positionwise_layer_type: str = "linear",
  333. positionwise_conv_kernel_size: int = 1,
  334. padding_idx: int = -1,
  335. interctc_layer_idx: List[int] = [],
  336. interctc_use_conditioning: bool = False,
  337. kernel_size : int = 11,
  338. sanm_shfit : int = 0,
  339. selfattention_layer_type: str = "sanm",
  340. ):
  341. assert check_argument_types()
  342. super().__init__()
  343. self._output_size = output_size
  344. if input_layer == "linear":
  345. self.embed = torch.nn.Sequential(
  346. torch.nn.Linear(input_size, output_size),
  347. torch.nn.LayerNorm(output_size),
  348. torch.nn.Dropout(dropout_rate),
  349. torch.nn.ReLU(),
  350. pos_enc_class(output_size, positional_dropout_rate),
  351. )
  352. elif input_layer == "conv2d":
  353. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  354. elif input_layer == "conv2d2":
  355. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  356. elif input_layer == "conv2d6":
  357. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  358. elif input_layer == "conv2d8":
  359. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  360. elif input_layer == "embed":
  361. self.embed = torch.nn.Sequential(
  362. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  363. SinusoidalPositionEncoder(),
  364. )
  365. elif input_layer is None:
  366. if input_size == output_size:
  367. self.embed = None
  368. else:
  369. self.embed = torch.nn.Linear(input_size, output_size)
  370. elif input_layer == "pe":
  371. self.embed = SinusoidalPositionEncoder()
  372. else:
  373. raise ValueError("unknown input_layer: " + input_layer)
  374. self.normalize_before = normalize_before
  375. if positionwise_layer_type == "linear":
  376. positionwise_layer = PositionwiseFeedForward
  377. positionwise_layer_args = (
  378. output_size,
  379. linear_units,
  380. dropout_rate,
  381. )
  382. elif positionwise_layer_type == "conv1d":
  383. positionwise_layer = MultiLayeredConv1d
  384. positionwise_layer_args = (
  385. output_size,
  386. linear_units,
  387. positionwise_conv_kernel_size,
  388. dropout_rate,
  389. )
  390. elif positionwise_layer_type == "conv1d-linear":
  391. positionwise_layer = Conv1dLinear
  392. positionwise_layer_args = (
  393. output_size,
  394. linear_units,
  395. positionwise_conv_kernel_size,
  396. dropout_rate,
  397. )
  398. else:
  399. raise NotImplementedError("Support only linear or conv1d.")
  400. if selfattention_layer_type == "selfattn":
  401. encoder_selfattn_layer = MultiHeadedAttention
  402. encoder_selfattn_layer_args = (
  403. attention_heads,
  404. output_size,
  405. attention_dropout_rate,
  406. )
  407. elif selfattention_layer_type == "sanm":
  408. self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
  409. encoder_selfattn_layer_args0 = (
  410. attention_heads,
  411. input_size,
  412. output_size,
  413. attention_dropout_rate,
  414. kernel_size,
  415. sanm_shfit,
  416. )
  417. encoder_selfattn_layer_args = (
  418. attention_heads,
  419. output_size,
  420. output_size,
  421. attention_dropout_rate,
  422. kernel_size,
  423. sanm_shfit,
  424. )
  425. self.encoders0 = repeat(
  426. 1,
  427. lambda lnum: EncoderLayerSANM(
  428. input_size,
  429. output_size,
  430. self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  431. positionwise_layer(*positionwise_layer_args),
  432. dropout_rate,
  433. normalize_before,
  434. concat_after,
  435. ),
  436. )
  437. self.encoders = repeat(
  438. num_blocks-1,
  439. lambda lnum: EncoderLayerSANM(
  440. output_size,
  441. output_size,
  442. self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
  443. positionwise_layer(*positionwise_layer_args),
  444. dropout_rate,
  445. normalize_before,
  446. concat_after,
  447. ),
  448. )
  449. if self.normalize_before:
  450. self.after_norm = LayerNorm(output_size)
  451. self.interctc_layer_idx = interctc_layer_idx
  452. if len(interctc_layer_idx) > 0:
  453. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  454. self.interctc_use_conditioning = interctc_use_conditioning
  455. self.conditioning_layer = None
  456. self.dropout = nn.Dropout(dropout_rate)
  457. def output_size(self) -> int:
  458. return self._output_size
  459. def forward(
  460. self,
  461. xs_pad: torch.Tensor,
  462. ilens: torch.Tensor,
  463. vad_indexes: torch.Tensor,
  464. prev_states: torch.Tensor = None,
  465. ctc: CTC = None,
  466. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  467. """Embed positions in tensor.
  468. Args:
  469. xs_pad: input tensor (B, L, D)
  470. ilens: input length (B)
  471. prev_states: Not to be used now.
  472. Returns:
  473. position embedded tensor and mask
  474. """
  475. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  476. sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
  477. no_future_masks = masks & sub_masks
  478. xs_pad *= self.output_size()**0.5
  479. if self.embed is None:
  480. xs_pad = xs_pad
  481. elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
  482. or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
  483. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  484. if short_status:
  485. raise TooShortUttError(
  486. f"has {xs_pad.size(1)} frames and is too short for subsampling " +
  487. f"(it needs more than {limit_size} frames), return empty results",
  488. xs_pad.size(1),
  489. limit_size,
  490. )
  491. xs_pad, masks = self.embed(xs_pad, masks)
  492. else:
  493. xs_pad = self.embed(xs_pad)
  494. # xs_pad = self.dropout(xs_pad)
  495. mask_tup0 = [masks, no_future_masks]
  496. encoder_outs = self.encoders0(xs_pad, mask_tup0)
  497. xs_pad, _ = encoder_outs[0], encoder_outs[1]
  498. intermediate_outs = []
  499. #if len(self.interctc_layer_idx) == 0:
  500. if False:
  501. # Here, we should not use the repeat operation to do it for all layers.
  502. encoder_outs = self.encoders(xs_pad, masks)
  503. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  504. else:
  505. for layer_idx, encoder_layer in enumerate(self.encoders):
  506. if layer_idx + 1 == len(self.encoders):
  507. # This is last layer.
  508. coner_mask = torch.ones(masks.size(0),
  509. masks.size(-1),
  510. masks.size(-1),
  511. device=xs_pad.device,
  512. dtype=torch.bool)
  513. for word_index, length in enumerate(ilens):
  514. coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
  515. vad_indexes[word_index],
  516. device=xs_pad.device)
  517. layer_mask = masks & coner_mask
  518. else:
  519. layer_mask = no_future_masks
  520. mask_tup1 = [masks, layer_mask]
  521. encoder_outs = encoder_layer(xs_pad, mask_tup1)
  522. xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
  523. if layer_idx + 1 in self.interctc_layer_idx:
  524. encoder_out = xs_pad
  525. # intermediate outputs are also normalized
  526. if self.normalize_before:
  527. encoder_out = self.after_norm(encoder_out)
  528. intermediate_outs.append((layer_idx + 1, encoder_out))
  529. if self.interctc_use_conditioning:
  530. ctc_out = ctc.softmax(encoder_out)
  531. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  532. if self.normalize_before:
  533. xs_pad = self.after_norm(xs_pad)
  534. olens = masks.squeeze(1).sum(1)
  535. if len(intermediate_outs) > 0:
  536. return (xs_pad, intermediate_outs), olens, None
  537. return xs_pad, olens, None