|
@@ -0,0 +1,632 @@
|
|
|
|
|
+"""Attention (time mixing) modules for RWKV block.
|
|
|
|
|
+
|
|
|
|
|
+Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py.
|
|
|
|
|
+
|
|
|
|
|
+Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
|
|
|
|
|
+
|
|
|
|
|
+""" # noqa
|
|
|
|
|
+
|
|
|
|
|
+import math
|
|
|
|
|
+from importlib.util import find_spec
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import List, Optional, Tuple, Union
|
|
|
|
|
+
|
|
|
|
|
+import torch
|
|
|
|
|
+
|
|
|
|
|
+wkv_kernel_encoder = None
|
|
|
|
|
+wkv_kernel_decoder = None
|
|
|
|
|
+
|
|
|
|
|
+class WKVLinearAttentionEncoder(torch.autograd.Function):
|
|
|
|
|
+ """WKVLinearAttention function definition."""
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ ctx,
|
|
|
|
|
+ time_decay: torch.Tensor,
|
|
|
|
|
+ time_first: torch.Tensor,
|
|
|
|
|
+ key: torch.Tensor,
|
|
|
|
|
+ value: torch.tensor,
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ """WKVLinearAttention function forward pass.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ time_decay: Channel-wise time decay vector. (D_att)
|
|
|
|
|
+ time_first: Channel-wise time first vector. (D_att)
|
|
|
|
|
+ key: Key tensor. (B, U, D_att)
|
|
|
|
|
+ value: Value tensor. (B, U, D_att)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ out: Weighted Key-Value tensor. (B, U, D_att)
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ batch, length, dim = key.size()
|
|
|
|
|
+
|
|
|
|
|
+ assert length <= wkv_kernel_encoder.context_size, (
|
|
|
|
|
+ f"Cannot process key of length {length} while context_size "
|
|
|
|
|
+ f"is ({wkv_kernel_encoder.context_size}). Limit should be increased."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ assert batch * dim % min(dim, 32) == 0, (
|
|
|
|
|
+ f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
|
|
|
|
|
+ f"{min(dim, 32)}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ ctx.input_dtype = key.dtype
|
|
|
|
|
+
|
|
|
|
|
+ time_decay = -torch.exp(time_decay.float().contiguous())
|
|
|
|
|
+ time_first = time_first.float().contiguous()
|
|
|
|
|
+
|
|
|
|
|
+ key = key.float().contiguous()
|
|
|
|
|
+ value = value.float().contiguous()
|
|
|
|
|
+
|
|
|
|
|
+ out = torch.empty_like(key, memory_format=torch.contiguous_format)
|
|
|
|
|
+
|
|
|
|
|
+ wkv_kernel_encoder.forward(time_decay, time_first, key, value, out)
|
|
|
|
|
+ ctx.save_for_backward(time_decay, time_first, key, value, out)
|
|
|
|
|
+
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def backward(
|
|
|
|
|
+ ctx, grad_output: torch.Tensor
|
|
|
|
|
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
+ """WKVLinearAttention function backward pass.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ grad_output: Output gradient. (B, U, D_att)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
|
|
|
|
|
+ grad_time_first: Gradient for channel-wise time first vector. (D_att)
|
|
|
|
|
+ grad_key: Gradient for key tensor. (B, U, D_att)
|
|
|
|
|
+ grad_value: Gradient for value tensor. (B, U, D_att)
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ time_decay, time_first, key, value, output = ctx.saved_tensors
|
|
|
|
|
+ grad_dtype = ctx.input_dtype
|
|
|
|
|
+
|
|
|
|
|
+ batch, _, dim = key.size()
|
|
|
|
|
+
|
|
|
|
|
+ grad_time_decay = torch.empty(
|
|
|
|
|
+ (batch, dim),
|
|
|
|
|
+ memory_format=torch.contiguous_format,
|
|
|
|
|
+ dtype=time_decay.dtype,
|
|
|
|
|
+ device=time_decay.device,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ grad_time_first = torch.empty(
|
|
|
|
|
+ (batch, dim),
|
|
|
|
|
+ memory_format=torch.contiguous_format,
|
|
|
|
|
+ dtype=time_decay.dtype,
|
|
|
|
|
+ device=time_decay.device,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
|
|
|
|
|
+ grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
|
|
|
|
|
+
|
|
|
|
|
+ wkv_kernel_encoder.backward(
|
|
|
|
|
+ time_decay,
|
|
|
|
|
+ time_first,
|
|
|
|
|
+ key,
|
|
|
|
|
+ value,
|
|
|
|
|
+ output,
|
|
|
|
|
+ grad_output.contiguous(),
|
|
|
|
|
+ grad_time_decay,
|
|
|
|
|
+ grad_time_first,
|
|
|
|
|
+ grad_key,
|
|
|
|
|
+ grad_value,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ grad_time_decay = torch.sum(grad_time_decay, dim=0)
|
|
|
|
|
+ grad_time_first = torch.sum(grad_time_first, dim=0)
|
|
|
|
|
+
|
|
|
|
|
+ return (
|
|
|
|
|
+ grad_time_decay,
|
|
|
|
|
+ grad_time_first,
|
|
|
|
|
+ grad_key,
|
|
|
|
|
+ grad_value,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+class WKVLinearAttentionDecoder(torch.autograd.Function):
|
|
|
|
|
+ """WKVLinearAttention function definition."""
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ ctx,
|
|
|
|
|
+ time_decay: torch.Tensor,
|
|
|
|
|
+ time_first: torch.Tensor,
|
|
|
|
|
+ key: torch.Tensor,
|
|
|
|
|
+ value: torch.tensor,
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ """WKVLinearAttention function forward pass.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ time_decay: Channel-wise time decay vector. (D_att)
|
|
|
|
|
+ time_first: Channel-wise time first vector. (D_att)
|
|
|
|
|
+ key: Key tensor. (B, U, D_att)
|
|
|
|
|
+ value: Value tensor. (B, U, D_att)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ out: Weighted Key-Value tensor. (B, U, D_att)
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ batch, length, dim = key.size()
|
|
|
|
|
+
|
|
|
|
|
+ assert length <= wkv_kernel_decoder.context_size, (
|
|
|
|
|
+ f"Cannot process key of length {length} while context_size "
|
|
|
|
|
+ f"is ({wkv_kernel.context_size}). Limit should be increased."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ assert batch * dim % min(dim, 32) == 0, (
|
|
|
|
|
+ f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
|
|
|
|
|
+ f"{min(dim, 32)}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ ctx.input_dtype = key.dtype
|
|
|
|
|
+
|
|
|
|
|
+ time_decay = -torch.exp(time_decay.float().contiguous())
|
|
|
|
|
+ time_first = time_first.float().contiguous()
|
|
|
|
|
+
|
|
|
|
|
+ key = key.float().contiguous()
|
|
|
|
|
+ value = value.float().contiguous()
|
|
|
|
|
+
|
|
|
|
|
+ out = torch.empty_like(key, memory_format=torch.contiguous_format)
|
|
|
|
|
+
|
|
|
|
|
+ wkv_kernel_decoder.forward(time_decay, time_first, key, value, out)
|
|
|
|
|
+ ctx.save_for_backward(time_decay, time_first, key, value, out)
|
|
|
|
|
+
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def backward(
|
|
|
|
|
+ ctx, grad_output: torch.Tensor
|
|
|
|
|
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
+ """WKVLinearAttention function backward pass.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ grad_output: Output gradient. (B, U, D_att)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
|
|
|
|
|
+ grad_time_first: Gradient for channel-wise time first vector. (D_att)
|
|
|
|
|
+ grad_key: Gradient for key tensor. (B, U, D_att)
|
|
|
|
|
+ grad_value: Gradient for value tensor. (B, U, D_att)
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ time_decay, time_first, key, value, output = ctx.saved_tensors
|
|
|
|
|
+ grad_dtype = ctx.input_dtype
|
|
|
|
|
+
|
|
|
|
|
+ batch, _, dim = key.size()
|
|
|
|
|
+
|
|
|
|
|
+ grad_time_decay = torch.empty(
|
|
|
|
|
+ (batch, dim),
|
|
|
|
|
+ memory_format=torch.contiguous_format,
|
|
|
|
|
+ dtype=time_decay.dtype,
|
|
|
|
|
+ device=time_decay.device,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ grad_time_first = torch.empty(
|
|
|
|
|
+ (batch, dim),
|
|
|
|
|
+ memory_format=torch.contiguous_format,
|
|
|
|
|
+ dtype=time_decay.dtype,
|
|
|
|
|
+ device=time_decay.device,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
|
|
|
|
|
+ grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
|
|
|
|
|
+
|
|
|
|
|
+ wkv_kernel_decoder.backward(
|
|
|
|
|
+ time_decay,
|
|
|
|
|
+ time_first,
|
|
|
|
|
+ key,
|
|
|
|
|
+ value,
|
|
|
|
|
+ output,
|
|
|
|
|
+ grad_output.contiguous(),
|
|
|
|
|
+ grad_time_decay,
|
|
|
|
|
+ grad_time_first,
|
|
|
|
|
+ grad_key,
|
|
|
|
|
+ grad_value,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ grad_time_decay = torch.sum(grad_time_decay, dim=0)
|
|
|
|
|
+ grad_time_first = torch.sum(grad_time_first, dim=0)
|
|
|
|
|
+
|
|
|
|
|
+ return (
|
|
|
|
|
+ grad_time_decay,
|
|
|
|
|
+ grad_time_first,
|
|
|
|
|
+ grad_key,
|
|
|
|
|
+ grad_value,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+def load_encoder_wkv_kernel(context_size: int) -> None:
|
|
|
|
|
+ """Load WKV CUDA kernel.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ context_size: Context size.
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ from torch.utils.cpp_extension import load
|
|
|
|
|
+
|
|
|
|
|
+ global wkv_kernel_encoder
|
|
|
|
|
+
|
|
|
|
|
+ if wkv_kernel_encoder is not None and wkv_kernel_encoder.context_size == context_size:
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ if find_spec("ninja") is None:
|
|
|
|
|
+ raise ImportError(
|
|
|
|
|
+ "Ninja package was not found. WKV kernel module can't be loaded "
|
|
|
|
|
+ "for training. Please, 'pip install ninja' in your environment."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if not torch.cuda.is_available():
|
|
|
|
|
+ raise ImportError(
|
|
|
|
|
+ "CUDA is currently a requirement for WKV kernel loading. "
|
|
|
|
|
+ "Please set your devices properly and launch again."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ kernel_folder = Path(__file__).resolve().parent / "cuda_encoder"
|
|
|
|
|
+ kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
|
|
|
|
|
+
|
|
|
|
|
+ kernel_cflags = [
|
|
|
|
|
+ "-res-usage",
|
|
|
|
|
+ "--maxrregcount 60",
|
|
|
|
|
+ "--use_fast_math",
|
|
|
|
|
+ "-O3",
|
|
|
|
|
+ "-Xptxas -O3",
|
|
|
|
|
+ f"-DTmax={context_size}",
|
|
|
|
|
+ ]
|
|
|
|
|
+ wkv_kernel_encoder = load(
|
|
|
|
|
+ name=f"encoder_wkv_{context_size}",
|
|
|
|
|
+ sources=kernel_files,
|
|
|
|
|
+ verbose=True,
|
|
|
|
|
+ extra_cuda_cflags=kernel_cflags,
|
|
|
|
|
+ )
|
|
|
|
|
+ wkv_kernel_encoder.context_size = context_size
|
|
|
|
|
+
|
|
|
|
|
+def load_decoder_wkv_kernel(context_size: int) -> None:
|
|
|
|
|
+ """Load WKV CUDA kernel.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ context_size: Context size.
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ from torch.utils.cpp_extension import load
|
|
|
|
|
+
|
|
|
|
|
+ global wkv_kernel_decoder
|
|
|
|
|
+
|
|
|
|
|
+ if wkv_kernel_decoder is not None and wkv_kernel_decoder.context_size == context_size:
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ if find_spec("ninja") is None:
|
|
|
|
|
+ raise ImportError(
|
|
|
|
|
+ "Ninja package was not found. WKV kernel module can't be loaded "
|
|
|
|
|
+ "for training. Please, 'pip install ninja' in your environment."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if not torch.cuda.is_available():
|
|
|
|
|
+ raise ImportError(
|
|
|
|
|
+ "CUDA is currently a requirement for WKV kernel loading. "
|
|
|
|
|
+ "Please set your devices properly and launch again."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ kernel_folder = Path(__file__).resolve().parent / "cuda_decoder"
|
|
|
|
|
+ kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
|
|
|
|
|
+
|
|
|
|
|
+ kernel_cflags = [
|
|
|
|
|
+ "-res-usage",
|
|
|
|
|
+ "--maxrregcount 60",
|
|
|
|
|
+ "--use_fast_math",
|
|
|
|
|
+ "-O3",
|
|
|
|
|
+ "-Xptxas -O3",
|
|
|
|
|
+ f"-DTmax={context_size}",
|
|
|
|
|
+ ]
|
|
|
|
|
+ wkv_kernel_decoder = load(
|
|
|
|
|
+ name=f"decoder_wkv_{context_size}",
|
|
|
|
|
+ sources=kernel_files,
|
|
|
|
|
+ verbose=True,
|
|
|
|
|
+ extra_cuda_cflags=kernel_cflags,
|
|
|
|
|
+ )
|
|
|
|
|
+ wkv_kernel_decoder.context_size = context_size
|
|
|
|
|
+
|
|
|
|
|
+class SelfAttention(torch.nn.Module):
|
|
|
|
|
+ """SelfAttention module definition.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ size: Input/Output size.
|
|
|
|
|
+ attention_size: Attention hidden size.
|
|
|
|
|
+ context_size: Context size for WKV kernel.
|
|
|
|
|
+ block_id: Block index.
|
|
|
|
|
+ num_blocks: Number of blocks in the architecture.
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ size: int,
|
|
|
|
|
+ attention_size: int,
|
|
|
|
|
+ block_id: int,
|
|
|
|
|
+ dropout_rate: float,
|
|
|
|
|
+ num_blocks: int,
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ """Construct a SelfAttention object."""
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
|
+
|
|
|
|
|
+ self.time_decay = torch.nn.Parameter(torch.empty(attention_size))
|
|
|
|
|
+ self.time_first = torch.nn.Parameter(torch.empty(attention_size))
|
|
|
|
|
+
|
|
|
|
|
+ self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
|
|
|
|
|
+ self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size))
|
|
|
|
|
+ self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
|
|
|
|
|
+
|
|
|
|
|
+ self.proj_key = torch.nn.Linear(size, attention_size, bias=True)
|
|
|
|
|
+ self.proj_value = torch.nn.Linear(size, attention_size, bias=True)
|
|
|
|
|
+ self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True)
|
|
|
|
|
+
|
|
|
|
|
+ self.proj_output = torch.nn.Linear(attention_size, size, bias=True)
|
|
|
|
|
+
|
|
|
|
|
+ self.block_id = block_id
|
|
|
|
|
+
|
|
|
|
|
+ self.reset_parameters(size, attention_size, block_id, num_blocks)
|
|
|
|
|
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
|
|
|
|
|
+
|
|
|
|
|
+ def reset_parameters(
|
|
|
|
|
+ self, size: int, attention_size: int, block_id: int, num_blocks: int
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ """Reset module parameters.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ size: Block size.
|
|
|
|
|
+ attention_size: Attention hidden size.
|
|
|
|
|
+ block_id: Block index.
|
|
|
|
|
+ num_blocks: Number of blocks in the architecture.
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ ratio_0_to_1 = block_id / (num_blocks - 1)
|
|
|
|
|
+ ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
|
|
|
|
|
+
|
|
|
|
|
+ time_weight = torch.ones(1, 1, size)
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(size):
|
|
|
|
|
+ time_weight[0, 0, i] = i / size
|
|
|
|
|
+
|
|
|
|
|
+ decay_speed = [
|
|
|
|
|
+ -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
|
|
|
|
+ for h in range(attention_size)
|
|
|
|
|
+ ]
|
|
|
|
|
+ decay_speed = torch.tensor(
|
|
|
|
|
+ decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ zigzag = (
|
|
|
|
|
+ torch.tensor(
|
|
|
|
|
+ [(i + 1) % 3 - 1 for i in range(attention_size)],
|
|
|
|
|
+ dtype=self.time_first.dtype,
|
|
|
|
|
+ device=self.time_first.device,
|
|
|
|
|
+ )
|
|
|
|
|
+ * 0.5
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ self.time_decay.data = decay_speed
|
|
|
|
|
+ self.time_first.data = torch.ones_like(
|
|
|
|
|
+ self.time_first * math.log(0.3) + zigzag
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
|
|
|
|
+ self.time_mix_value.data = (
|
|
|
|
|
+ torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
|
|
|
|
+ )
|
|
|
|
|
+ self.time_mix_receptance.data = torch.pow(
|
|
|
|
|
+ time_weight, 0.5 * ratio_1_to_almost0
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ @torch.no_grad()
|
|
|
|
|
+ def wkv_linear_attention(
|
|
|
|
|
+ self,
|
|
|
|
|
+ time_decay: torch.Tensor,
|
|
|
|
|
+ time_first: torch.Tensor,
|
|
|
|
|
+ key: torch.Tensor,
|
|
|
|
|
+ value: torch.Tensor,
|
|
|
|
|
+ state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
|
|
|
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
|
|
|
+ """Compute WKV with state (i.e.: for inference).
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ time_decay: Channel-wise time decay vector. (D_att)
|
|
|
|
|
+ time_first: Channel-wise time first vector. (D_att)
|
|
|
|
|
+ key: Key tensor. (B, 1, D_att)
|
|
|
|
|
+ value: Value tensor. (B, 1, D_att)
|
|
|
|
|
+ state: Decoder hidden states. [3 x (B, D_att)]
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ output: Weighted Key-Value. (B, 1, D_att)
|
|
|
|
|
+ state: Decoder hidden states. [3 x (B, 1, D_att)]
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ num_state, den_state, max_state = state
|
|
|
|
|
+
|
|
|
|
|
+ max_for_output = torch.maximum(max_state, (time_first + key))
|
|
|
|
|
+
|
|
|
|
|
+ e1 = torch.exp(max_state - max_for_output)
|
|
|
|
|
+ e2 = torch.exp((time_first + key) - max_for_output)
|
|
|
|
|
+
|
|
|
|
|
+ numerator = e1 * num_state + e2 * value
|
|
|
|
|
+ denominator = e1 * den_state + e2
|
|
|
|
|
+
|
|
|
|
|
+ max_for_state = torch.maximum(key, (max_state + time_decay))
|
|
|
|
|
+
|
|
|
|
|
+ e1 = torch.exp((max_state + time_decay) - max_for_state)
|
|
|
|
|
+ e2 = torch.exp(key - max_for_state)
|
|
|
|
|
+
|
|
|
|
|
+ wkv = numerator / denominator
|
|
|
|
|
+
|
|
|
|
|
+ state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state]
|
|
|
|
|
+
|
|
|
|
|
+ return wkv, state
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DecoderSelfAttention(SelfAttention):
|
|
|
|
|
+ """SelfAttention module definition.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ size: Input/Output size.
|
|
|
|
|
+ attention_size: Attention hidden size.
|
|
|
|
|
+ context_size: Context size for WKV kernel.
|
|
|
|
|
+ block_id: Block index.
|
|
|
|
|
+ num_blocks: Number of blocks in the architecture.
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ size: int,
|
|
|
|
|
+ attention_size: int,
|
|
|
|
|
+ context_size: int,
|
|
|
|
|
+ block_id: int,
|
|
|
|
|
+ dropout_rate: float,
|
|
|
|
|
+ num_blocks: int,
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ """Construct a SelfAttention object."""
|
|
|
|
|
+ super().__init__(
|
|
|
|
|
+ size,
|
|
|
|
|
+ attention_size,
|
|
|
|
|
+ block_id,
|
|
|
|
|
+ dropout_rate,
|
|
|
|
|
+ num_blocks
|
|
|
|
|
+ )
|
|
|
|
|
+ load_decoder_wkv_kernel(context_size)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: torch.Tensor,
|
|
|
|
|
+ state: Optional[List[torch.Tensor]] = None,
|
|
|
|
|
+ ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
|
|
|
|
+ """Compute time mixing.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ x: SelfAttention input sequences. (B, U, size)
|
|
|
|
|
+ state: Decoder hidden states. [5 x (B, 1, D_att, N)]
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ x: SelfAttention output sequences. (B, U, size)
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ shifted_x = (
|
|
|
|
|
+ self.time_shift(x) if state is None else state[1][..., self.block_id]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
|
|
|
|
|
+ value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
|
|
|
|
|
+ receptance = x * self.time_mix_receptance + shifted_x * (
|
|
|
|
|
+ 1 - self.time_mix_receptance
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ key = self.proj_key(key)
|
|
|
|
|
+ value = self.proj_value(value)
|
|
|
|
|
+ receptance = torch.sigmoid(self.proj_receptance(receptance))
|
|
|
|
|
+
|
|
|
|
|
+ if state is not None:
|
|
|
|
|
+ state[1][..., self.block_id] = x
|
|
|
|
|
+
|
|
|
|
|
+ wkv, att_state = self.wkv_linear_attention(
|
|
|
|
|
+ self.time_decay,
|
|
|
|
|
+ self.time_first,
|
|
|
|
|
+ key,
|
|
|
|
|
+ value,
|
|
|
|
|
+ tuple(s[..., self.block_id] for s in state[2:]),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ state[2][..., self.block_id] = att_state[0]
|
|
|
|
|
+ state[3][..., self.block_id] = att_state[1]
|
|
|
|
|
+ state[4][..., self.block_id] = att_state[2]
|
|
|
|
|
+ else:
|
|
|
|
|
+ wkv = WKVLinearAttentionDecoder.apply(self.time_decay, self.time_first, key, value)
|
|
|
|
|
+
|
|
|
|
|
+ wkv = self.dropout(wkv)
|
|
|
|
|
+ x = self.proj_output(receptance * wkv)
|
|
|
|
|
+
|
|
|
|
|
+ return x, state
|
|
|
|
|
+
|
|
|
|
|
+class EncoderSelfAttention(SelfAttention):
|
|
|
|
|
+ """SelfAttention module definition.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ size: Input/Output size.
|
|
|
|
|
+ attention_size: Attention hidden size.
|
|
|
|
|
+ context_size: Context size for WKV kernel.
|
|
|
|
|
+ block_id: Block index.
|
|
|
|
|
+ num_blocks: Number of blocks in the architecture.
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ size: int,
|
|
|
|
|
+ attention_size: int,
|
|
|
|
|
+ context_size: int,
|
|
|
|
|
+ block_id: int,
|
|
|
|
|
+ dropout_rate: float,
|
|
|
|
|
+ num_blocks: int,
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ """Construct a SelfAttention object."""
|
|
|
|
|
+ super().__init__(
|
|
|
|
|
+ size,
|
|
|
|
|
+ attention_size,
|
|
|
|
|
+ block_id,
|
|
|
|
|
+ dropout_rate,
|
|
|
|
|
+ num_blocks
|
|
|
|
|
+ )
|
|
|
|
|
+ load_encoder_wkv_kernel(context_size)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: torch.Tensor,
|
|
|
|
|
+ state: Optional[List[torch.Tensor]] = None,
|
|
|
|
|
+ ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
|
|
|
|
+ """Compute time mixing.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ x: SelfAttention input sequences. (B, U, size)
|
|
|
|
|
+ state: Decoder hidden states. [5 x (B, 1, D_att, N)]
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ x: SelfAttention output sequences. (B, U, size)
|
|
|
|
|
+
|
|
|
|
|
+ """
|
|
|
|
|
+ shifted_x = (
|
|
|
|
|
+ self.time_shift(x) if state is None else state[1][..., self.block_id]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
|
|
|
|
|
+ value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
|
|
|
|
|
+ receptance = x * self.time_mix_receptance + shifted_x * (
|
|
|
|
|
+ 1 - self.time_mix_receptance
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ key = self.proj_key(key)
|
|
|
|
|
+ value = self.proj_value(value)
|
|
|
|
|
+ receptance = torch.sigmoid(self.proj_receptance(receptance))
|
|
|
|
|
+
|
|
|
|
|
+ if state is not None:
|
|
|
|
|
+ state[1][..., self.block_id] = x
|
|
|
|
|
+
|
|
|
|
|
+ wkv, att_state = self.wkv_linear_attention(
|
|
|
|
|
+ self.time_decay,
|
|
|
|
|
+ self.time_first,
|
|
|
|
|
+ key,
|
|
|
|
|
+ value,
|
|
|
|
|
+ tuple(s[..., self.block_id] for s in state[2:]),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ state[2][..., self.block_id] = att_state[0]
|
|
|
|
|
+ state[3][..., self.block_id] = att_state[1]
|
|
|
|
|
+ state[4][..., self.block_id] = att_state[2]
|
|
|
|
|
+ else:
|
|
|
|
|
+ wkv = WKVLinearAttentionEncoder.apply(self.time_decay, self.time_first, key, value)
|
|
|
|
|
+
|
|
|
|
|
+ wkv = self.dropout(wkv)
|
|
|
|
|
+ x = self.proj_output(receptance * wkv)
|
|
|
|
|
+
|
|
|
|
|
+ return x, state
|
|
|
|
|
+
|