[Quantization] import torch.fx only if torch.__version__ >= 1.8

This commit is contained in:
wanchen.swc 2023-03-30 16:31:14 +08:00
parent 8207de9441
commit 0438b966d6

View File

@ -75,8 +75,8 @@ def preprocess_for_attn(x, mask, cache, pad_fn):
return x, cache
torch_version = float(".".join(torch.__version__.split(".")[:2]))
if torch_version >= 1.8:
torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
if torch_version >= (1, 8):
import torch.fx
torch.fx.wrap('preprocess_for_attn')