|
@@ -75,8 +75,8 @@ def preprocess_for_attn(x, mask, cache, pad_fn):
|
|
|
return x, cache
|
|
return x, cache
|
|
|
|
|
|
|
|
|
|
|
|
|
-torch_version = float(".".join(torch.__version__.split(".")[:2]))
|
|
|
|
|
-if torch_version >= 1.8:
|
|
|
|
|
|
|
+torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
|
|
|
|
|
+if torch_version >= (1, 8):
|
|
|
import torch.fx
|
|
import torch.fx
|
|
|
torch.fx.wrap('preprocess_for_attn')
|
|
torch.fx.wrap('preprocess_for_attn')
|
|
|
|
|
|