|
|
@@ -1,17 +1,20 @@
|
|
|
-"""RWKV encoder definition for Transducer models."""
|
|
|
-
|
|
|
-import math
|
|
|
-from typing import Dict, List, Optional, Tuple
|
|
|
+#!/usr/bin/env python3
|
|
|
+# -*- encoding: utf-8 -*-
|
|
|
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
|
|
+# MIT License (https://opensource.org/licenses/MIT)
|
|
|
|
|
|
import torch
|
|
|
+from typing import Dict, List, Optional, Tuple
|
|
|
|
|
|
-from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
|
+from funasr.register import tables
|
|
|
from funasr.models.rwkv_bat.rwkv import RWKV
|
|
|
from funasr.models.transformer.layer_norm import LayerNorm
|
|
|
-from funasr.models.rwkv_bat.rwkv_subsampling import RWKVConvInput
|
|
|
from funasr.models.transformer.utils.nets_utils import make_source_mask
|
|
|
+from funasr.models.rwkv_bat.rwkv_subsampling import RWKVConvInput
|
|
|
+
|
|
|
|
|
|
-class RWKVEncoder(AbsEncoder):
|
|
|
+@tables.register("encoder_classes", "RWKVEncoder")
|
|
|
+class RWKVEncoder(torch.nn.Module):
|
|
|
"""RWKV encoder module.
|
|
|
|
|
|
Based on https://arxiv.org/pdf/2305.13048.pdf.
|
|
|
@@ -44,6 +47,7 @@ class RWKVEncoder(AbsEncoder):
|
|
|
subsampling_factor: int =4,
|
|
|
time_reduction_factor: int = 1,
|
|
|
kernel: int = 3,
|
|
|
+ **kwargs,
|
|
|
) -> None:
|
|
|
"""Construct a RWKVEncoder object."""
|
|
|
super().__init__()
|