| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- import torch
- import torch.nn.functional as F
- from torch import nn, einsum
- from einops import rearrange
- def identity(t, *args, **kwargs):
- return t
- def append_dims(x, num_dims):
- if num_dims <= 0:
- return x
- return x.view(*x.shape, *((1,) * num_dims))
- def exists(val):
- return val is not None
- def default(val, d):
- return val if exists(val) else d
- def padding_to_multiple_of(n, mult):
- remainder = n % mult
- if remainder == 0:
- return 0
- return mult - remainder
- class Transpose(nn.Module):
- """ Wrapper class of torch.transpose() for Sequential module. """
- def __init__(self, shape: tuple):
- super(Transpose, self).__init__()
- self.shape = shape
- def forward(self, x):
- return x.transpose(*self.shape)
- class DepthwiseConv1d(nn.Module):
- """
- When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
- this operation is termed in literature as depthwise convolution.
- Args:
- in_channels (int): Number of channels in the input
- out_channels (int): Number of channels produced by the convolution
- kernel_size (int or tuple): Size of the convolving kernel
- stride (int, optional): Stride of the convolution. Default: 1
- padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
- bias (bool, optional): If True, adds a learnable bias to the output. Default: True
- Inputs: inputs
- - **inputs** (batch, in_channels, time): Tensor containing input vector
- Returns: outputs
- - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
- """
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int = 1,
- padding: int = 0,
- bias: bool = False,
- ) -> None:
- super(DepthwiseConv1d, self).__init__()
- assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
- self.conv = nn.Conv1d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- groups=in_channels,
- stride=stride,
- padding=padding,
- bias=bias,
- )
- def forward(self, inputs):
- return self.conv(inputs)
- class ConvModule(nn.Module):
- """
- Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
- This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
- to aid training deep models.
- Args:
- in_channels (int): Number of channels in the input
- kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
- dropout_p (float, optional): probability of dropout
- Inputs: inputs
- inputs (batch, time, dim): Tensor contains input sequences
- Outputs: outputs
- outputs (batch, time, dim): Tensor produces by conformer convolution module.
- """
- def __init__(
- self,
- in_channels: int,
- kernel_size: int = 17,
- expansion_factor: int = 2,
- dropout_p: float = 0.1,
- ) -> None:
- super(ConvModule, self).__init__()
- assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
- assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
- self.sequential = nn.Sequential(
- Transpose(shape=(1, 2)),
- DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
- )
- def forward(self, inputs):
- return inputs + self.sequential(inputs).transpose(1, 2)
- class OffsetScale(nn.Module):
- def __init__(self, dim, heads = 1):
- super().__init__()
- self.gamma = nn.Parameter(torch.ones(heads, dim))
- self.beta = nn.Parameter(torch.zeros(heads, dim))
- nn.init.normal_(self.gamma, std = 0.02)
- def forward(self, x):
- out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
- return out.unbind(dim = -2)
- class FFConvM(nn.Module):
- def __init__(
- self,
- dim_in,
- dim_out,
- norm_klass = nn.LayerNorm,
- dropout = 0.1
- ):
- super().__init__()
- self.mdl = nn.Sequential(
- norm_klass(dim_in),
- nn.Linear(dim_in, dim_out),
- nn.SiLU(),
- ConvModule(dim_out),
- nn.Dropout(dropout)
- )
- def forward(
- self,
- x,
- ):
- output = self.mdl(x)
- return output
- class FLASH_ShareA_FFConvM(nn.Module):
- def __init__(
- self,
- *,
- dim,
- group_size = 256,
- query_key_dim = 128,
- expansion_factor = 1.,
- causal = False,
- dropout = 0.1,
- rotary_pos_emb = None,
- norm_klass = nn.LayerNorm,
- shift_tokens = True
- ):
- super().__init__()
- hidden_dim = int(dim * expansion_factor)
- self.group_size = group_size
- self.causal = causal
- self.shift_tokens = shift_tokens
- # positional embeddings
- self.rotary_pos_emb = rotary_pos_emb
- # norm
- self.dropout = nn.Dropout(dropout)
- # projections
-
- self.to_hidden = FFConvM(
- dim_in = dim,
- dim_out = hidden_dim,
- norm_klass = norm_klass,
- dropout = dropout,
- )
- self.to_qk = FFConvM(
- dim_in = dim,
- dim_out = query_key_dim,
- norm_klass = norm_klass,
- dropout = dropout,
- )
- self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
- self.to_out = FFConvM(
- dim_in = dim*2,
- dim_out = dim,
- norm_klass = norm_klass,
- dropout = dropout,
- )
-
- self.gateActivate=nn.Sigmoid()
- def forward(
- self,
- x,
- *,
- mask = None
- ):
- """
- b - batch
- n - sequence length (within groups)
- g - group dimension
- d - feature dimension (keys)
- e - feature dimension (values)
- i - sequence dimension (source)
- j - sequence dimension (target)
- """
- normed_x = x
- # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
- residual = x
- if self.shift_tokens:
- x_shift, x_pass = normed_x.chunk(2, dim = -1)
- x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
- normed_x = torch.cat((x_shift, x_pass), dim = -1)
- # initial projections
- v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
- qk = self.to_qk(normed_x)
- # offset and scale
- quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
- att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
- out = (att_u*v ) * self.gateActivate(att_v*u)
- x = x + self.to_out(out)
- return x
- def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
- b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
- if exists(mask):
- lin_mask = rearrange(mask, '... -> ... 1')
- lin_k = lin_k.masked_fill(~lin_mask, 0.)
- # rotate queries and keys
- if exists(self.rotary_pos_emb):
- quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
- # padding for groups
- padding = padding_to_multiple_of(n, g)
- if padding > 0:
- quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
- mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
- mask = F.pad(mask, (0, padding), value = False)
- # group along sequence
- quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
- if exists(mask):
- mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
- # calculate quadratic attention output
- sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
- attn = F.relu(sim) ** 2
- attn = self.dropout(attn)
- if exists(mask):
- attn = attn.masked_fill(~mask, 0.)
- if self.causal:
- causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
- attn = attn.masked_fill(causal_mask, 0.)
- quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
- quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
- # calculate linear attention output
- if self.causal:
- lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
- # exclusive cumulative sum along group dimension
- lin_kv = lin_kv.cumsum(dim = 1)
- lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
- lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
- lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
- # exclusive cumulative sum along group dimension
- lin_ku = lin_ku.cumsum(dim = 1)
- lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
- lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
- else:
- lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
- lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
- lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
- lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
- # fold back groups into full sequence, and excise out padding
- return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))
|