diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py index 1983db8b6..6fce85166 100644 --- a/funasr/export/models/modules/multihead_att.py +++ b/funasr/export/models/modules/multihead_att.py @@ -75,8 +75,8 @@ def preprocess_for_attn(x, mask, cache, pad_fn): return x, cache -torch_version = float(".".join(torch.__version__.split(".")[:2])) -if torch_version >= 1.8: +torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]]) +if torch_version >= (1, 8): import torch.fx torch.fx.wrap('preprocess_for_attn')