encoder.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # Copyright 2019 Shigeki Karita
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Transformer encoder definition."""
  4. from typing import List
  5. from typing import Optional
  6. from typing import Tuple
  7. import torch
  8. from torch import nn
  9. import logging
  10. from funasr.models.transformer.attention import MultiHeadedAttention
  11. from funasr.models.transformer.embedding import PositionalEncoding
  12. from funasr.models.transformer.layer_norm import LayerNorm
  13. from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
  14. from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
  15. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  16. from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
  17. from funasr.models.transformer.utils.repeat import repeat
  18. from funasr.models.ctc.ctc import CTC
  19. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
  20. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
  21. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
  22. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
  23. from funasr.models.transformer.utils.subsampling import TooShortUttError
  24. from funasr.models.transformer.utils.subsampling import check_short_utt
  25. from funasr.register import tables
  26. class EncoderLayer(nn.Module):
  27. """Encoder layer module.
  28. Args:
  29. size (int): Input dimension.
  30. self_attn (torch.nn.Module): Self-attention module instance.
  31. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  32. can be used as the argument.
  33. feed_forward (torch.nn.Module): Feed-forward module instance.
  34. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  35. can be used as the argument.
  36. dropout_rate (float): Dropout rate.
  37. normalize_before (bool): Whether to use layer_norm before the first block.
  38. concat_after (bool): Whether to concat attention layer's input and output.
  39. if True, additional linear will be applied.
  40. i.e. x -> x + linear(concat(x, att(x)))
  41. if False, no additional linear will be applied. i.e. x -> x + att(x)
  42. stochastic_depth_rate (float): Proability to skip this layer.
  43. During training, the layer may skip residual computation and return input
  44. as-is with given probability.
  45. """
  46. def __init__(
  47. self,
  48. size,
  49. self_attn,
  50. feed_forward,
  51. dropout_rate,
  52. normalize_before=True,
  53. concat_after=False,
  54. stochastic_depth_rate=0.0,
  55. ):
  56. """Construct an EncoderLayer object."""
  57. super(EncoderLayer, self).__init__()
  58. self.self_attn = self_attn
  59. self.feed_forward = feed_forward
  60. self.norm1 = LayerNorm(size)
  61. self.norm2 = LayerNorm(size)
  62. self.dropout = nn.Dropout(dropout_rate)
  63. self.size = size
  64. self.normalize_before = normalize_before
  65. self.concat_after = concat_after
  66. if self.concat_after:
  67. self.concat_linear = nn.Linear(size + size, size)
  68. self.stochastic_depth_rate = stochastic_depth_rate
  69. def forward(self, x, mask, cache=None):
  70. """Compute encoded features.
  71. Args:
  72. x_input (torch.Tensor): Input tensor (#batch, time, size).
  73. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  74. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  75. Returns:
  76. torch.Tensor: Output tensor (#batch, time, size).
  77. torch.Tensor: Mask tensor (#batch, time).
  78. """
  79. skip_layer = False
  80. # with stochastic depth, residual connection `x + f(x)` becomes
  81. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  82. stoch_layer_coeff = 1.0
  83. if self.training and self.stochastic_depth_rate > 0:
  84. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  85. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  86. if skip_layer:
  87. if cache is not None:
  88. x = torch.cat([cache, x], dim=1)
  89. return x, mask
  90. residual = x
  91. if self.normalize_before:
  92. x = self.norm1(x)
  93. if cache is None:
  94. x_q = x
  95. else:
  96. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  97. x_q = x[:, -1:, :]
  98. residual = residual[:, -1:, :]
  99. mask = None if mask is None else mask[:, -1:, :]
  100. if self.concat_after:
  101. x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
  102. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  103. else:
  104. x = residual + stoch_layer_coeff * self.dropout(
  105. self.self_attn(x_q, x, x, mask)
  106. )
  107. if not self.normalize_before:
  108. x = self.norm1(x)
  109. residual = x
  110. if self.normalize_before:
  111. x = self.norm2(x)
  112. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  113. if not self.normalize_before:
  114. x = self.norm2(x)
  115. if cache is not None:
  116. x = torch.cat([cache, x], dim=1)
  117. return x, mask
  118. @tables.register("encoder_classes", "TransformerEncoder")
  119. class TransformerEncoder(nn.Module):
  120. """Transformer encoder module.
  121. Args:
  122. input_size: input dim
  123. output_size: dimension of attention
  124. attention_heads: the number of heads of multi head attention
  125. linear_units: the number of units of position-wise feed forward
  126. num_blocks: the number of decoder blocks
  127. dropout_rate: dropout rate
  128. attention_dropout_rate: dropout rate in attention
  129. positional_dropout_rate: dropout rate after adding positional encoding
  130. input_layer: input layer type
  131. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  132. normalize_before: whether to use layer_norm before the first block
  133. concat_after: whether to concat attention layer's input and output
  134. if True, additional linear will be applied.
  135. i.e. x -> x + linear(concat(x, att(x)))
  136. if False, no additional linear will be applied.
  137. i.e. x -> x + att(x)
  138. positionwise_layer_type: linear of conv1d
  139. positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
  140. padding_idx: padding_idx for input_layer=embed
  141. """
  142. def __init__(
  143. self,
  144. input_size: int,
  145. output_size: int = 256,
  146. attention_heads: int = 4,
  147. linear_units: int = 2048,
  148. num_blocks: int = 6,
  149. dropout_rate: float = 0.1,
  150. positional_dropout_rate: float = 0.1,
  151. attention_dropout_rate: float = 0.0,
  152. input_layer: Optional[str] = "conv2d",
  153. pos_enc_class=PositionalEncoding,
  154. normalize_before: bool = True,
  155. concat_after: bool = False,
  156. positionwise_layer_type: str = "linear",
  157. positionwise_conv_kernel_size: int = 1,
  158. padding_idx: int = -1,
  159. interctc_layer_idx: List[int] = [],
  160. interctc_use_conditioning: bool = False,
  161. ):
  162. super().__init__()
  163. self._output_size = output_size
  164. if input_layer == "linear":
  165. self.embed = torch.nn.Sequential(
  166. torch.nn.Linear(input_size, output_size),
  167. torch.nn.LayerNorm(output_size),
  168. torch.nn.Dropout(dropout_rate),
  169. torch.nn.ReLU(),
  170. pos_enc_class(output_size, positional_dropout_rate),
  171. )
  172. elif input_layer == "conv2d":
  173. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  174. elif input_layer == "conv2d2":
  175. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  176. elif input_layer == "conv2d6":
  177. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  178. elif input_layer == "conv2d8":
  179. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  180. elif input_layer == "embed":
  181. self.embed = torch.nn.Sequential(
  182. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  183. pos_enc_class(output_size, positional_dropout_rate),
  184. )
  185. elif input_layer is None:
  186. if input_size == output_size:
  187. self.embed = None
  188. else:
  189. self.embed = torch.nn.Linear(input_size, output_size)
  190. else:
  191. raise ValueError("unknown input_layer: " + input_layer)
  192. self.normalize_before = normalize_before
  193. if positionwise_layer_type == "linear":
  194. positionwise_layer = PositionwiseFeedForward
  195. positionwise_layer_args = (
  196. output_size,
  197. linear_units,
  198. dropout_rate,
  199. )
  200. elif positionwise_layer_type == "conv1d":
  201. positionwise_layer = MultiLayeredConv1d
  202. positionwise_layer_args = (
  203. output_size,
  204. linear_units,
  205. positionwise_conv_kernel_size,
  206. dropout_rate,
  207. )
  208. elif positionwise_layer_type == "conv1d-linear":
  209. positionwise_layer = Conv1dLinear
  210. positionwise_layer_args = (
  211. output_size,
  212. linear_units,
  213. positionwise_conv_kernel_size,
  214. dropout_rate,
  215. )
  216. else:
  217. raise NotImplementedError("Support only linear or conv1d.")
  218. self.encoders = repeat(
  219. num_blocks,
  220. lambda lnum: EncoderLayer(
  221. output_size,
  222. MultiHeadedAttention(
  223. attention_heads, output_size, attention_dropout_rate
  224. ),
  225. positionwise_layer(*positionwise_layer_args),
  226. dropout_rate,
  227. normalize_before,
  228. concat_after,
  229. ),
  230. )
  231. if self.normalize_before:
  232. self.after_norm = LayerNorm(output_size)
  233. self.interctc_layer_idx = interctc_layer_idx
  234. if len(interctc_layer_idx) > 0:
  235. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  236. self.interctc_use_conditioning = interctc_use_conditioning
  237. self.conditioning_layer = None
  238. def output_size(self) -> int:
  239. return self._output_size
  240. def forward(
  241. self,
  242. xs_pad: torch.Tensor,
  243. ilens: torch.Tensor,
  244. prev_states: torch.Tensor = None,
  245. ctc: CTC = None,
  246. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  247. """Embed positions in tensor.
  248. Args:
  249. xs_pad: input tensor (B, L, D)
  250. ilens: input length (B)
  251. prev_states: Not to be used now.
  252. Returns:
  253. position embedded tensor and mask
  254. """
  255. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  256. if self.embed is None:
  257. xs_pad = xs_pad
  258. elif (
  259. isinstance(self.embed, Conv2dSubsampling)
  260. or isinstance(self.embed, Conv2dSubsampling2)
  261. or isinstance(self.embed, Conv2dSubsampling6)
  262. or isinstance(self.embed, Conv2dSubsampling8)
  263. ):
  264. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  265. if short_status:
  266. raise TooShortUttError(
  267. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  268. + f"(it needs more than {limit_size} frames), return empty results",
  269. xs_pad.size(1),
  270. limit_size,
  271. )
  272. xs_pad, masks = self.embed(xs_pad, masks)
  273. else:
  274. xs_pad = self.embed(xs_pad)
  275. intermediate_outs = []
  276. if len(self.interctc_layer_idx) == 0:
  277. xs_pad, masks = self.encoders(xs_pad, masks)
  278. else:
  279. for layer_idx, encoder_layer in enumerate(self.encoders):
  280. xs_pad, masks = encoder_layer(xs_pad, masks)
  281. if layer_idx + 1 in self.interctc_layer_idx:
  282. encoder_out = xs_pad
  283. # intermediate outputs are also normalized
  284. if self.normalize_before:
  285. encoder_out = self.after_norm(encoder_out)
  286. intermediate_outs.append((layer_idx + 1, encoder_out))
  287. if self.interctc_use_conditioning:
  288. ctc_out = ctc.softmax(encoder_out)
  289. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  290. if self.normalize_before:
  291. xs_pad = self.after_norm(xs_pad)
  292. olens = masks.squeeze(1).sum(1)
  293. if len(intermediate_outs) > 0:
  294. return (xs_pad, intermediate_outs), olens, None
  295. return xs_pad, olens, None