rwkv_encoder.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. from typing import Dict, List, Optional, Tuple
  7. from funasr.register import tables
  8. from funasr.models.rwkv_bat.rwkv import RWKV
  9. from funasr.models.transformer.layer_norm import LayerNorm
  10. from funasr.models.transformer.utils.nets_utils import make_source_mask
  11. from funasr.models.rwkv_bat.rwkv_subsampling import RWKVConvInput
  12. @tables.register("encoder_classes", "RWKVEncoder")
  13. class RWKVEncoder(torch.nn.Module):
  14. """RWKV encoder module.
  15. Based on https://arxiv.org/pdf/2305.13048.pdf.
  16. Args:
  17. vocab_size: Vocabulary size.
  18. output_size: Input/Output size.
  19. context_size: Context size for WKV computation.
  20. linear_size: FeedForward hidden size.
  21. attention_size: SelfAttention hidden size.
  22. normalization_type: Normalization layer type.
  23. normalization_args: Normalization layer arguments.
  24. num_blocks: Number of RWKV blocks.
  25. embed_dropout_rate: Dropout rate for embedding layer.
  26. att_dropout_rate: Dropout rate for the attention module.
  27. ffn_dropout_rate: Dropout rate for the feed-forward module.
  28. """
  29. def __init__(
  30. self,
  31. input_size: int,
  32. output_size: int = 512,
  33. context_size: int = 1024,
  34. linear_size: Optional[int] = None,
  35. attention_size: Optional[int] = None,
  36. num_blocks: int = 4,
  37. att_dropout_rate: float = 0.0,
  38. ffn_dropout_rate: float = 0.0,
  39. dropout_rate: float = 0.0,
  40. subsampling_factor: int =4,
  41. time_reduction_factor: int = 1,
  42. kernel: int = 3,
  43. **kwargs,
  44. ) -> None:
  45. """Construct a RWKVEncoder object."""
  46. super().__init__()
  47. self.embed = RWKVConvInput(
  48. input_size,
  49. [output_size//4, output_size//2, output_size],
  50. subsampling_factor,
  51. conv_kernel_size=kernel,
  52. output_size=output_size,
  53. )
  54. self.subsampling_factor = subsampling_factor
  55. linear_size = output_size * 4 if linear_size is None else linear_size
  56. attention_size = output_size if attention_size is None else attention_size
  57. self.rwkv_blocks = torch.nn.ModuleList(
  58. [
  59. RWKV(
  60. output_size,
  61. linear_size,
  62. attention_size,
  63. context_size,
  64. block_id,
  65. num_blocks,
  66. att_dropout_rate=att_dropout_rate,
  67. ffn_dropout_rate=ffn_dropout_rate,
  68. dropout_rate=dropout_rate,
  69. )
  70. for block_id in range(num_blocks)
  71. ]
  72. )
  73. self.embed_norm = LayerNorm(output_size)
  74. self.final_norm = LayerNorm(output_size)
  75. self._output_size = output_size
  76. self.context_size = context_size
  77. self.num_blocks = num_blocks
  78. self.time_reduction_factor = time_reduction_factor
  79. def output_size(self) -> int:
  80. return self._output_size
  81. def forward(self, x: torch.Tensor, x_len) -> torch.Tensor:
  82. """Encode source label sequences.
  83. Args:
  84. x: Encoder input sequences. (B, L)
  85. Returns:
  86. out: Encoder output sequences. (B, U, D)
  87. """
  88. _, length, _ = x.size()
  89. assert (
  90. length <= self.context_size * self.subsampling_factor
  91. ), "Context size is too short for current length: %d versus %d" % (
  92. length,
  93. self.context_size * self.subsampling_factor,
  94. )
  95. mask = make_source_mask(x_len).to(x.device)
  96. x, mask = self.embed(x, mask, None)
  97. x = self.embed_norm(x)
  98. olens = mask.eq(0).sum(1)
  99. if self.training:
  100. for block in self.rwkv_blocks:
  101. x, _ = block(x)
  102. else:
  103. x = self.rwkv_infer(x)
  104. x = self.final_norm(x)
  105. if self.time_reduction_factor > 1:
  106. x = x[:,::self.time_reduction_factor,:]
  107. olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
  108. return x, olens, None
  109. def rwkv_infer(self, xs_pad):
  110. batch_size = xs_pad.shape[0]
  111. hidden_sizes = [
  112. self._output_size for i in range(5)
  113. ]
  114. state = [
  115. torch.zeros(
  116. (batch_size, 1, hidden_sizes[i], self.num_blocks),
  117. dtype=torch.float32,
  118. device=xs_pad.device,
  119. )
  120. for i in range(5)
  121. ]
  122. state[4] -= 1e-30
  123. xs_out = []
  124. for t in range(xs_pad.shape[1]):
  125. x_t = xs_pad[:,t,:]
  126. for idx, block in enumerate(self.rwkv_blocks):
  127. x_t, state = block(x_t, state=state)
  128. xs_out.append(x_t)
  129. xs_out = torch.cat(xs_out, dim=1)
  130. return xs_out