|
|
@@ -0,0 +1,307 @@
|
|
|
+import torch
|
|
|
+import torch.nn.functional as F
|
|
|
+from torch import nn, einsum
|
|
|
+from einops import rearrange
|
|
|
+
|
|
|
+
|
|
|
+def identity(t, *args, **kwargs):
|
|
|
+ return t
|
|
|
+
|
|
|
+def append_dims(x, num_dims):
|
|
|
+ if num_dims <= 0:
|
|
|
+ return x
|
|
|
+ return x.view(*x.shape, *((1,) * num_dims))
|
|
|
+
|
|
|
+def exists(val):
|
|
|
+ return val is not None
|
|
|
+
|
|
|
+def default(val, d):
|
|
|
+ return val if exists(val) else d
|
|
|
+
|
|
|
+def padding_to_multiple_of(n, mult):
|
|
|
+ remainder = n % mult
|
|
|
+ if remainder == 0:
|
|
|
+ return 0
|
|
|
+ return mult - remainder
|
|
|
+
|
|
|
+
|
|
|
+class Transpose(nn.Module):
|
|
|
+ """ Wrapper class of torch.transpose() for Sequential module. """
|
|
|
+ def __init__(self, shape: tuple):
|
|
|
+ super(Transpose, self).__init__()
|
|
|
+ self.shape = shape
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return x.transpose(*self.shape)
|
|
|
+
|
|
|
+
|
|
|
+class DepthwiseConv1d(nn.Module):
|
|
|
+ """
|
|
|
+ When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
|
|
+ this operation is termed in literature as depthwise convolution.
|
|
|
+ Args:
|
|
|
+ in_channels (int): Number of channels in the input
|
|
|
+ out_channels (int): Number of channels produced by the convolution
|
|
|
+ kernel_size (int or tuple): Size of the convolving kernel
|
|
|
+ stride (int, optional): Stride of the convolution. Default: 1
|
|
|
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
|
|
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
|
|
+ Inputs: inputs
|
|
|
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
|
|
|
+ Returns: outputs
|
|
|
+ - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
|
|
+ """
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_channels: int,
|
|
|
+ out_channels: int,
|
|
|
+ kernel_size: int,
|
|
|
+ stride: int = 1,
|
|
|
+ padding: int = 0,
|
|
|
+ bias: bool = False,
|
|
|
+ ) -> None:
|
|
|
+ super(DepthwiseConv1d, self).__init__()
|
|
|
+ assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
|
|
|
+ self.conv = nn.Conv1d(
|
|
|
+ in_channels=in_channels,
|
|
|
+ out_channels=out_channels,
|
|
|
+ kernel_size=kernel_size,
|
|
|
+ groups=in_channels,
|
|
|
+ stride=stride,
|
|
|
+ padding=padding,
|
|
|
+ bias=bias,
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, inputs):
|
|
|
+ return self.conv(inputs)
|
|
|
+
|
|
|
+
|
|
|
+class ConvModule(nn.Module):
|
|
|
+ """
|
|
|
+ Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
|
|
+ This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
|
|
+ to aid training deep models.
|
|
|
+ Args:
|
|
|
+ in_channels (int): Number of channels in the input
|
|
|
+ kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
|
|
+ dropout_p (float, optional): probability of dropout
|
|
|
+ Inputs: inputs
|
|
|
+ inputs (batch, time, dim): Tensor contains input sequences
|
|
|
+ Outputs: outputs
|
|
|
+ outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
|
|
+ """
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_channels: int,
|
|
|
+ kernel_size: int = 17,
|
|
|
+ expansion_factor: int = 2,
|
|
|
+ dropout_p: float = 0.1,
|
|
|
+ ) -> None:
|
|
|
+ super(ConvModule, self).__init__()
|
|
|
+ assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
|
|
+ assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
|
|
|
+
|
|
|
+ self.sequential = nn.Sequential(
|
|
|
+ Transpose(shape=(1, 2)),
|
|
|
+ DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, inputs):
|
|
|
+ return inputs + self.sequential(inputs).transpose(1, 2)
|
|
|
+
|
|
|
+
|
|
|
+class OffsetScale(nn.Module):
|
|
|
+ def __init__(self, dim, heads = 1):
|
|
|
+ super().__init__()
|
|
|
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
|
|
|
+ self.beta = nn.Parameter(torch.zeros(heads, dim))
|
|
|
+ nn.init.normal_(self.gamma, std = 0.02)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
|
|
|
+ return out.unbind(dim = -2)
|
|
|
+
|
|
|
+
|
|
|
+class FFConvM(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ dim_in,
|
|
|
+ dim_out,
|
|
|
+ norm_klass = nn.LayerNorm,
|
|
|
+ dropout = 0.1
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.mdl = nn.Sequential(
|
|
|
+ norm_klass(dim_in),
|
|
|
+ nn.Linear(dim_in, dim_out),
|
|
|
+ nn.SiLU(),
|
|
|
+ ConvModule(dim_out),
|
|
|
+ nn.Dropout(dropout)
|
|
|
+ )
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ x,
|
|
|
+ ):
|
|
|
+ output = self.mdl(x)
|
|
|
+ return output
|
|
|
+
|
|
|
+
|
|
|
+class FLASH_ShareA_FFConvM(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ dim,
|
|
|
+ group_size = 256,
|
|
|
+ query_key_dim = 128,
|
|
|
+ expansion_factor = 1.,
|
|
|
+ causal = False,
|
|
|
+ dropout = 0.1,
|
|
|
+ rotary_pos_emb = None,
|
|
|
+ norm_klass = nn.LayerNorm,
|
|
|
+ shift_tokens = True
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ hidden_dim = int(dim * expansion_factor)
|
|
|
+ self.group_size = group_size
|
|
|
+ self.causal = causal
|
|
|
+ self.shift_tokens = shift_tokens
|
|
|
+
|
|
|
+ # positional embeddings
|
|
|
+ self.rotary_pos_emb = rotary_pos_emb
|
|
|
+ # norm
|
|
|
+ self.dropout = nn.Dropout(dropout)
|
|
|
+ # projections
|
|
|
+
|
|
|
+ self.to_hidden = FFConvM(
|
|
|
+ dim_in = dim,
|
|
|
+ dim_out = hidden_dim,
|
|
|
+ norm_klass = norm_klass,
|
|
|
+ dropout = dropout,
|
|
|
+ )
|
|
|
+ self.to_qk = FFConvM(
|
|
|
+ dim_in = dim,
|
|
|
+ dim_out = query_key_dim,
|
|
|
+ norm_klass = norm_klass,
|
|
|
+ dropout = dropout,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
|
|
|
+
|
|
|
+ self.to_out = FFConvM(
|
|
|
+ dim_in = dim*2,
|
|
|
+ dim_out = dim,
|
|
|
+ norm_klass = norm_klass,
|
|
|
+ dropout = dropout,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.gateActivate=nn.Sigmoid()
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ x,
|
|
|
+ *,
|
|
|
+ mask = None
|
|
|
+ ):
|
|
|
+
|
|
|
+ """
|
|
|
+ b - batch
|
|
|
+ n - sequence length (within groups)
|
|
|
+ g - group dimension
|
|
|
+ d - feature dimension (keys)
|
|
|
+ e - feature dimension (values)
|
|
|
+ i - sequence dimension (source)
|
|
|
+ j - sequence dimension (target)
|
|
|
+ """
|
|
|
+
|
|
|
+ normed_x = x
|
|
|
+
|
|
|
+ # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
|
|
|
+ residual = x
|
|
|
+
|
|
|
+ if self.shift_tokens:
|
|
|
+ x_shift, x_pass = normed_x.chunk(2, dim = -1)
|
|
|
+ x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
|
|
|
+ normed_x = torch.cat((x_shift, x_pass), dim = -1)
|
|
|
+
|
|
|
+ # initial projections
|
|
|
+
|
|
|
+ v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
|
|
|
+ qk = self.to_qk(normed_x)
|
|
|
+
|
|
|
+ # offset and scale
|
|
|
+ quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
|
|
|
+ att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
|
|
|
+ out = (att_u*v ) * self.gateActivate(att_v*u)
|
|
|
+ x = x + self.to_out(out)
|
|
|
+ return x
|
|
|
+
|
|
|
+ def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
|
|
|
+ b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
|
|
|
+
|
|
|
+ if exists(mask):
|
|
|
+ lin_mask = rearrange(mask, '... -> ... 1')
|
|
|
+ lin_k = lin_k.masked_fill(~lin_mask, 0.)
|
|
|
+
|
|
|
+ # rotate queries and keys
|
|
|
+
|
|
|
+ if exists(self.rotary_pos_emb):
|
|
|
+ quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
|
|
|
+
|
|
|
+ # padding for groups
|
|
|
+
|
|
|
+ padding = padding_to_multiple_of(n, g)
|
|
|
+
|
|
|
+ if padding > 0:
|
|
|
+ quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
|
|
|
+
|
|
|
+ mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
|
|
|
+ mask = F.pad(mask, (0, padding), value = False)
|
|
|
+
|
|
|
+ # group along sequence
|
|
|
+
|
|
|
+ quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
|
|
|
+
|
|
|
+ if exists(mask):
|
|
|
+ mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
|
|
|
+
|
|
|
+ # calculate quadratic attention output
|
|
|
+
|
|
|
+ sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
|
|
|
+
|
|
|
+ attn = F.relu(sim) ** 2
|
|
|
+ attn = self.dropout(attn)
|
|
|
+
|
|
|
+ if exists(mask):
|
|
|
+ attn = attn.masked_fill(~mask, 0.)
|
|
|
+
|
|
|
+ if self.causal:
|
|
|
+ causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
|
|
|
+ attn = attn.masked_fill(causal_mask, 0.)
|
|
|
+
|
|
|
+ quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
|
|
|
+ quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
|
|
|
+
|
|
|
+ # calculate linear attention output
|
|
|
+
|
|
|
+ if self.causal:
|
|
|
+ lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
|
|
|
+ # exclusive cumulative sum along group dimension
|
|
|
+ lin_kv = lin_kv.cumsum(dim = 1)
|
|
|
+ lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
|
|
|
+ lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
|
|
|
+
|
|
|
+ lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
|
|
|
+ # exclusive cumulative sum along group dimension
|
|
|
+ lin_ku = lin_ku.cumsum(dim = 1)
|
|
|
+ lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
|
|
|
+ lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
|
|
|
+ else:
|
|
|
+ lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
|
|
|
+ lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
|
|
|
+
|
|
|
+ lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
|
|
|
+ lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
|
|
|
+
|
|
|
+ # fold back groups into full sequence, and excise out padding
|
|
|
+ return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))
|
|
|
+
|