modules.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. import numpy as np
  7. import torch.nn as nn
  8. from enum import Enum, auto
  9. import torch.nn.functional as F
  10. from dataclasses import dataclass
  11. from funasr.models.emotion2vec.fairseq_modules import (
  12. LayerNorm,
  13. SamePad,
  14. TransposeLast,
  15. )
  16. class Modality(Enum):
  17. AUDIO = auto()
  18. @dataclass
  19. class D2vDecoderConfig:
  20. decoder_dim: int = 384
  21. decoder_groups: int = 16
  22. decoder_kernel: int = 5
  23. decoder_layers: int = 5
  24. input_dropout: float = 0.1
  25. add_positions_masked: bool = False
  26. add_positions_all: bool = False
  27. decoder_residual: bool = True
  28. projection_layers: int = 1
  29. projection_ratio: float = 2.0
  30. class FixedPositionalEncoder(nn.Module):
  31. def __init__(self, pos_embed):
  32. super().__init__()
  33. self.positions = pos_embed
  34. def forward(self, x, padding_mask):
  35. return self.positions
  36. class TextFeatPositionalEncoder(nn.Module):
  37. """
  38. Original encoder expects (B, T) long input. This module wraps it to take
  39. local_encoder output which are (B, T, D) float tensors
  40. """
  41. def __init__(self, pos_encoder):
  42. super().__init__()
  43. self.pos_encoder = pos_encoder
  44. def forward(self, x, padding_mask):
  45. # assume padded token embeddings are 0s
  46. # TODO: consider using padding_mask as input
  47. return self.pos_encoder(x[..., 0])
  48. class BlockEncoder(nn.Module):
  49. def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
  50. super().__init__()
  51. self.blocks = blocks
  52. self.norm = norm_layer
  53. self.layer_norm_first = layer_norm_first
  54. self.layerdrop = layerdrop
  55. self.dropout = nn.Dropout(dropout, inplace=True)
  56. def forward(self, x, padding_mask, alibi_bias, alibi_scale):
  57. if self.norm is not None and not self.layer_norm_first:
  58. x = self.norm(x)
  59. x = self.dropout(x)
  60. for i, blk in enumerate(self.blocks):
  61. if (
  62. not self.training
  63. or self.layerdrop == 0
  64. or (np.random.random() > self.layerdrop)
  65. ):
  66. ab = alibi_bias
  67. if ab is not None and alibi_scale is not None:
  68. scale = (
  69. alibi_scale[i]
  70. if alibi_scale.size(0) > 1
  71. else alibi_scale.squeeze(0)
  72. )
  73. ab = ab * scale.type_as(ab)
  74. x, _ = blk(x, padding_mask, ab)
  75. if self.norm is not None and self.layer_norm_first:
  76. x = self.norm(x)
  77. return x
  78. class DecoderBase(nn.Module):
  79. decoder_cfg: D2vDecoderConfig
  80. def __init__(self, cfg: D2vDecoderConfig):
  81. super().__init__()
  82. self.decoder_cfg = cfg
  83. def reset_parameters(self):
  84. for mod in self.proj.modules():
  85. if isinstance(mod, nn.Linear):
  86. mod.reset_parameters()
  87. def add_residual(self, x, residual, i, mask_info):
  88. if (
  89. residual is None
  90. or not self.decoder_cfg.decoder_residual
  91. or residual.size(1) != x.size(1)
  92. ):
  93. return x
  94. ret = x + residual
  95. return ret
  96. class Decoder1d(DecoderBase):
  97. def __init__(self, cfg: D2vDecoderConfig, input_dim):
  98. super().__init__(cfg)
  99. def make_block(in_dim):
  100. block = [
  101. nn.Conv1d(
  102. in_dim,
  103. cfg.decoder_dim,
  104. kernel_size=cfg.decoder_kernel,
  105. padding=cfg.decoder_kernel // 2,
  106. groups=cfg.decoder_groups,
  107. ),
  108. SamePad(cfg.decoder_kernel),
  109. TransposeLast(),
  110. LayerNorm(cfg.decoder_dim, elementwise_affine=False),
  111. TransposeLast(),
  112. nn.GELU(),
  113. ]
  114. return nn.Sequential(*block)
  115. self.blocks = nn.Sequential(
  116. *[
  117. make_block(input_dim if i == 0 else cfg.decoder_dim)
  118. for i in range(cfg.decoder_layers)
  119. ]
  120. )
  121. projs = []
  122. curr_dim = cfg.decoder_dim
  123. for i in range(cfg.projection_layers - 1):
  124. next_dim = int(curr_dim * cfg.projection_ratio) if i == 0 else curr_dim
  125. projs.append(nn.Linear(curr_dim, next_dim))
  126. projs.append(nn.GELU())
  127. curr_dim = next_dim
  128. projs.append(nn.Linear(curr_dim, input_dim))
  129. if len(projs) == 1:
  130. self.proj = projs[0]
  131. else:
  132. self.proj = nn.Sequential(*projs)
  133. def forward(self, x, mask_info):
  134. x = x.transpose(1, 2)
  135. residual = x
  136. for i, layer in enumerate(self.blocks):
  137. x = layer(x)
  138. x = self.add_residual(x, residual, i, mask_info)
  139. residual = x
  140. x = x.transpose(1, 2)
  141. x = self.proj(x)
  142. return x
  143. class AltBlock(nn.Module):
  144. def __init__(
  145. self,
  146. dim,
  147. num_heads,
  148. mlp_ratio=4.0,
  149. qkv_bias=False,
  150. qk_scale=None,
  151. drop=0.0,
  152. attn_drop=0.0,
  153. mlp_drop=0.0,
  154. post_mlp_drop=0.0,
  155. drop_path=0.0,
  156. act_layer=nn.GELU,
  157. norm_layer=nn.LayerNorm,
  158. layer_norm_first=True,
  159. ffn_targets=False,
  160. cosine_attention=False,
  161. ):
  162. super().__init__()
  163. self.layer_norm_first = layer_norm_first
  164. self.ffn_targets = ffn_targets
  165. from funasr.models.emotion2vec.timm_modules import DropPath, Mlp
  166. self.norm1 = norm_layer(dim)
  167. self.attn = AltAttention(
  168. dim,
  169. num_heads=num_heads,
  170. qkv_bias=qkv_bias,
  171. qk_scale=qk_scale,
  172. attn_drop=attn_drop,
  173. proj_drop=drop,
  174. cosine_attention=cosine_attention,
  175. )
  176. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  177. self.norm2 = norm_layer(dim)
  178. mlp_hidden_dim = int(dim * mlp_ratio)
  179. self.mlp = Mlp(
  180. in_features=dim,
  181. hidden_features=mlp_hidden_dim,
  182. act_layer=act_layer,
  183. drop=mlp_drop,
  184. )
  185. self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
  186. def forward(self, x, padding_mask=None, alibi_bias=None):
  187. if self.layer_norm_first:
  188. x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
  189. r = x = self.mlp(self.norm2(x))
  190. t = x
  191. x = r + self.drop_path(self.post_mlp_dropout(x))
  192. if not self.ffn_targets:
  193. t = x
  194. else:
  195. x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
  196. r = x = self.norm1(x)
  197. x = self.mlp(x)
  198. t = x
  199. x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
  200. if not self.ffn_targets:
  201. t = x
  202. return x, t
  203. class AltAttention(nn.Module):
  204. def __init__(
  205. self,
  206. dim,
  207. num_heads=8,
  208. qkv_bias=False,
  209. qk_scale=None,
  210. attn_drop=0.0,
  211. proj_drop=0.0,
  212. cosine_attention=False,
  213. ):
  214. super().__init__()
  215. self.num_heads = num_heads
  216. head_dim = dim // num_heads
  217. self.scale = qk_scale or head_dim ** -0.5
  218. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  219. self.attn_drop = nn.Dropout(attn_drop)
  220. self.proj = nn.Linear(dim, dim)
  221. self.proj_drop = nn.Dropout(proj_drop)
  222. self.cosine_attention = cosine_attention
  223. if cosine_attention:
  224. self.logit_scale = nn.Parameter(
  225. torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
  226. )
  227. def forward(self, x, padding_mask=None, alibi_bias=None):
  228. B, N, C = x.shape
  229. qkv = (
  230. self.qkv(x)
  231. .reshape(B, N, 3, self.num_heads, C // self.num_heads)
  232. .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
  233. )
  234. q, k, v = (
  235. qkv[0],
  236. qkv[1],
  237. qkv[2],
  238. ) # make torchscript happy (cannot use tensor as tuple)
  239. dtype = q.dtype
  240. if self.cosine_attention:
  241. # cosine attention
  242. attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
  243. logit_scale = torch.clamp(
  244. self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
  245. ).exp()
  246. attn = attn * logit_scale
  247. else:
  248. q = q * self.scale
  249. attn = q @ k.transpose(-2, -1)
  250. if alibi_bias is not None:
  251. attn = attn.type_as(alibi_bias)
  252. attn[:, : alibi_bias.size(1)] += alibi_bias
  253. if padding_mask is not None and padding_mask.any():
  254. attn = attn.masked_fill(
  255. padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
  256. float("-inf"),
  257. )
  258. attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
  259. attn = self.attn_drop(attn)
  260. x = (attn @ v).transpose(1, 2) #
  261. x = x.reshape(B, N, C)
  262. x = self.proj(x)
  263. x = self.proj_drop(x)
  264. return x