From 0438b966d6d9ef89fdd7a8d112d6e03b34122b1e Mon Sep 17 00:00:00 2001 From: "wanchen.swc" Date: Thu, 30 Mar 2023 16:31:14 +0800 Subject: [PATCH] [Quantization] import torch.fx only if torch.__version__ >= 1.8 --- funasr/export/models/modules/multihead_att.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py index 1983db8b6..6fce85166 100644 --- a/funasr/export/models/modules/multihead_att.py +++ b/funasr/export/models/modules/multihead_att.py @@ -75,8 +75,8 @@ def preprocess_for_attn(x, mask, cache, pad_fn): return x, cache -torch_version = float(".".join(torch.__version__.split(".")[:2])) -if torch_version >= 1.8: +torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]]) +if torch_version >= (1, 8): import torch.fx torch.fx.wrap('preprocess_for_attn')