rwkv_encoder.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. """RWKV encoder definition for Transducer models."""
  2. import math
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. from funasr.models.encoder.abs_encoder import AbsEncoder
  6. from funasr.modules.rwkv import RWKV
  7. from funasr.modules.layer_norm import LayerNorm
  8. from funasr.modules.rwkv_subsampling import RWKVConvInput
  9. from funasr.modules.nets_utils import make_source_mask
  10. class RWKVEncoder(AbsEncoder):
  11. """RWKV encoder module.
  12. Based on https://arxiv.org/pdf/2305.13048.pdf.
  13. Args:
  14. vocab_size: Vocabulary size.
  15. output_size: Input/Output size.
  16. context_size: Context size for WKV computation.
  17. linear_size: FeedForward hidden size.
  18. attention_size: SelfAttention hidden size.
  19. normalization_type: Normalization layer type.
  20. normalization_args: Normalization layer arguments.
  21. num_blocks: Number of RWKV blocks.
  22. embed_dropout_rate: Dropout rate for embedding layer.
  23. att_dropout_rate: Dropout rate for the attention module.
  24. ffn_dropout_rate: Dropout rate for the feed-forward module.
  25. """
  26. def __init__(
  27. self,
  28. input_size: int,
  29. output_size: int = 512,
  30. context_size: int = 1024,
  31. linear_size: Optional[int] = None,
  32. attention_size: Optional[int] = None,
  33. num_blocks: int = 4,
  34. att_dropout_rate: float = 0.0,
  35. ffn_dropout_rate: float = 0.0,
  36. dropout_rate: float = 0.0,
  37. subsampling_factor: int =4,
  38. time_reduction_factor: int = 1,
  39. kernel: int = 3,
  40. ) -> None:
  41. """Construct a RWKVEncoder object."""
  42. super().__init__()
  43. self.embed = RWKVConvInput(
  44. input_size,
  45. [output_size//4, output_size//2, output_size],
  46. subsampling_factor,
  47. conv_kernel_size=kernel,
  48. output_size=output_size,
  49. )
  50. self.subsampling_factor = subsampling_factor
  51. linear_size = output_size * 4 if linear_size is None else linear_size
  52. attention_size = output_size if attention_size is None else attention_size
  53. self.rwkv_blocks = torch.nn.ModuleList(
  54. [
  55. RWKV(
  56. output_size,
  57. linear_size,
  58. attention_size,
  59. context_size,
  60. block_id,
  61. num_blocks,
  62. att_dropout_rate=att_dropout_rate,
  63. ffn_dropout_rate=ffn_dropout_rate,
  64. dropout_rate=dropout_rate,
  65. )
  66. for block_id in range(num_blocks)
  67. ]
  68. )
  69. self.embed_norm = LayerNorm(output_size)
  70. self.final_norm = LayerNorm(output_size)
  71. self._output_size = output_size
  72. self.context_size = context_size
  73. self.num_blocks = num_blocks
  74. self.time_reduction_factor = time_reduction_factor
  75. def output_size(self) -> int:
  76. return self._output_size
  77. def forward(self, x: torch.Tensor, x_len) -> torch.Tensor:
  78. """Encode source label sequences.
  79. Args:
  80. x: Encoder input sequences. (B, L)
  81. Returns:
  82. out: Encoder output sequences. (B, U, D)
  83. """
  84. _, length, _ = x.size()
  85. assert (
  86. length <= self.context_size * self.subsampling_factor
  87. ), "Context size is too short for current length: %d versus %d" % (
  88. length,
  89. self.context_size * self.subsampling_factor,
  90. )
  91. mask = make_source_mask(x_len).to(x.device)
  92. x, mask = self.embed(x, mask, None)
  93. x = self.embed_norm(x)
  94. olens = mask.eq(0).sum(1)
  95. # for training
  96. # for block in self.rwkv_blocks:
  97. # x, _ = block(x)
  98. # for streaming inference
  99. x = self.rwkv_infer(x)
  100. x = self.final_norm(x)
  101. if self.time_reduction_factor > 1:
  102. x = x[:,::self.time_reduction_factor,:]
  103. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  104. return x, olens, None
  105. def rwkv_infer(self, xs_pad):
  106. batch_size = xs_pad.shape[0]
  107. hidden_sizes = [
  108. self._output_size for i in range(5)
  109. ]
  110. state = [
  111. torch.zeros(
  112. (batch_size, 1, hidden_sizes[i], self.num_blocks),
  113. dtype=torch.float32,
  114. device=xs_pad.device,
  115. )
  116. for i in range(5)
  117. ]
  118. state[4] -= 1e-30
  119. xs_out = []
  120. for t in range(xs_pad.shape[1]):
  121. x_t = xs_pad[:,t,:]
  122. for idx, block in enumerate(self.rwkv_blocks):
  123. x_t, state = block(x_t, state=state)
  124. xs_out.append(x_t)
  125. xs_out = torch.cat(xs_out, dim=1)
  126. return xs_out