transformer_encoder.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # Copyright 2019 Shigeki Karita
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Transformer encoder definition."""
  4. from typing import List
  5. from typing import Optional
  6. from typing import Tuple
  7. import torch
  8. from torch import nn
  9. import logging
  10. from funasr.models.transformer.attention import MultiHeadedAttention
  11. from funasr.models.transformer.embedding import PositionalEncoding
  12. from funasr.models.transformer.layer_norm import LayerNorm
  13. from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
  14. from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
  15. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  16. from funasr.models.transformer.positionwise_feed_forward import (
  17. PositionwiseFeedForward, # noqa: H301
  18. )
  19. from funasr.models.transformer.utils.repeat import repeat
  20. from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
  21. from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
  22. from funasr.models.transformer.utils.lightconv import LightweightConvolution
  23. from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
  24. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
  25. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
  26. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
  27. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
  28. from funasr.models.transformer.utils.subsampling import TooShortUttError
  29. from funasr.models.transformer.utils.subsampling import check_short_utt
  30. class EncoderLayer(nn.Module):
  31. """Encoder layer module.
  32. Args:
  33. size (int): Input dimension.
  34. self_attn (torch.nn.Module): Self-attention module instance.
  35. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  36. can be used as the argument.
  37. feed_forward (torch.nn.Module): Feed-forward module instance.
  38. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  39. can be used as the argument.
  40. dropout_rate (float): Dropout rate.
  41. normalize_before (bool): Whether to use layer_norm before the first block.
  42. concat_after (bool): Whether to concat attention layer's input and output.
  43. if True, additional linear will be applied.
  44. i.e. x -> x + linear(concat(x, att(x)))
  45. if False, no additional linear will be applied. i.e. x -> x + att(x)
  46. stochastic_depth_rate (float): Proability to skip this layer.
  47. During training, the layer may skip residual computation and return input
  48. as-is with given probability.
  49. """
  50. def __init__(
  51. self,
  52. size,
  53. self_attn,
  54. feed_forward,
  55. dropout_rate,
  56. normalize_before=True,
  57. concat_after=False,
  58. stochastic_depth_rate=0.0,
  59. ):
  60. """Construct an EncoderLayer object."""
  61. super(EncoderLayer, self).__init__()
  62. self.self_attn = self_attn
  63. self.feed_forward = feed_forward
  64. self.norm1 = LayerNorm(size)
  65. self.norm2 = LayerNorm(size)
  66. self.dropout = nn.Dropout(dropout_rate)
  67. self.size = size
  68. self.normalize_before = normalize_before
  69. self.concat_after = concat_after
  70. if self.concat_after:
  71. self.concat_linear = nn.Linear(size + size, size)
  72. self.stochastic_depth_rate = stochastic_depth_rate
  73. def forward(self, x, mask, cache=None):
  74. """Compute encoded features.
  75. Args:
  76. x_input (torch.Tensor): Input tensor (#batch, time, size).
  77. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  78. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  79. Returns:
  80. torch.Tensor: Output tensor (#batch, time, size).
  81. torch.Tensor: Mask tensor (#batch, time).
  82. """
  83. skip_layer = False
  84. # with stochastic depth, residual connection `x + f(x)` becomes
  85. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  86. stoch_layer_coeff = 1.0
  87. if self.training and self.stochastic_depth_rate > 0:
  88. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  89. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  90. if skip_layer:
  91. if cache is not None:
  92. x = torch.cat([cache, x], dim=1)
  93. return x, mask
  94. residual = x
  95. if self.normalize_before:
  96. x = self.norm1(x)
  97. if cache is None:
  98. x_q = x
  99. else:
  100. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  101. x_q = x[:, -1:, :]
  102. residual = residual[:, -1:, :]
  103. mask = None if mask is None else mask[:, -1:, :]
  104. if self.concat_after:
  105. x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
  106. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  107. else:
  108. x = residual + stoch_layer_coeff * self.dropout(
  109. self.self_attn(x_q, x, x, mask)
  110. )
  111. if not self.normalize_before:
  112. x = self.norm1(x)
  113. residual = x
  114. if self.normalize_before:
  115. x = self.norm2(x)
  116. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  117. if not self.normalize_before:
  118. x = self.norm2(x)
  119. if cache is not None:
  120. x = torch.cat([cache, x], dim=1)
  121. return x, mask
  122. class TransformerEncoder_lm(nn.Module):
  123. """Transformer encoder module.
  124. Args:
  125. idim (int): Input dimension.
  126. attention_dim (int): Dimension of attention.
  127. attention_heads (int): The number of heads of multi head attention.
  128. conv_wshare (int): The number of kernel of convolution. Only used in
  129. selfattention_layer_type == "lightconv*" or "dynamiconv*".
  130. conv_kernel_length (Union[int, str]): Kernel size str of convolution
  131. (e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type
  132. == "lightconv*" or "dynamiconv*".
  133. conv_usebias (bool): Whether to use bias in convolution. Only used in
  134. selfattention_layer_type == "lightconv*" or "dynamiconv*".
  135. linear_units (int): The number of units of position-wise feed forward.
  136. num_blocks (int): The number of decoder blocks.
  137. dropout_rate (float): Dropout rate.
  138. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  139. attention_dropout_rate (float): Dropout rate in attention.
  140. input_layer (Union[str, torch.nn.Module]): Input layer type.
  141. pos_enc_class (torch.nn.Module): Positional encoding module class.
  142. `PositionalEncoding `or `ScaledPositionalEncoding`
  143. normalize_before (bool): Whether to use layer_norm before the first block.
  144. concat_after (bool): Whether to concat attention layer's input and output.
  145. if True, additional linear will be applied.
  146. i.e. x -> x + linear(concat(x, att(x)))
  147. if False, no additional linear will be applied. i.e. x -> x + att(x)
  148. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  149. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  150. selfattention_layer_type (str): Encoder attention layer type.
  151. padding_idx (int): Padding idx for input_layer=embed.
  152. stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
  153. intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
  154. indices start from 1.
  155. if not None, intermediate outputs are returned (which changes return type
  156. signature.)
  157. """
  158. def __init__(
  159. self,
  160. idim,
  161. attention_dim=256,
  162. attention_heads=4,
  163. conv_wshare=4,
  164. conv_kernel_length="11",
  165. conv_usebias=False,
  166. linear_units=2048,
  167. num_blocks=6,
  168. dropout_rate=0.1,
  169. positional_dropout_rate=0.1,
  170. attention_dropout_rate=0.0,
  171. input_layer="conv2d",
  172. pos_enc_class=PositionalEncoding,
  173. normalize_before=True,
  174. concat_after=False,
  175. positionwise_layer_type="linear",
  176. positionwise_conv_kernel_size=1,
  177. selfattention_layer_type="selfattn",
  178. padding_idx=-1,
  179. stochastic_depth_rate=0.0,
  180. intermediate_layers=None,
  181. ctc_softmax=None,
  182. conditioning_layer_dim=None,
  183. ):
  184. """Construct an Encoder object."""
  185. super().__init__()
  186. self.conv_subsampling_factor = 1
  187. if input_layer == "linear":
  188. self.embed = torch.nn.Sequential(
  189. torch.nn.Linear(idim, attention_dim),
  190. torch.nn.LayerNorm(attention_dim),
  191. torch.nn.Dropout(dropout_rate),
  192. torch.nn.ReLU(),
  193. pos_enc_class(attention_dim, positional_dropout_rate),
  194. )
  195. elif input_layer == "conv2d":
  196. self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
  197. self.conv_subsampling_factor = 4
  198. elif input_layer == "conv2d-scaled-pos-enc":
  199. self.embed = Conv2dSubsampling(
  200. idim,
  201. attention_dim,
  202. dropout_rate,
  203. pos_enc_class(attention_dim, positional_dropout_rate),
  204. )
  205. self.conv_subsampling_factor = 4
  206. elif input_layer == "conv2d6":
  207. self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate)
  208. self.conv_subsampling_factor = 6
  209. elif input_layer == "conv2d8":
  210. self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate)
  211. self.conv_subsampling_factor = 8
  212. elif input_layer == "embed":
  213. self.embed = torch.nn.Sequential(
  214. torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
  215. pos_enc_class(attention_dim, positional_dropout_rate),
  216. )
  217. elif isinstance(input_layer, torch.nn.Module):
  218. self.embed = torch.nn.Sequential(
  219. input_layer,
  220. pos_enc_class(attention_dim, positional_dropout_rate),
  221. )
  222. elif input_layer is None:
  223. self.embed = torch.nn.Sequential(
  224. pos_enc_class(attention_dim, positional_dropout_rate)
  225. )
  226. else:
  227. raise ValueError("unknown input_layer: " + input_layer)
  228. self.normalize_before = normalize_before
  229. positionwise_layer, positionwise_layer_args = self.get_positionwise_layer(
  230. positionwise_layer_type,
  231. attention_dim,
  232. linear_units,
  233. dropout_rate,
  234. positionwise_conv_kernel_size,
  235. )
  236. if selfattention_layer_type in [
  237. "selfattn",
  238. "rel_selfattn",
  239. "legacy_rel_selfattn",
  240. ]:
  241. logging.info("encoder self-attention layer type = self-attention")
  242. encoder_selfattn_layer = MultiHeadedAttention
  243. encoder_selfattn_layer_args = [
  244. (
  245. attention_heads,
  246. attention_dim,
  247. attention_dropout_rate,
  248. )
  249. ] * num_blocks
  250. elif selfattention_layer_type == "lightconv":
  251. logging.info("encoder self-attention layer type = lightweight convolution")
  252. encoder_selfattn_layer = LightweightConvolution
  253. encoder_selfattn_layer_args = [
  254. (
  255. conv_wshare,
  256. attention_dim,
  257. attention_dropout_rate,
  258. int(conv_kernel_length.split("_")[lnum]),
  259. False,
  260. conv_usebias,
  261. )
  262. for lnum in range(num_blocks)
  263. ]
  264. elif selfattention_layer_type == "lightconv2d":
  265. logging.info(
  266. "encoder self-attention layer "
  267. "type = lightweight convolution 2-dimensional"
  268. )
  269. encoder_selfattn_layer = LightweightConvolution2D
  270. encoder_selfattn_layer_args = [
  271. (
  272. conv_wshare,
  273. attention_dim,
  274. attention_dropout_rate,
  275. int(conv_kernel_length.split("_")[lnum]),
  276. False,
  277. conv_usebias,
  278. )
  279. for lnum in range(num_blocks)
  280. ]
  281. elif selfattention_layer_type == "dynamicconv":
  282. logging.info("encoder self-attention layer type = dynamic convolution")
  283. encoder_selfattn_layer = DynamicConvolution
  284. encoder_selfattn_layer_args = [
  285. (
  286. conv_wshare,
  287. attention_dim,
  288. attention_dropout_rate,
  289. int(conv_kernel_length.split("_")[lnum]),
  290. False,
  291. conv_usebias,
  292. )
  293. for lnum in range(num_blocks)
  294. ]
  295. elif selfattention_layer_type == "dynamicconv2d":
  296. logging.info(
  297. "encoder self-attention layer type = dynamic convolution 2-dimensional"
  298. )
  299. encoder_selfattn_layer = DynamicConvolution2D
  300. encoder_selfattn_layer_args = [
  301. (
  302. conv_wshare,
  303. attention_dim,
  304. attention_dropout_rate,
  305. int(conv_kernel_length.split("_")[lnum]),
  306. False,
  307. conv_usebias,
  308. )
  309. for lnum in range(num_blocks)
  310. ]
  311. else:
  312. raise NotImplementedError(selfattention_layer_type)
  313. self.encoders = repeat(
  314. num_blocks,
  315. lambda lnum: EncoderLayer(
  316. attention_dim,
  317. encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]),
  318. positionwise_layer(*positionwise_layer_args),
  319. dropout_rate,
  320. normalize_before,
  321. concat_after,
  322. stochastic_depth_rate * float(1 + lnum) / num_blocks,
  323. ),
  324. )
  325. if self.normalize_before:
  326. self.after_norm = LayerNorm(attention_dim)
  327. self.intermediate_layers = intermediate_layers
  328. self.use_conditioning = True if ctc_softmax is not None else False
  329. if self.use_conditioning:
  330. self.ctc_softmax = ctc_softmax
  331. self.conditioning_layer = torch.nn.Linear(
  332. conditioning_layer_dim, attention_dim
  333. )
  334. def get_positionwise_layer(
  335. self,
  336. positionwise_layer_type="linear",
  337. attention_dim=256,
  338. linear_units=2048,
  339. dropout_rate=0.1,
  340. positionwise_conv_kernel_size=1,
  341. ):
  342. """Define positionwise layer."""
  343. if positionwise_layer_type == "linear":
  344. positionwise_layer = PositionwiseFeedForward
  345. positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
  346. elif positionwise_layer_type == "conv1d":
  347. positionwise_layer = MultiLayeredConv1d
  348. positionwise_layer_args = (
  349. attention_dim,
  350. linear_units,
  351. positionwise_conv_kernel_size,
  352. dropout_rate,
  353. )
  354. elif positionwise_layer_type == "conv1d-linear":
  355. positionwise_layer = Conv1dLinear
  356. positionwise_layer_args = (
  357. attention_dim,
  358. linear_units,
  359. positionwise_conv_kernel_size,
  360. dropout_rate,
  361. )
  362. else:
  363. raise NotImplementedError("Support only linear or conv1d.")
  364. return positionwise_layer, positionwise_layer_args
  365. def forward(self, xs, masks):
  366. """Encode input sequence.
  367. Args:
  368. xs (torch.Tensor): Input tensor (#batch, time, idim).
  369. masks (torch.Tensor): Mask tensor (#batch, time).
  370. Returns:
  371. torch.Tensor: Output tensor (#batch, time, attention_dim).
  372. torch.Tensor: Mask tensor (#batch, time).
  373. """
  374. if isinstance(
  375. self.embed,
  376. (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8),
  377. ):
  378. xs, masks = self.embed(xs, masks)
  379. else:
  380. xs = self.embed(xs)
  381. if self.intermediate_layers is None:
  382. xs, masks = self.encoders(xs, masks)
  383. else:
  384. intermediate_outputs = []
  385. for layer_idx, encoder_layer in enumerate(self.encoders):
  386. xs, masks = encoder_layer(xs, masks)
  387. if (
  388. self.intermediate_layers is not None
  389. and layer_idx + 1 in self.intermediate_layers
  390. ):
  391. encoder_output = xs
  392. # intermediate branches also require normalization.
  393. if self.normalize_before:
  394. encoder_output = self.after_norm(encoder_output)
  395. intermediate_outputs.append(encoder_output)
  396. if self.use_conditioning:
  397. intermediate_result = self.ctc_softmax(encoder_output)
  398. xs = xs + self.conditioning_layer(intermediate_result)
  399. if self.normalize_before:
  400. xs = self.after_norm(xs)
  401. if self.intermediate_layers is not None:
  402. return xs, masks, intermediate_outputs
  403. return xs, masks
  404. def forward_one_step(self, xs, masks, cache=None):
  405. """Encode input frame.
  406. Args:
  407. xs (torch.Tensor): Input tensor.
  408. masks (torch.Tensor): Mask tensor.
  409. cache (List[torch.Tensor]): List of cache tensors.
  410. Returns:
  411. torch.Tensor: Output tensor.
  412. torch.Tensor: Mask tensor.
  413. List[torch.Tensor]: List of new cache tensors.
  414. """
  415. if isinstance(self.embed, Conv2dSubsampling):
  416. xs, masks = self.embed(xs, masks)
  417. else:
  418. xs = self.embed(xs)
  419. if cache is None:
  420. cache = [None for _ in range(len(self.encoders))]
  421. new_cache = []
  422. for c, e in zip(cache, self.encoders):
  423. xs, masks = e(xs, masks, cache=c)
  424. new_cache.append(xs)
  425. if self.normalize_before:
  426. xs = self.after_norm(xs)
  427. return xs, masks, new_cache