encoder.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. from typing import List, Optional, Tuple
  7. from funasr.register import tables
  8. from funasr.models.ctc.ctc import CTC
  9. from funasr.models.transformer.utils.repeat import repeat
  10. from funasr.models.transformer.layer_norm import LayerNorm
  11. from funasr.models.sanm.attention import MultiHeadedAttention
  12. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  13. from funasr.models.transformer.utils.subsampling import check_short_utt
  14. from funasr.models.transformer.utils.subsampling import TooShortUttError
  15. from funasr.models.transformer.embedding import SinusoidalPositionEncoder
  16. from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
  17. from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
  18. from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
  19. from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
  20. from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask
  21. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8
  22. class EncoderLayerSANM(torch.nn.Module):
  23. def __init__(
  24. self,
  25. in_size,
  26. size,
  27. self_attn,
  28. feed_forward,
  29. dropout_rate,
  30. normalize_before=True,
  31. concat_after=False,
  32. stochastic_depth_rate=0.0,
  33. ):
  34. """Construct an EncoderLayer object."""
  35. super(EncoderLayerSANM, self).__init__()
  36. self.self_attn = self_attn
  37. self.feed_forward = feed_forward
  38. self.norm1 = LayerNorm(in_size)
  39. self.norm2 = LayerNorm(size)
  40. self.dropout = torch.nn.Dropout(dropout_rate)
  41. self.in_size = in_size
  42. self.size = size
  43. self.normalize_before = normalize_before
  44. self.concat_after = concat_after
  45. if self.concat_after:
  46. self.concat_linear = torch.nn.Linear(size + size, size)
  47. self.stochastic_depth_rate = stochastic_depth_rate
  48. self.dropout_rate = dropout_rate
  49. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  50. """Compute encoded features.
  51. Args:
  52. x_input (torch.Tensor): Input tensor (#batch, time, size).
  53. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  54. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  55. Returns:
  56. torch.Tensor: Output tensor (#batch, time, size).
  57. torch.Tensor: Mask tensor (#batch, time).
  58. """
  59. skip_layer = False
  60. # with stochastic depth, residual connection `x + f(x)` becomes
  61. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  62. stoch_layer_coeff = 1.0
  63. if self.training and self.stochastic_depth_rate > 0:
  64. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  65. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  66. if skip_layer:
  67. if cache is not None:
  68. x = torch.cat([cache, x], dim=1)
  69. return x, mask
  70. residual = x
  71. if self.normalize_before:
  72. x = self.norm1(x)
  73. if self.concat_after:
  74. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  75. if self.in_size == self.size:
  76. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  77. else:
  78. x = stoch_layer_coeff * self.concat_linear(x_concat)
  79. else:
  80. if self.in_size == self.size:
  81. x = residual + stoch_layer_coeff * self.dropout(
  82. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  83. )
  84. else:
  85. x = stoch_layer_coeff * self.dropout(
  86. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  87. )
  88. if not self.normalize_before:
  89. x = self.norm1(x)
  90. residual = x
  91. if self.normalize_before:
  92. x = self.norm2(x)
  93. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  94. if not self.normalize_before:
  95. x = self.norm2(x)
  96. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  97. def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
  98. """Compute encoded features.
  99. Args:
  100. x_input (torch.Tensor): Input tensor (#batch, time, size).
  101. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  102. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  103. Returns:
  104. torch.Tensor: Output tensor (#batch, time, size).
  105. torch.Tensor: Mask tensor (#batch, time).
  106. """
  107. residual = x
  108. if self.normalize_before:
  109. x = self.norm1(x)
  110. if self.in_size == self.size:
  111. attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  112. x = residual + attn
  113. else:
  114. x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  115. if not self.normalize_before:
  116. x = self.norm1(x)
  117. residual = x
  118. if self.normalize_before:
  119. x = self.norm2(x)
  120. x = residual + self.feed_forward(x)
  121. if not self.normalize_before:
  122. x = self.norm2(x)
  123. return x, cache
  124. @tables.register("encoder_classes", "SANMVadEncoder")
  125. class SANMVadEncoder(torch.nn.Module):
  126. """
  127. Author: Speech Lab of DAMO Academy, Alibaba Group
  128. """
  129. def __init__(
  130. self,
  131. input_size: int,
  132. output_size: int = 256,
  133. attention_heads: int = 4,
  134. linear_units: int = 2048,
  135. num_blocks: int = 6,
  136. dropout_rate: float = 0.1,
  137. positional_dropout_rate: float = 0.1,
  138. attention_dropout_rate: float = 0.0,
  139. input_layer: Optional[str] = "conv2d",
  140. pos_enc_class=SinusoidalPositionEncoder,
  141. normalize_before: bool = True,
  142. concat_after: bool = False,
  143. positionwise_layer_type: str = "linear",
  144. positionwise_conv_kernel_size: int = 1,
  145. padding_idx: int = -1,
  146. interctc_layer_idx: List[int] = [],
  147. interctc_use_conditioning: bool = False,
  148. kernel_size : int = 11,
  149. sanm_shfit : int = 0,
  150. selfattention_layer_type: str = "sanm",
  151. ):
  152. super().__init__()
  153. self._output_size = output_size
  154. if input_layer == "linear":
  155. self.embed = torch.nn.Sequential(
  156. torch.nn.Linear(input_size, output_size),
  157. torch.nn.LayerNorm(output_size),
  158. torch.nn.Dropout(dropout_rate),
  159. torch.nn.ReLU(),
  160. pos_enc_class(output_size, positional_dropout_rate),
  161. )
  162. elif input_layer == "conv2d":
  163. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  164. elif input_layer == "conv2d2":
  165. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  166. elif input_layer == "conv2d6":
  167. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  168. elif input_layer == "conv2d8":
  169. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  170. elif input_layer == "embed":
  171. self.embed = torch.nn.Sequential(
  172. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  173. SinusoidalPositionEncoder(),
  174. )
  175. elif input_layer is None:
  176. if input_size == output_size:
  177. self.embed = None
  178. else:
  179. self.embed = torch.nn.Linear(input_size, output_size)
  180. elif input_layer == "pe":
  181. self.embed = SinusoidalPositionEncoder()
  182. else:
  183. raise ValueError("unknown input_layer: " + input_layer)
  184. self.normalize_before = normalize_before
  185. if positionwise_layer_type == "linear":
  186. positionwise_layer = PositionwiseFeedForward
  187. positionwise_layer_args = (
  188. output_size,
  189. linear_units,
  190. dropout_rate,
  191. )
  192. elif positionwise_layer_type == "conv1d":
  193. positionwise_layer = MultiLayeredConv1d
  194. positionwise_layer_args = (
  195. output_size,
  196. linear_units,
  197. positionwise_conv_kernel_size,
  198. dropout_rate,
  199. )
  200. elif positionwise_layer_type == "conv1d-linear":
  201. positionwise_layer = Conv1dLinear
  202. positionwise_layer_args = (
  203. output_size,
  204. linear_units,
  205. positionwise_conv_kernel_size,
  206. dropout_rate,
  207. )
  208. else:
  209. raise NotImplementedError("Support only linear or conv1d.")
  210. if selfattention_layer_type == "selfattn":
  211. encoder_selfattn_layer = MultiHeadedAttention
  212. encoder_selfattn_layer_args = (
  213. attention_heads,
  214. output_size,
  215. attention_dropout_rate,
  216. )
  217. elif selfattention_layer_type == "sanm":
  218. self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
  219. encoder_selfattn_layer_args0 = (
  220. attention_heads,
  221. input_size,
  222. output_size,
  223. attention_dropout_rate,
  224. kernel_size,
  225. sanm_shfit,
  226. )
  227. encoder_selfattn_layer_args = (
  228. attention_heads,
  229. output_size,
  230. output_size,
  231. attention_dropout_rate,
  232. kernel_size,
  233. sanm_shfit,
  234. )
  235. self.encoders0 = repeat(
  236. 1,
  237. lambda lnum: EncoderLayerSANM(
  238. input_size,
  239. output_size,
  240. self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  241. positionwise_layer(*positionwise_layer_args),
  242. dropout_rate,
  243. normalize_before,
  244. concat_after,
  245. ),
  246. )
  247. self.encoders = repeat(
  248. num_blocks-1,
  249. lambda lnum: EncoderLayerSANM(
  250. output_size,
  251. output_size,
  252. self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
  253. positionwise_layer(*positionwise_layer_args),
  254. dropout_rate,
  255. normalize_before,
  256. concat_after,
  257. ),
  258. )
  259. if self.normalize_before:
  260. self.after_norm = LayerNorm(output_size)
  261. self.interctc_layer_idx = interctc_layer_idx
  262. if len(interctc_layer_idx) > 0:
  263. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  264. self.interctc_use_conditioning = interctc_use_conditioning
  265. self.conditioning_layer = None
  266. self.dropout = torch.nn.Dropout(dropout_rate)
  267. def output_size(self) -> int:
  268. return self._output_size
  269. def forward(
  270. self,
  271. xs_pad: torch.Tensor,
  272. ilens: torch.Tensor,
  273. vad_indexes: torch.Tensor,
  274. prev_states: torch.Tensor = None,
  275. ctc: CTC = None,
  276. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  277. """Embed positions in tensor.
  278. Args:
  279. xs_pad: input tensor (B, L, D)
  280. ilens: input length (B)
  281. prev_states: Not to be used now.
  282. Returns:
  283. position embedded tensor and mask
  284. """
  285. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  286. sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
  287. no_future_masks = masks & sub_masks
  288. xs_pad *= self.output_size()**0.5
  289. if self.embed is None:
  290. xs_pad = xs_pad
  291. elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
  292. or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
  293. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  294. if short_status:
  295. raise TooShortUttError(
  296. f"has {xs_pad.size(1)} frames and is too short for subsampling " +
  297. f"(it needs more than {limit_size} frames), return empty results",
  298. xs_pad.size(1),
  299. limit_size,
  300. )
  301. xs_pad, masks = self.embed(xs_pad, masks)
  302. else:
  303. xs_pad = self.embed(xs_pad)
  304. # xs_pad = self.dropout(xs_pad)
  305. mask_tup0 = [masks, no_future_masks]
  306. encoder_outs = self.encoders0(xs_pad, mask_tup0)
  307. xs_pad, _ = encoder_outs[0], encoder_outs[1]
  308. intermediate_outs = []
  309. for layer_idx, encoder_layer in enumerate(self.encoders):
  310. if layer_idx + 1 == len(self.encoders):
  311. # This is last layer.
  312. coner_mask = torch.ones(masks.size(0),
  313. masks.size(-1),
  314. masks.size(-1),
  315. device=xs_pad.device,
  316. dtype=torch.bool)
  317. for word_index, length in enumerate(ilens):
  318. coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
  319. vad_indexes[word_index],
  320. device=xs_pad.device)
  321. layer_mask = masks & coner_mask
  322. else:
  323. layer_mask = no_future_masks
  324. mask_tup1 = [masks, layer_mask]
  325. encoder_outs = encoder_layer(xs_pad, mask_tup1)
  326. xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
  327. if self.normalize_before:
  328. xs_pad = self.after_norm(xs_pad)
  329. olens = masks.squeeze(1).sum(1)
  330. if len(intermediate_outs) > 0:
  331. return (xs_pad, intermediate_outs), olens, None
  332. return xs_pad, olens, None