Fix seaco onnx export bug (#2325)

This commit is contained in:
zhong zhuang 2024-12-21 17:14:35 +08:00 committed by GitHub
parent b5ad7c81be
commit fcb2102a60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -67,7 +67,11 @@ def _onnx(
device = kwargs.get("device", "cpu")
dummy_input = model.export_dummy_inputs()
dummy_input = (dummy_input[0].to(device), dummy_input[1].to(device))
if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.to(device)
else:
dummy_input = tuple([input.to(device) for input in dummy_input])
verbose = kwargs.get("verbose", False)