aky15 пре 2 година
родитељ
комит
c7c97cb0fe
1 измењених фајлова са 6 додато и 6 уклоњено
  1. 6 6
      funasr/models/encoder/rwkv_encoder.py

+ 6 - 6
funasr/models/encoder/rwkv_encoder.py

@@ -113,12 +113,12 @@ class RWKVEncoder(AbsEncoder):
         x = self.embed_norm(x)
         olens = mask.eq(0).sum(1)
 
-        # for training
-        # for block in self.rwkv_blocks:
-        #     x, _ = block(x)
-
-        # for streaming inference
-        x = self.rwkv_infer(x)
+        if self.training:
+            for block in self.rwkv_blocks:
+                x, _ = block(x)
+        else:
+            x = self.rwkv_infer(x)
+            
         x = self.final_norm(x)
 
         if self.time_reduction_factor > 1: