mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
[Quantization] import torch.fx only if torch.__version__ >= 1.8
This commit is contained in:
parent
8207de9441
commit
0438b966d6
@ -75,8 +75,8 @@ def preprocess_for_attn(x, mask, cache, pad_fn):
|
||||
return x, cache
|
||||
|
||||
|
||||
torch_version = float(".".join(torch.__version__.split(".")[:2]))
|
||||
if torch_version >= 1.8:
|
||||
torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
|
||||
if torch_version >= (1, 8):
|
||||
import torch.fx
|
||||
torch.fx.wrap('preprocess_for_attn')
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user