| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632 |
- """Attention (time mixing) modules for RWKV block.
- Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py.
- Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
- """ # noqa
- import math
- from importlib.util import find_spec
- from pathlib import Path
- from typing import List, Optional, Tuple, Union
- import torch
- wkv_kernel_encoder = None
- wkv_kernel_decoder = None
- class WKVLinearAttentionEncoder(torch.autograd.Function):
- """WKVLinearAttention function definition."""
- @staticmethod
- def forward(
- ctx,
- time_decay: torch.Tensor,
- time_first: torch.Tensor,
- key: torch.Tensor,
- value: torch.tensor,
- ) -> torch.Tensor:
- """WKVLinearAttention function forward pass.
- Args:
- time_decay: Channel-wise time decay vector. (D_att)
- time_first: Channel-wise time first vector. (D_att)
- key: Key tensor. (B, U, D_att)
- value: Value tensor. (B, U, D_att)
- Returns:
- out: Weighted Key-Value tensor. (B, U, D_att)
- """
- batch, length, dim = key.size()
- assert length <= wkv_kernel_encoder.context_size, (
- f"Cannot process key of length {length} while context_size "
- f"is ({wkv_kernel_encoder.context_size}). Limit should be increased."
- )
- assert batch * dim % min(dim, 32) == 0, (
- f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
- f"{min(dim, 32)}"
- )
- ctx.input_dtype = key.dtype
- time_decay = -torch.exp(time_decay.float().contiguous())
- time_first = time_first.float().contiguous()
- key = key.float().contiguous()
- value = value.float().contiguous()
- out = torch.empty_like(key, memory_format=torch.contiguous_format)
- wkv_kernel_encoder.forward(time_decay, time_first, key, value, out)
- ctx.save_for_backward(time_decay, time_first, key, value, out)
- return out
- @staticmethod
- def backward(
- ctx, grad_output: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """WKVLinearAttention function backward pass.
- Args:
- grad_output: Output gradient. (B, U, D_att)
- Returns:
- grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
- grad_time_first: Gradient for channel-wise time first vector. (D_att)
- grad_key: Gradient for key tensor. (B, U, D_att)
- grad_value: Gradient for value tensor. (B, U, D_att)
- """
- time_decay, time_first, key, value, output = ctx.saved_tensors
- grad_dtype = ctx.input_dtype
- batch, _, dim = key.size()
- grad_time_decay = torch.empty(
- (batch, dim),
- memory_format=torch.contiguous_format,
- dtype=time_decay.dtype,
- device=time_decay.device,
- )
- grad_time_first = torch.empty(
- (batch, dim),
- memory_format=torch.contiguous_format,
- dtype=time_decay.dtype,
- device=time_decay.device,
- )
- grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
- grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
- wkv_kernel_encoder.backward(
- time_decay,
- time_first,
- key,
- value,
- output,
- grad_output.contiguous(),
- grad_time_decay,
- grad_time_first,
- grad_key,
- grad_value,
- )
- grad_time_decay = torch.sum(grad_time_decay, dim=0)
- grad_time_first = torch.sum(grad_time_first, dim=0)
- return (
- grad_time_decay,
- grad_time_first,
- grad_key,
- grad_value,
- )
- class WKVLinearAttentionDecoder(torch.autograd.Function):
- """WKVLinearAttention function definition."""
- @staticmethod
- def forward(
- ctx,
- time_decay: torch.Tensor,
- time_first: torch.Tensor,
- key: torch.Tensor,
- value: torch.tensor,
- ) -> torch.Tensor:
- """WKVLinearAttention function forward pass.
- Args:
- time_decay: Channel-wise time decay vector. (D_att)
- time_first: Channel-wise time first vector. (D_att)
- key: Key tensor. (B, U, D_att)
- value: Value tensor. (B, U, D_att)
- Returns:
- out: Weighted Key-Value tensor. (B, U, D_att)
- """
- batch, length, dim = key.size()
- assert length <= wkv_kernel_decoder.context_size, (
- f"Cannot process key of length {length} while context_size "
- f"is ({wkv_kernel.context_size}). Limit should be increased."
- )
- assert batch * dim % min(dim, 32) == 0, (
- f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
- f"{min(dim, 32)}"
- )
- ctx.input_dtype = key.dtype
- time_decay = -torch.exp(time_decay.float().contiguous())
- time_first = time_first.float().contiguous()
- key = key.float().contiguous()
- value = value.float().contiguous()
- out = torch.empty_like(key, memory_format=torch.contiguous_format)
- wkv_kernel_decoder.forward(time_decay, time_first, key, value, out)
- ctx.save_for_backward(time_decay, time_first, key, value, out)
- return out
- @staticmethod
- def backward(
- ctx, grad_output: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """WKVLinearAttention function backward pass.
- Args:
- grad_output: Output gradient. (B, U, D_att)
- Returns:
- grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
- grad_time_first: Gradient for channel-wise time first vector. (D_att)
- grad_key: Gradient for key tensor. (B, U, D_att)
- grad_value: Gradient for value tensor. (B, U, D_att)
- """
- time_decay, time_first, key, value, output = ctx.saved_tensors
- grad_dtype = ctx.input_dtype
- batch, _, dim = key.size()
- grad_time_decay = torch.empty(
- (batch, dim),
- memory_format=torch.contiguous_format,
- dtype=time_decay.dtype,
- device=time_decay.device,
- )
- grad_time_first = torch.empty(
- (batch, dim),
- memory_format=torch.contiguous_format,
- dtype=time_decay.dtype,
- device=time_decay.device,
- )
- grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
- grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
- wkv_kernel_decoder.backward(
- time_decay,
- time_first,
- key,
- value,
- output,
- grad_output.contiguous(),
- grad_time_decay,
- grad_time_first,
- grad_key,
- grad_value,
- )
- grad_time_decay = torch.sum(grad_time_decay, dim=0)
- grad_time_first = torch.sum(grad_time_first, dim=0)
- return (
- grad_time_decay,
- grad_time_first,
- grad_key,
- grad_value,
- )
- def load_encoder_wkv_kernel(context_size: int) -> None:
- """Load WKV CUDA kernel.
- Args:
- context_size: Context size.
- """
- from torch.utils.cpp_extension import load
- global wkv_kernel_encoder
- if wkv_kernel_encoder is not None and wkv_kernel_encoder.context_size == context_size:
- return
- if find_spec("ninja") is None:
- raise ImportError(
- "Ninja package was not found. WKV kernel module can't be loaded "
- "for training. Please, 'pip install ninja' in your environment."
- )
- if not torch.cuda.is_available():
- raise ImportError(
- "CUDA is currently a requirement for WKV kernel loading. "
- "Please set your devices properly and launch again."
- )
- kernel_folder = Path(__file__).resolve().parent / "cuda_encoder"
- kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
- kernel_cflags = [
- "-res-usage",
- "--maxrregcount 60",
- "--use_fast_math",
- "-O3",
- "-Xptxas -O3",
- f"-DTmax={context_size}",
- ]
- wkv_kernel_encoder = load(
- name=f"encoder_wkv_{context_size}",
- sources=kernel_files,
- verbose=True,
- extra_cuda_cflags=kernel_cflags,
- )
- wkv_kernel_encoder.context_size = context_size
- def load_decoder_wkv_kernel(context_size: int) -> None:
- """Load WKV CUDA kernel.
- Args:
- context_size: Context size.
- """
- from torch.utils.cpp_extension import load
- global wkv_kernel_decoder
- if wkv_kernel_decoder is not None and wkv_kernel_decoder.context_size == context_size:
- return
- if find_spec("ninja") is None:
- raise ImportError(
- "Ninja package was not found. WKV kernel module can't be loaded "
- "for training. Please, 'pip install ninja' in your environment."
- )
- if not torch.cuda.is_available():
- raise ImportError(
- "CUDA is currently a requirement for WKV kernel loading. "
- "Please set your devices properly and launch again."
- )
- kernel_folder = Path(__file__).resolve().parent / "cuda_decoder"
- kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
- kernel_cflags = [
- "-res-usage",
- "--maxrregcount 60",
- "--use_fast_math",
- "-O3",
- "-Xptxas -O3",
- f"-DTmax={context_size}",
- ]
- wkv_kernel_decoder = load(
- name=f"decoder_wkv_{context_size}",
- sources=kernel_files,
- verbose=True,
- extra_cuda_cflags=kernel_cflags,
- )
- wkv_kernel_decoder.context_size = context_size
- class SelfAttention(torch.nn.Module):
- """SelfAttention module definition.
- Args:
- size: Input/Output size.
- attention_size: Attention hidden size.
- context_size: Context size for WKV kernel.
- block_id: Block index.
- num_blocks: Number of blocks in the architecture.
- """
- def __init__(
- self,
- size: int,
- attention_size: int,
- block_id: int,
- dropout_rate: float,
- num_blocks: int,
- ) -> None:
- """Construct a SelfAttention object."""
- super().__init__()
- self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
- self.time_decay = torch.nn.Parameter(torch.empty(attention_size))
- self.time_first = torch.nn.Parameter(torch.empty(attention_size))
- self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
- self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size))
- self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
- self.proj_key = torch.nn.Linear(size, attention_size, bias=True)
- self.proj_value = torch.nn.Linear(size, attention_size, bias=True)
- self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True)
- self.proj_output = torch.nn.Linear(attention_size, size, bias=True)
- self.block_id = block_id
- self.reset_parameters(size, attention_size, block_id, num_blocks)
- self.dropout = torch.nn.Dropout(p=dropout_rate)
- def reset_parameters(
- self, size: int, attention_size: int, block_id: int, num_blocks: int
- ) -> None:
- """Reset module parameters.
- Args:
- size: Block size.
- attention_size: Attention hidden size.
- block_id: Block index.
- num_blocks: Number of blocks in the architecture.
- """
- ratio_0_to_1 = block_id / (num_blocks - 1)
- ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
- time_weight = torch.ones(1, 1, size)
- for i in range(size):
- time_weight[0, 0, i] = i / size
- decay_speed = [
- -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
- for h in range(attention_size)
- ]
- decay_speed = torch.tensor(
- decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device
- )
- zigzag = (
- torch.tensor(
- [(i + 1) % 3 - 1 for i in range(attention_size)],
- dtype=self.time_first.dtype,
- device=self.time_first.device,
- )
- * 0.5
- )
- with torch.no_grad():
- self.time_decay.data = decay_speed
- self.time_first.data = torch.ones_like(
- self.time_first * math.log(0.3) + zigzag
- )
- self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
- self.time_mix_value.data = (
- torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
- )
- self.time_mix_receptance.data = torch.pow(
- time_weight, 0.5 * ratio_1_to_almost0
- )
- @torch.no_grad()
- def wkv_linear_attention(
- self,
- time_decay: torch.Tensor,
- time_first: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
- """Compute WKV with state (i.e.: for inference).
- Args:
- time_decay: Channel-wise time decay vector. (D_att)
- time_first: Channel-wise time first vector. (D_att)
- key: Key tensor. (B, 1, D_att)
- value: Value tensor. (B, 1, D_att)
- state: Decoder hidden states. [3 x (B, D_att)]
- Returns:
- output: Weighted Key-Value. (B, 1, D_att)
- state: Decoder hidden states. [3 x (B, 1, D_att)]
- """
- num_state, den_state, max_state = state
- max_for_output = torch.maximum(max_state, (time_first + key))
- e1 = torch.exp(max_state - max_for_output)
- e2 = torch.exp((time_first + key) - max_for_output)
- numerator = e1 * num_state + e2 * value
- denominator = e1 * den_state + e2
- max_for_state = torch.maximum(key, (max_state + time_decay))
- e1 = torch.exp((max_state + time_decay) - max_for_state)
- e2 = torch.exp(key - max_for_state)
- wkv = numerator / denominator
- state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state]
- return wkv, state
- class DecoderSelfAttention(SelfAttention):
- """SelfAttention module definition.
- Args:
- size: Input/Output size.
- attention_size: Attention hidden size.
- context_size: Context size for WKV kernel.
- block_id: Block index.
- num_blocks: Number of blocks in the architecture.
- """
- def __init__(
- self,
- size: int,
- attention_size: int,
- context_size: int,
- block_id: int,
- dropout_rate: float,
- num_blocks: int,
- ) -> None:
- """Construct a SelfAttention object."""
- super().__init__(
- size,
- attention_size,
- block_id,
- dropout_rate,
- num_blocks
- )
- load_decoder_wkv_kernel(context_size)
- def forward(
- self,
- x: torch.Tensor,
- state: Optional[List[torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
- """Compute time mixing.
- Args:
- x: SelfAttention input sequences. (B, U, size)
- state: Decoder hidden states. [5 x (B, 1, D_att, N)]
- Returns:
- x: SelfAttention output sequences. (B, U, size)
- """
- shifted_x = (
- self.time_shift(x) if state is None else state[1][..., self.block_id]
- )
- key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
- value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
- receptance = x * self.time_mix_receptance + shifted_x * (
- 1 - self.time_mix_receptance
- )
- key = self.proj_key(key)
- value = self.proj_value(value)
- receptance = torch.sigmoid(self.proj_receptance(receptance))
- if state is not None:
- state[1][..., self.block_id] = x
- wkv, att_state = self.wkv_linear_attention(
- self.time_decay,
- self.time_first,
- key,
- value,
- tuple(s[..., self.block_id] for s in state[2:]),
- )
- state[2][..., self.block_id] = att_state[0]
- state[3][..., self.block_id] = att_state[1]
- state[4][..., self.block_id] = att_state[2]
- else:
- wkv = WKVLinearAttentionDecoder.apply(self.time_decay, self.time_first, key, value)
- wkv = self.dropout(wkv)
- x = self.proj_output(receptance * wkv)
- return x, state
- class EncoderSelfAttention(SelfAttention):
- """SelfAttention module definition.
- Args:
- size: Input/Output size.
- attention_size: Attention hidden size.
- context_size: Context size for WKV kernel.
- block_id: Block index.
- num_blocks: Number of blocks in the architecture.
- """
- def __init__(
- self,
- size: int,
- attention_size: int,
- context_size: int,
- block_id: int,
- dropout_rate: float,
- num_blocks: int,
- ) -> None:
- """Construct a SelfAttention object."""
- super().__init__(
- size,
- attention_size,
- block_id,
- dropout_rate,
- num_blocks
- )
- load_encoder_wkv_kernel(context_size)
- def forward(
- self,
- x: torch.Tensor,
- state: Optional[List[torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
- """Compute time mixing.
- Args:
- x: SelfAttention input sequences. (B, U, size)
- state: Decoder hidden states. [5 x (B, 1, D_att, N)]
- Returns:
- x: SelfAttention output sequences. (B, U, size)
- """
- shifted_x = (
- self.time_shift(x) if state is None else state[1][..., self.block_id]
- )
- key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
- value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
- receptance = x * self.time_mix_receptance + shifted_x * (
- 1 - self.time_mix_receptance
- )
- key = self.proj_key(key)
- value = self.proj_value(value)
- receptance = torch.sigmoid(self.proj_receptance(receptance))
- if state is not None:
- state[1][..., self.block_id] = x
- wkv, att_state = self.wkv_linear_attention(
- self.time_decay,
- self.time_first,
- key,
- value,
- tuple(s[..., self.block_id] for s in state[2:]),
- )
- state[2][..., self.block_id] = att_state[0]
- state[3][..., self.block_id] = att_state[1]
- state[4][..., self.block_id] = att_state[2]
- else:
- wkv = WKVLinearAttentionEncoder.apply(self.time_decay, self.time_first, key, value)
- wkv = self.dropout(wkv)
- x = self.proj_output(receptance * wkv)
- return x, state
|