Răsfoiți Sursa

[Quantization] import torch.fx only if torch.__version__ >= 1.8

wanchen.swc 3 ani în urmă
părinte
comite
0438b966d6
1 a modificat fișierele cu 2 adăugiri și 2 ștergeri
  1. 2 2
      funasr/export/models/modules/multihead_att.py

+ 2 - 2
funasr/export/models/modules/multihead_att.py

@@ -75,8 +75,8 @@ def preprocess_for_attn(x, mask, cache, pad_fn):
     return x, cache
 
 
-torch_version = float(".".join(torch.__version__.split(".")[:2]))
-if torch_version >= 1.8:
+torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
+if torch_version >= (1, 8):
     import torch.fx
     torch.fx.wrap('preprocess_for_attn')