encoder.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # Copyright 2019 Shigeki Karita
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Transformer encoder definition."""
  4. from typing import List
  5. from typing import Optional
  6. from typing import Tuple
  7. import torch
  8. from torch import nn
  9. import logging
  10. from funasr.models.transformer.attention import MultiHeadedAttention
  11. from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
  12. from funasr.models.transformer.embedding import PositionalEncoding
  13. from funasr.models.transformer.layer_norm import LayerNorm
  14. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  15. from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
  16. from funasr.models.transformer.utils.repeat import repeat
  17. from funasr.register import tables
  18. class EncoderLayer(nn.Module):
  19. """Encoder layer module.
  20. Args:
  21. size (int): Input dimension.
  22. self_attn (torch.nn.Module): Self-attention module instance.
  23. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
  24. can be used as the argument.
  25. feed_forward (torch.nn.Module): Feed-forward module instance.
  26. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  27. can be used as the argument.
  28. dropout_rate (float): Dropout rate.
  29. normalize_before (bool): Whether to use layer_norm before the first block.
  30. concat_after (bool): Whether to concat attention layer's input and output.
  31. if True, additional linear will be applied.
  32. i.e. x -> x + linear(concat(x, att(x)))
  33. if False, no additional linear will be applied. i.e. x -> x + att(x)
  34. stochastic_depth_rate (float): Proability to skip this layer.
  35. During training, the layer may skip residual computation and return input
  36. as-is with given probability.
  37. """
  38. def __init__(
  39. self,
  40. size,
  41. self_attn,
  42. feed_forward,
  43. dropout_rate,
  44. normalize_before=True,
  45. concat_after=False,
  46. stochastic_depth_rate=0.0,
  47. ):
  48. """Construct an EncoderLayer object."""
  49. super(EncoderLayer, self).__init__()
  50. self.self_attn = self_attn
  51. self.feed_forward = feed_forward
  52. self.norm1 = LayerNorm(size)
  53. self.norm2 = LayerNorm(size)
  54. self.dropout = nn.Dropout(dropout_rate)
  55. self.size = size
  56. self.normalize_before = normalize_before
  57. self.concat_after = concat_after
  58. if self.concat_after:
  59. self.concat_linear = nn.Linear(size + size, size)
  60. self.stochastic_depth_rate = stochastic_depth_rate
  61. def forward(self, x, mask, cache=None):
  62. """Compute encoded features.
  63. Args:
  64. x_input (torch.Tensor): Input tensor (#batch, time, size).
  65. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  66. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  67. Returns:
  68. torch.Tensor: Output tensor (#batch, time, size).
  69. torch.Tensor: Mask tensor (#batch, time).
  70. """
  71. skip_layer = False
  72. # with stochastic depth, residual connection `x + f(x)` becomes
  73. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  74. stoch_layer_coeff = 1.0
  75. if self.training and self.stochastic_depth_rate > 0:
  76. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  77. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  78. if skip_layer:
  79. if cache is not None:
  80. x = torch.cat([cache, x], dim=1)
  81. return x, mask
  82. residual = x
  83. if self.normalize_before:
  84. x = self.norm1(x)
  85. if cache is None:
  86. x_q = x
  87. else:
  88. assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
  89. x_q = x[:, -1:, :]
  90. residual = residual[:, -1:, :]
  91. mask = None if mask is None else mask[:, -1:, :]
  92. if self.concat_after:
  93. x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
  94. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  95. else:
  96. x = residual + stoch_layer_coeff * self.dropout(
  97. self.self_attn(x_q, x, x, mask)
  98. )
  99. if not self.normalize_before:
  100. x = self.norm1(x)
  101. residual = x
  102. if self.normalize_before:
  103. x = self.norm2(x)
  104. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  105. if not self.normalize_before:
  106. x = self.norm2(x)
  107. if cache is not None:
  108. x = torch.cat([cache, x], dim=1)
  109. return x, mask
  110. @tables.register("encoder_classes", "TransformerTextEncoder")
  111. class TransformerTextEncoder(nn.Module):
  112. """Transformer text encoder module.
  113. Args:
  114. input_size: input dim
  115. output_size: dimension of attention
  116. attention_heads: the number of heads of multi head attention
  117. linear_units: the number of units of position-wise feed forward
  118. num_blocks: the number of decoder blocks
  119. dropout_rate: dropout rate
  120. attention_dropout_rate: dropout rate in attention
  121. positional_dropout_rate: dropout rate after adding positional encoding
  122. input_layer: input layer type
  123. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  124. normalize_before: whether to use layer_norm before the first block
  125. concat_after: whether to concat attention layer's input and output
  126. if True, additional linear will be applied.
  127. i.e. x -> x + linear(concat(x, att(x)))
  128. if False, no additional linear will be applied.
  129. i.e. x -> x + att(x)
  130. positionwise_layer_type: linear of conv1d
  131. positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
  132. padding_idx: padding_idx for input_layer=embed
  133. """
  134. def __init__(
  135. self,
  136. input_size: int,
  137. output_size: int = 256,
  138. attention_heads: int = 4,
  139. linear_units: int = 2048,
  140. num_blocks: int = 6,
  141. dropout_rate: float = 0.1,
  142. positional_dropout_rate: float = 0.1,
  143. attention_dropout_rate: float = 0.0,
  144. pos_enc_class=PositionalEncoding,
  145. normalize_before: bool = True,
  146. concat_after: bool = False,
  147. ):
  148. super().__init__()
  149. self._output_size = output_size
  150. self.embed = torch.nn.Sequential(
  151. torch.nn.Embedding(input_size, output_size),
  152. pos_enc_class(output_size, positional_dropout_rate),
  153. )
  154. self.normalize_before = normalize_before
  155. positionwise_layer = PositionwiseFeedForward
  156. positionwise_layer_args = (
  157. output_size,
  158. linear_units,
  159. dropout_rate,
  160. )
  161. self.encoders = repeat(
  162. num_blocks,
  163. lambda lnum: EncoderLayer(
  164. output_size,
  165. MultiHeadedAttention(
  166. attention_heads, output_size, attention_dropout_rate
  167. ),
  168. positionwise_layer(*positionwise_layer_args),
  169. dropout_rate,
  170. normalize_before,
  171. concat_after,
  172. ),
  173. )
  174. if self.normalize_before:
  175. self.after_norm = LayerNorm(output_size)
  176. def output_size(self) -> int:
  177. return self._output_size
  178. def forward(
  179. self,
  180. xs_pad: torch.Tensor,
  181. ilens: torch.Tensor,
  182. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  183. """Embed positions in tensor.
  184. Args:
  185. xs_pad: input tensor (B, L, D)
  186. ilens: input length (B)
  187. Returns:
  188. position embedded tensor and mask
  189. """
  190. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  191. xs_pad = self.embed(xs_pad)
  192. xs_pad, masks = self.encoders(xs_pad, masks)
  193. if self.normalize_before:
  194. xs_pad = self.after_norm(xs_pad)
  195. olens = masks.squeeze(1).sum(1)
  196. return xs_pad, olens, None
  197. @tables.register("encoder_classes", "FusionSANEncoder")
  198. class SelfSrcAttention(nn.Module):
  199. """Single decoder layer module.
  200. Args:
  201. size (int): Input dimension.
  202. self_attn (torch.nn.Module): Self-attention module instance.
  203. `MultiHeadedAttention` instance can be used as the argument.
  204. src_attn (torch.nn.Module): Self-attention module instance.
  205. `MultiHeadedAttention` instance can be used as the argument.
  206. feed_forward (torch.nn.Module): Feed-forward module instance.
  207. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  208. can be used as the argument.
  209. dropout_rate (float): Dropout rate.
  210. normalize_before (bool): Whether to use layer_norm before the first block.
  211. concat_after (bool): Whether to concat attention layer's input and output.
  212. if True, additional linear will be applied.
  213. i.e. x -> x + linear(concat(x, att(x)))
  214. if False, no additional linear will be applied. i.e. x -> x + att(x)
  215. """
  216. def __init__(
  217. self,
  218. size,
  219. attention_heads,
  220. attention_dim,
  221. linear_units,
  222. self_attention_dropout_rate,
  223. src_attention_dropout_rate,
  224. positional_dropout_rate,
  225. dropout_rate,
  226. normalize_before=True,
  227. concat_after=False,
  228. ):
  229. """Construct an SelfSrcAttention object."""
  230. super(SelfSrcAttention, self).__init__()
  231. self.size = size
  232. self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
  233. self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
  234. self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
  235. self.norm1 = LayerNorm(size)
  236. self.norm2 = LayerNorm(size)
  237. self.norm3 = LayerNorm(size)
  238. self.dropout = nn.Dropout(dropout_rate)
  239. self.normalize_before = normalize_before
  240. self.concat_after = concat_after
  241. if self.concat_after:
  242. self.concat_linear1 = nn.Linear(size + size, size)
  243. self.concat_linear2 = nn.Linear(size + size, size)
  244. def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
  245. """Compute decoded features.
  246. Args:
  247. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  248. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  249. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  250. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  251. cache (List[torch.Tensor]): List of cached tensors.
  252. Each tensor shape should be (#batch, maxlen_out - 1, size).
  253. Returns:
  254. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  255. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  256. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  257. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  258. """
  259. residual = tgt
  260. if self.normalize_before:
  261. tgt = self.norm1(tgt)
  262. if cache is None:
  263. tgt_q = tgt
  264. tgt_q_mask = tgt_mask
  265. else:
  266. # compute only the last frame query keeping dim: max_time_out -> 1
  267. assert cache.shape == (
  268. tgt.shape[0],
  269. tgt.shape[1] - 1,
  270. self.size,
  271. ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
  272. tgt_q = tgt[:, -1:, :]
  273. residual = residual[:, -1:, :]
  274. tgt_q_mask = None
  275. if tgt_mask is not None:
  276. tgt_q_mask = tgt_mask[:, -1:, :]
  277. if self.concat_after:
  278. tgt_concat = torch.cat(
  279. (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
  280. )
  281. x = residual + self.concat_linear1(tgt_concat)
  282. else:
  283. x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
  284. if not self.normalize_before:
  285. x = self.norm1(x)
  286. residual = x
  287. if self.normalize_before:
  288. x = self.norm2(x)
  289. if self.concat_after:
  290. x_concat = torch.cat(
  291. (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
  292. )
  293. x = residual + self.concat_linear2(x_concat)
  294. else:
  295. x, score = self.src_attn(x, memory, memory, memory_mask)
  296. x = residual + self.dropout(x)
  297. if not self.normalize_before:
  298. x = self.norm2(x)
  299. residual = x
  300. if self.normalize_before:
  301. x = self.norm3(x)
  302. x = residual + self.dropout(self.feed_forward(x))
  303. if not self.normalize_before:
  304. x = self.norm3(x)
  305. if cache is not None:
  306. x = torch.cat([cache, x], dim=1)
  307. return x, tgt_mask, memory, memory_mask
  308. @tables.register("encoder_classes", "ConvBiasPredictor")
  309. class ConvPredictor(nn.Module):
  310. def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
  311. super().__init__()
  312. self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
  313. self.norm1 = LayerNorm(size)
  314. self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
  315. self.norm2 = LayerNorm(size)
  316. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  317. self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
  318. self.output_linear = nn.Linear(size, 1)
  319. def forward(self, text_enc, asr_enc):
  320. # stage1 cross-attention
  321. residual = text_enc
  322. text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
  323. # stage2 FFN
  324. residual = text_enc
  325. text_enc = self.norm1(text_enc)
  326. text_enc = residual + self.feed_forward(text_enc)
  327. # stage Conv predictor
  328. text_enc = self.norm2(text_enc)
  329. context = text_enc.transpose(1, 2)
  330. queries = self.pad(context)
  331. memory = self.conv1d(queries)
  332. output = memory + context
  333. output = output.transpose(1, 2)
  334. output = torch.relu(output)
  335. output = self.output_linear(output)
  336. if output.dim()==3:
  337. output = output.squeeze(2)
  338. return output