diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py index dbec50426..0691ef2d5 100644 --- a/funasr/utils/export_utils.py +++ b/funasr/utils/export_utils.py @@ -67,7 +67,11 @@ def _onnx( device = kwargs.get("device", "cpu") dummy_input = model.export_dummy_inputs() - dummy_input = (dummy_input[0].to(device), dummy_input[1].to(device)) + + if isinstance(dummy_input, torch.Tensor): + dummy_input = dummy_input.to(device) + else: + dummy_input = tuple([input.to(device) for input in dummy_input]) verbose = kwargs.get("verbose", False)