|
|
@@ -113,12 +113,12 @@ class RWKVEncoder(AbsEncoder):
|
|
|
x = self.embed_norm(x)
|
|
|
olens = mask.eq(0).sum(1)
|
|
|
|
|
|
- # for training
|
|
|
- # for block in self.rwkv_blocks:
|
|
|
- # x, _ = block(x)
|
|
|
-
|
|
|
- # for streaming inference
|
|
|
- x = self.rwkv_infer(x)
|
|
|
+ if self.training:
|
|
|
+ for block in self.rwkv_blocks:
|
|
|
+ x, _ = block(x)
|
|
|
+ else:
|
|
|
+ x = self.rwkv_infer(x)
|
|
|
+
|
|
|
x = self.final_norm(x)
|
|
|
|
|
|
if self.time_reduction_factor > 1:
|