Przeglądaj źródła

Update rwkv_encoder.py

aky15 2 lat temu
rodzic
commit
c7c97cb0fe
1 zmienionych plików z 6 dodań i 6 usunięć
  1. 6 6
      funasr/models/encoder/rwkv_encoder.py

+ 6 - 6
funasr/models/encoder/rwkv_encoder.py

@@ -113,12 +113,12 @@ class RWKVEncoder(AbsEncoder):
         x = self.embed_norm(x)
         olens = mask.eq(0).sum(1)
 
-        # for training
-        # for block in self.rwkv_blocks:
-        #     x, _ = block(x)
-
-        # for streaming inference
-        x = self.rwkv_infer(x)
+        if self.training:
+            for block in self.rwkv_blocks:
+                x, _ = block(x)
+        else:
+            x = self.rwkv_infer(x)
+            
         x = self.final_norm(x)
 
         if self.time_reduction_factor > 1: