mossformer.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn, einsum
  4. from einops import rearrange
  5. def identity(t, *args, **kwargs):
  6. return t
  7. def append_dims(x, num_dims):
  8. if num_dims <= 0:
  9. return x
  10. return x.view(*x.shape, *((1,) * num_dims))
  11. def exists(val):
  12. return val is not None
  13. def default(val, d):
  14. return val if exists(val) else d
  15. def padding_to_multiple_of(n, mult):
  16. remainder = n % mult
  17. if remainder == 0:
  18. return 0
  19. return mult - remainder
  20. class Transpose(nn.Module):
  21. """ Wrapper class of torch.transpose() for Sequential module. """
  22. def __init__(self, shape: tuple):
  23. super(Transpose, self).__init__()
  24. self.shape = shape
  25. def forward(self, x):
  26. return x.transpose(*self.shape)
  27. class DepthwiseConv1d(nn.Module):
  28. """
  29. When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
  30. this operation is termed in literature as depthwise convolution.
  31. Args:
  32. in_channels (int): Number of channels in the input
  33. out_channels (int): Number of channels produced by the convolution
  34. kernel_size (int or tuple): Size of the convolving kernel
  35. stride (int, optional): Stride of the convolution. Default: 1
  36. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  37. bias (bool, optional): If True, adds a learnable bias to the output. Default: True
  38. Inputs: inputs
  39. - **inputs** (batch, in_channels, time): Tensor containing input vector
  40. Returns: outputs
  41. - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
  42. """
  43. def __init__(
  44. self,
  45. in_channels: int,
  46. out_channels: int,
  47. kernel_size: int,
  48. stride: int = 1,
  49. padding: int = 0,
  50. bias: bool = False,
  51. ) -> None:
  52. super(DepthwiseConv1d, self).__init__()
  53. assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
  54. self.conv = nn.Conv1d(
  55. in_channels=in_channels,
  56. out_channels=out_channels,
  57. kernel_size=kernel_size,
  58. groups=in_channels,
  59. stride=stride,
  60. padding=padding,
  61. bias=bias,
  62. )
  63. def forward(self, inputs):
  64. return self.conv(inputs)
  65. class ConvModule(nn.Module):
  66. """
  67. Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
  68. This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
  69. to aid training deep models.
  70. Args:
  71. in_channels (int): Number of channels in the input
  72. kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
  73. dropout_p (float, optional): probability of dropout
  74. Inputs: inputs
  75. inputs (batch, time, dim): Tensor contains input sequences
  76. Outputs: outputs
  77. outputs (batch, time, dim): Tensor produces by conformer convolution module.
  78. """
  79. def __init__(
  80. self,
  81. in_channels: int,
  82. kernel_size: int = 17,
  83. expansion_factor: int = 2,
  84. dropout_p: float = 0.1,
  85. ) -> None:
  86. super(ConvModule, self).__init__()
  87. assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
  88. assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
  89. self.sequential = nn.Sequential(
  90. Transpose(shape=(1, 2)),
  91. DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
  92. )
  93. def forward(self, inputs):
  94. return inputs + self.sequential(inputs).transpose(1, 2)
  95. class OffsetScale(nn.Module):
  96. def __init__(self, dim, heads = 1):
  97. super().__init__()
  98. self.gamma = nn.Parameter(torch.ones(heads, dim))
  99. self.beta = nn.Parameter(torch.zeros(heads, dim))
  100. nn.init.normal_(self.gamma, std = 0.02)
  101. def forward(self, x):
  102. out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
  103. return out.unbind(dim = -2)
  104. class FFConvM(nn.Module):
  105. def __init__(
  106. self,
  107. dim_in,
  108. dim_out,
  109. norm_klass = nn.LayerNorm,
  110. dropout = 0.1
  111. ):
  112. super().__init__()
  113. self.mdl = nn.Sequential(
  114. norm_klass(dim_in),
  115. nn.Linear(dim_in, dim_out),
  116. nn.SiLU(),
  117. ConvModule(dim_out),
  118. nn.Dropout(dropout)
  119. )
  120. def forward(
  121. self,
  122. x,
  123. ):
  124. output = self.mdl(x)
  125. return output
  126. class FLASH_ShareA_FFConvM(nn.Module):
  127. def __init__(
  128. self,
  129. *,
  130. dim,
  131. group_size = 256,
  132. query_key_dim = 128,
  133. expansion_factor = 1.,
  134. causal = False,
  135. dropout = 0.1,
  136. rotary_pos_emb = None,
  137. norm_klass = nn.LayerNorm,
  138. shift_tokens = True
  139. ):
  140. super().__init__()
  141. hidden_dim = int(dim * expansion_factor)
  142. self.group_size = group_size
  143. self.causal = causal
  144. self.shift_tokens = shift_tokens
  145. # positional embeddings
  146. self.rotary_pos_emb = rotary_pos_emb
  147. # norm
  148. self.dropout = nn.Dropout(dropout)
  149. # projections
  150. self.to_hidden = FFConvM(
  151. dim_in = dim,
  152. dim_out = hidden_dim,
  153. norm_klass = norm_klass,
  154. dropout = dropout,
  155. )
  156. self.to_qk = FFConvM(
  157. dim_in = dim,
  158. dim_out = query_key_dim,
  159. norm_klass = norm_klass,
  160. dropout = dropout,
  161. )
  162. self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
  163. self.to_out = FFConvM(
  164. dim_in = dim*2,
  165. dim_out = dim,
  166. norm_klass = norm_klass,
  167. dropout = dropout,
  168. )
  169. self.gateActivate=nn.Sigmoid()
  170. def forward(
  171. self,
  172. x,
  173. *,
  174. mask = None
  175. ):
  176. """
  177. b - batch
  178. n - sequence length (within groups)
  179. g - group dimension
  180. d - feature dimension (keys)
  181. e - feature dimension (values)
  182. i - sequence dimension (source)
  183. j - sequence dimension (target)
  184. """
  185. normed_x = x
  186. # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
  187. residual = x
  188. if self.shift_tokens:
  189. x_shift, x_pass = normed_x.chunk(2, dim = -1)
  190. x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
  191. normed_x = torch.cat((x_shift, x_pass), dim = -1)
  192. # initial projections
  193. v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
  194. qk = self.to_qk(normed_x)
  195. # offset and scale
  196. quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
  197. att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
  198. out = (att_u*v ) * self.gateActivate(att_v*u)
  199. x = x + self.to_out(out)
  200. return x
  201. def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
  202. b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
  203. if exists(mask):
  204. lin_mask = rearrange(mask, '... -> ... 1')
  205. lin_k = lin_k.masked_fill(~lin_mask, 0.)
  206. # rotate queries and keys
  207. if exists(self.rotary_pos_emb):
  208. quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
  209. # padding for groups
  210. padding = padding_to_multiple_of(n, g)
  211. if padding > 0:
  212. quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
  213. mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
  214. mask = F.pad(mask, (0, padding), value = False)
  215. # group along sequence
  216. quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
  217. if exists(mask):
  218. mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
  219. # calculate quadratic attention output
  220. sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
  221. attn = F.relu(sim) ** 2
  222. attn = self.dropout(attn)
  223. if exists(mask):
  224. attn = attn.masked_fill(~mask, 0.)
  225. if self.causal:
  226. causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
  227. attn = attn.masked_fill(causal_mask, 0.)
  228. quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
  229. quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
  230. # calculate linear attention output
  231. if self.causal:
  232. lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
  233. # exclusive cumulative sum along group dimension
  234. lin_kv = lin_kv.cumsum(dim = 1)
  235. lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
  236. lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
  237. lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
  238. # exclusive cumulative sum along group dimension
  239. lin_ku = lin_ku.cumsum(dim = 1)
  240. lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
  241. lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
  242. else:
  243. lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
  244. lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
  245. lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
  246. lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
  247. # fold back groups into full sequence, and excise out padding
  248. return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))