rwkv_encoder.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """RWKV encoder definition for Transducer models."""
  2. import math
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. from typeguard import check_argument_types
  6. from funasr.models.encoder.abs_encoder import AbsEncoder
  7. from funasr.modules.rwkv import RWKV
  8. from funasr.modules.layer_norm import LayerNorm
  9. from funasr.modules.rwkv_subsampling import RWKVConvInput
  10. from funasr.modules.nets_utils import make_source_mask
  11. class RWKVEncoder(AbsEncoder):
  12. """RWKV encoder module.
  13. Based on https://arxiv.org/pdf/2305.13048.pdf.
  14. Args:
  15. vocab_size: Vocabulary size.
  16. output_size: Input/Output size.
  17. context_size: Context size for WKV computation.
  18. linear_size: FeedForward hidden size.
  19. attention_size: SelfAttention hidden size.
  20. normalization_type: Normalization layer type.
  21. normalization_args: Normalization layer arguments.
  22. num_blocks: Number of RWKV blocks.
  23. embed_dropout_rate: Dropout rate for embedding layer.
  24. att_dropout_rate: Dropout rate for the attention module.
  25. ffn_dropout_rate: Dropout rate for the feed-forward module.
  26. """
  27. def __init__(
  28. self,
  29. input_size: int,
  30. output_size: int = 512,
  31. context_size: int = 1024,
  32. linear_size: Optional[int] = None,
  33. attention_size: Optional[int] = None,
  34. num_blocks: int = 4,
  35. att_dropout_rate: float = 0.0,
  36. ffn_dropout_rate: float = 0.0,
  37. dropout_rate: float = 0.0,
  38. subsampling_factor: int =4,
  39. time_reduction_factor: int = 1,
  40. kernel: int = 3,
  41. ) -> None:
  42. """Construct a RWKVEncoder object."""
  43. super().__init__()
  44. assert check_argument_types()
  45. self.embed = RWKVConvInput(
  46. input_size,
  47. [output_size//4, output_size//2, output_size],
  48. subsampling_factor,
  49. conv_kernel_size=kernel,
  50. output_size=output_size,
  51. )
  52. self.subsampling_factor = subsampling_factor
  53. linear_size = output_size * 4 if linear_size is None else linear_size
  54. attention_size = output_size if attention_size is None else attention_size
  55. self.rwkv_blocks = torch.nn.ModuleList(
  56. [
  57. RWKV(
  58. output_size,
  59. linear_size,
  60. attention_size,
  61. context_size,
  62. block_id,
  63. num_blocks,
  64. att_dropout_rate=att_dropout_rate,
  65. ffn_dropout_rate=ffn_dropout_rate,
  66. dropout_rate=dropout_rate,
  67. )
  68. for block_id in range(num_blocks)
  69. ]
  70. )
  71. self.embed_norm = LayerNorm(output_size)
  72. self.final_norm = LayerNorm(output_size)
  73. self._output_size = output_size
  74. self.context_size = context_size
  75. self.num_blocks = num_blocks
  76. self.time_reduction_factor = time_reduction_factor
  77. def output_size(self) -> int:
  78. return self._output_size
  79. def forward(self, x: torch.Tensor, x_len) -> torch.Tensor:
  80. """Encode source label sequences.
  81. Args:
  82. x: Encoder input sequences. (B, L)
  83. Returns:
  84. out: Encoder output sequences. (B, U, D)
  85. """
  86. _, length, _ = x.size()
  87. assert (
  88. length <= self.context_size * self.subsampling_factor
  89. ), "Context size is too short for current length: %d versus %d" % (
  90. length,
  91. self.context_size * self.subsampling_factor,
  92. )
  93. mask = make_source_mask(x_len).to(x.device)
  94. x, mask = self.embed(x, mask, None)
  95. x = self.embed_norm(x)
  96. olens = mask.eq(0).sum(1)
  97. for block in self.rwkv_blocks:
  98. x, _ = block(x)
  99. # for streaming inference
  100. # xs_pad = self.rwkv_infer(xs_pad)
  101. x = self.final_norm(x)
  102. if self.time_reduction_factor > 1:
  103. x = x[:,::self.time_reduction_factor,:]
  104. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  105. return x, olens, None
  106. def rwkv_infer(self, xs_pad):
  107. batch_size = xs_pad.shape[0]
  108. hidden_sizes = [
  109. self._output_size for i in range(5)
  110. ]
  111. state = [
  112. torch.zeros(
  113. (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
  114. dtype=torch.float32,
  115. device=self.device,
  116. )
  117. for i in range(5)
  118. ]
  119. state[4] -= 1e-30
  120. xs_out = []
  121. for t in range(xs_pad.shape[1]):
  122. x_t = xs_pad[:,t,:]
  123. for idx, block in enumerate(self.rwkv_blocks):
  124. x_t, state = block(x_t, state=state)
  125. xs_out.append(x_t)
  126. xs_out = torch.stack(xs_out, dim=1)
  127. return xs_out