rwkv_feed_forward.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """Feed-forward (channel mixing) module for RWKV block.
  2. Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py
  3. Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
  4. """ # noqa
  5. from typing import List, Optional, Tuple
  6. import torch
  7. class FeedForward(torch.nn.Module):
  8. """FeedForward module definition.
  9. Args:
  10. size: Input/Output size.
  11. hidden_size: Hidden size.
  12. block_id: Block index.
  13. num_blocks: Number of blocks in the architecture.
  14. """
  15. def __init__(
  16. self, size: int, hidden_size: int, block_id: int, dropout_rate: float, num_blocks: int
  17. ) -> None:
  18. """Construct a FeedForward object."""
  19. super().__init__()
  20. self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
  21. self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
  22. self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
  23. self.proj_key = torch.nn.Linear(size, hidden_size, bias=True)
  24. self.proj_value = torch.nn.Linear(hidden_size, size, bias=True)
  25. self.proj_receptance = torch.nn.Linear(size, size, bias=True)
  26. self.block_id = block_id
  27. self.reset_parameters(size, block_id, num_blocks)
  28. self.dropout = torch.nn.Dropout(p=dropout_rate)
  29. def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None:
  30. """Reset module parameters.
  31. Args:
  32. size: Block size.
  33. block_id: Block index.
  34. num_blocks: Number of blocks in the architecture.
  35. """
  36. ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
  37. time_weight = torch.ones(1, 1, size)
  38. for i in range(size):
  39. time_weight[0, 0, i] = i / size
  40. with torch.no_grad():
  41. self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
  42. self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
  43. def forward(
  44. self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None
  45. ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
  46. """Compute channel mixing.
  47. Args:
  48. x: FeedForward input sequences. (B, U, size)
  49. state: Decoder hidden state. [5 x (B, 1, size, N)]
  50. Returns:
  51. x: FeedForward output sequences. (B, U, size)
  52. state: Decoder hidden state. [5 x (B, 1, size, N)]
  53. """
  54. shifted_x = (
  55. self.time_shift(x) if state is None else state[0][..., self.block_id]
  56. )
  57. key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
  58. receptance = x * self.time_mix_receptance + shifted_x * (
  59. 1 - self.time_mix_receptance
  60. )
  61. key = torch.square(torch.relu(self.proj_key(key)))
  62. value = self.proj_value(self.dropout(key))
  63. receptance = torch.sigmoid(self.proj_receptance(receptance))
  64. if state is not None:
  65. state[0][..., self.block_id] = x
  66. x = receptance * value
  67. return x, state