Explorar el Código

Update rwkv_encoder.py

aky15 hace 2 años
padre
commit
c7c97cb0fe
Se han modificado 1 ficheros con 6 adiciones y 6 borrados
  1. 6 6
      funasr/models/encoder/rwkv_encoder.py

+ 6 - 6
funasr/models/encoder/rwkv_encoder.py

@@ -113,12 +113,12 @@ class RWKVEncoder(AbsEncoder):
         x = self.embed_norm(x)
         olens = mask.eq(0).sum(1)
 
-        # for training
-        # for block in self.rwkv_blocks:
-        #     x, _ = block(x)
-
-        # for streaming inference
-        x = self.rwkv_infer(x)
+        if self.training:
+            for block in self.rwkv_blocks:
+                x, _ = block(x)
+        else:
+            x = self.rwkv_infer(x)
+            
         x = self.final_norm(x)
 
         if self.time_reduction_factor > 1: