| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- """RWKV encoder definition for Transducer models."""
- import math
- from typing import Dict, List, Optional, Tuple
- import torch
- from typeguard import check_argument_types
- from funasr.models.encoder.abs_encoder import AbsEncoder
- from funasr.modules.rwkv import RWKV
- from funasr.modules.layer_norm import LayerNorm
- from funasr.modules.rwkv_subsampling import RWKVConvInput
- from funasr.modules.nets_utils import make_source_mask
- class RWKVEncoder(AbsEncoder):
- """RWKV encoder module.
- Based on https://arxiv.org/pdf/2305.13048.pdf.
- Args:
- vocab_size: Vocabulary size.
- output_size: Input/Output size.
- context_size: Context size for WKV computation.
- linear_size: FeedForward hidden size.
- attention_size: SelfAttention hidden size.
- normalization_type: Normalization layer type.
- normalization_args: Normalization layer arguments.
- num_blocks: Number of RWKV blocks.
- embed_dropout_rate: Dropout rate for embedding layer.
- att_dropout_rate: Dropout rate for the attention module.
- ffn_dropout_rate: Dropout rate for the feed-forward module.
- """
- def __init__(
- self,
- input_size: int,
- output_size: int = 512,
- context_size: int = 1024,
- linear_size: Optional[int] = None,
- attention_size: Optional[int] = None,
- num_blocks: int = 4,
- att_dropout_rate: float = 0.0,
- ffn_dropout_rate: float = 0.0,
- dropout_rate: float = 0.0,
- subsampling_factor: int =4,
- time_reduction_factor: int = 1,
- kernel: int = 3,
- ) -> None:
- """Construct a RWKVEncoder object."""
- super().__init__()
- assert check_argument_types()
- self.embed = RWKVConvInput(
- input_size,
- [output_size//4, output_size//2, output_size],
- subsampling_factor,
- conv_kernel_size=kernel,
- output_size=output_size,
- )
- self.subsampling_factor = subsampling_factor
- linear_size = output_size * 4 if linear_size is None else linear_size
- attention_size = output_size if attention_size is None else attention_size
-
- self.rwkv_blocks = torch.nn.ModuleList(
- [
- RWKV(
- output_size,
- linear_size,
- attention_size,
- context_size,
- block_id,
- num_blocks,
- att_dropout_rate=att_dropout_rate,
- ffn_dropout_rate=ffn_dropout_rate,
- dropout_rate=dropout_rate,
- )
- for block_id in range(num_blocks)
- ]
- )
- self.embed_norm = LayerNorm(output_size)
- self.final_norm = LayerNorm(output_size)
- self._output_size = output_size
- self.context_size = context_size
- self.num_blocks = num_blocks
- self.time_reduction_factor = time_reduction_factor
- def output_size(self) -> int:
- return self._output_size
- def forward(self, x: torch.Tensor, x_len) -> torch.Tensor:
- """Encode source label sequences.
- Args:
- x: Encoder input sequences. (B, L)
- Returns:
- out: Encoder output sequences. (B, U, D)
- """
- _, length, _ = x.size()
- assert (
- length <= self.context_size * self.subsampling_factor
- ), "Context size is too short for current length: %d versus %d" % (
- length,
- self.context_size * self.subsampling_factor,
- )
- mask = make_source_mask(x_len).to(x.device)
- x, mask = self.embed(x, mask, None)
- x = self.embed_norm(x)
- olens = mask.eq(0).sum(1)
- for block in self.rwkv_blocks:
- x, _ = block(x)
- # for streaming inference
- # xs_pad = self.rwkv_infer(xs_pad)
- x = self.final_norm(x)
- if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
- olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
- return x, olens, None
- def rwkv_infer(self, xs_pad):
- batch_size = xs_pad.shape[0]
- hidden_sizes = [
- self._output_size for i in range(5)
- ]
- state = [
- torch.zeros(
- (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
- dtype=torch.float32,
- device=self.device,
- )
- for i in range(5)
- ]
- state[4] -= 1e-30
- xs_out = []
- for t in range(xs_pad.shape[1]):
- x_t = xs_pad[:,t,:]
- for idx, block in enumerate(self.rwkv_blocks):
- x_t, state = block(x_t, state=state)
- xs_out.append(x_t)
- xs_out = torch.stack(xs_out, dim=1)
- return xs_out
|