rwkv.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. """Receptance Weighted Key Value (RWKV) block definition.
  2. Based/modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py
  3. """
  4. from typing import Dict, Optional, Tuple
  5. import torch
  6. from funasr.modules.rwkv_attention import EncoderSelfAttention, DecoderSelfAttention
  7. from funasr.modules.rwkv_feed_forward import FeedForward
  8. from funasr.modules.layer_norm import LayerNorm
  9. class RWKV(torch.nn.Module):
  10. """RWKV module.
  11. Args:
  12. size: Input/Output size.
  13. linear_size: Feed-forward hidden size.
  14. attention_size: SelfAttention hidden size.
  15. context_size: Context size for WKV computation.
  16. block_id: Block index.
  17. num_blocks: Number of blocks in the architecture.
  18. normalization_class: Normalization layer class.
  19. normalization_args: Normalization layer arguments.
  20. att_dropout_rate: Dropout rate for the attention module.
  21. ffn_dropout_rate: Dropout rate for the feed-forward module.
  22. """
  23. def __init__(
  24. self,
  25. size: int,
  26. linear_size: int,
  27. attention_size: int,
  28. context_size: int,
  29. block_id: int,
  30. num_blocks: int,
  31. att_dropout_rate: float = 0.0,
  32. ffn_dropout_rate: float = 0.0,
  33. dropout_rate: float = 0.0,
  34. ) -> None:
  35. """Construct a RWKV object."""
  36. super().__init__()
  37. self.layer_norm_att = LayerNorm(size)
  38. self.layer_norm_ffn = LayerNorm(size)
  39. self.att = EncoderSelfAttention(
  40. size, attention_size, context_size, block_id, att_dropout_rate, num_blocks
  41. )
  42. self.dropout_att = torch.nn.Dropout(p=dropout_rate)
  43. self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks)
  44. self.dropout_ffn = torch.nn.Dropout(p=dropout_rate)
  45. def forward(
  46. self,
  47. x: torch.Tensor,
  48. state: Optional[torch.Tensor] = None,
  49. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  50. """Compute receptance weighted key value.
  51. Args:
  52. x: RWKV input sequences. (B, L, size)
  53. state: Decoder hidden states. [5 x (B, D_att/size, N)]
  54. Returns:
  55. x: RWKV output sequences. (B, L, size)
  56. x: Decoder hidden states. [5 x (B, D_att/size, N)]
  57. """
  58. att, state = self.att(self.layer_norm_att(x), state=state)
  59. x = x + self.dropout_att(att)
  60. ffn, state = self.ffn(self.layer_norm_ffn(x), state=state)
  61. x = x + self.dropout_ffn(ffn)
  62. return x, state
  63. class RWKVDecoderLayer(torch.nn.Module):
  64. """RWKV module.
  65. Args:
  66. size: Input/Output size.
  67. linear_size: Feed-forward hidden size.
  68. attention_size: SelfAttention hidden size.
  69. context_size: Context size for WKV computation.
  70. block_id: Block index.
  71. num_blocks: Number of blocks in the architecture.
  72. normalization_class: Normalization layer class.
  73. normalization_args: Normalization layer arguments.
  74. att_dropout_rate: Dropout rate for the attention module.
  75. ffn_dropout_rate: Dropout rate for the feed-forward module.
  76. """
  77. def __init__(
  78. self,
  79. size: int,
  80. linear_size: int,
  81. attention_size: int,
  82. context_size: int,
  83. block_id: int,
  84. num_blocks: int,
  85. att_dropout_rate: float = 0.0,
  86. ffn_dropout_rate: float = 0.0,
  87. dropout_rate: float = 0.0,
  88. ) -> None:
  89. """Construct a RWKV object."""
  90. super().__init__()
  91. self.layer_norm_att = LayerNorm(size)
  92. self.layer_norm_ffn = LayerNorm(size)
  93. self.att = DecoderSelfAttention(
  94. size, attention_size, context_size, block_id, att_dropout_rate, num_blocks
  95. )
  96. self.dropout_att = torch.nn.Dropout(p=dropout_rate)
  97. self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks)
  98. self.dropout_ffn = torch.nn.Dropout(p=dropout_rate)
  99. def forward(
  100. self,
  101. x: torch.Tensor,
  102. state: Optional[torch.Tensor] = None,
  103. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  104. """Compute receptance weighted key value.
  105. Args:
  106. x: RWKV input sequences. (B, L, size)
  107. state: Decoder hidden states. [5 x (B, D_att/size, N)]
  108. Returns:
  109. x: RWKV output sequences. (B, L, size)
  110. x: Decoder hidden states. [5 x (B, D_att/size, N)]
  111. """
  112. att, state = self.att(self.layer_norm_att(x), state=state)
  113. x = x + self.dropout_att(att)
  114. ffn, state = self.ffn(self.layer_norm_ffn(x), state=state)
  115. x = x + self.dropout_ffn(ffn)
  116. return x, state