From 1d13575ba16623d711c682118ee118615383ba99 Mon Sep 17 00:00:00 2001 From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Date: Sun, 22 Dec 2024 09:02:58 +0800 Subject: [PATCH] `ultralytics 8.3.53` New Export argument validation (#18185) Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Signed-off-by: Glenn Jocher Signed-off-by: UltralyticsAssistant Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- docs/en/reference/engine/exporter.md | 4 ++ ultralytics/__init__.py | 2 +- ultralytics/engine/exporter.py | 62 ++++++++++++++++++++-------- ultralytics/utils/benchmarks.py | 2 +- 4 files changed, 50 insertions(+), 20 deletions(-) diff --git a/docs/en/reference/engine/exporter.md b/docs/en/reference/engine/exporter.md index 98e81a8aaf..a0d1822dce 100644 --- a/docs/en/reference/engine/exporter.md +++ b/docs/en/reference/engine/exporter.md @@ -23,6 +23,10 @@ keywords: YOLOv8, export formats, ONNX, TensorRT, CoreML, machine learning model



+## ::: ultralytics.engine.exporter.validate_args + +



+ ## ::: ultralytics.engine.exporter.gd_outputs



diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 177afda2c3..824a1aa741 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.52" +__version__ = "8.3.53" import os diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index ae84cab9a1..ea2bc01ecb 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -101,23 +101,47 @@ from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_d def export_formats(): """Ultralytics YOLO export formats.""" x = [ - ["PyTorch", "-", ".pt", True, True], - ["TorchScript", "torchscript", ".torchscript", True, True], - ["ONNX", "onnx", ".onnx", True, True], - ["OpenVINO", "openvino", "_openvino_model", True, False], - ["TensorRT", "engine", ".engine", False, True], - ["CoreML", "coreml", ".mlpackage", True, False], - ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True], - ["TensorFlow GraphDef", "pb", ".pb", True, True], - ["TensorFlow Lite", "tflite", ".tflite", True, False], - ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False], - ["TensorFlow.js", "tfjs", "_web_model", True, False], - ["PaddlePaddle", "paddle", "_paddle_model", True, True], - ["MNN", "mnn", ".mnn", True, True], - ["NCNN", "ncnn", "_ncnn_model", True, True], - ["IMX", "imx", "_imx_model", True, True], + ["PyTorch", "-", ".pt", True, True, []], + ["TorchScript", "torchscript", ".torchscript", True, True, ["optimize", "batch"]], + ["ONNX", "onnx", ".onnx", True, True, ["half", "dynamic", "simplify", "opset", "batch"]], + ["OpenVINO", "openvino", "_openvino_model", True, False, ["half", "int8", "batch"]], + ["TensorRT", "engine", ".engine", False, True, ["half", "dynamic", "simplify", "int8", "batch"]], + ["CoreML", "coreml", ".mlpackage", True, False, ["half", "int8", "nms", "batch"]], + ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["keras", "int8", "batch"]], + ["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]], + ["TensorFlow Lite", "tflite", ".tflite", True, False, ["half", "int8", "batch"]], + ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []], + ["TensorFlow.js", "tfjs", "_web_model", True, False, ["half", "int8", "batch"]], + ["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]], + ["MNN", "mnn", ".mnn", True, True, ["batch", "int8", "half"]], + ["NCNN", "ncnn", "_ncnn_model", True, True, ["half", "batch"]], + ["IMX", "imx", "_imx_model", True, True, ["int8"]], ] - return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x))) + return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x))) + + +def validate_args(format, passed_args, valid_args): + """ + Validates arguments based on format. + + Args: + format (str): The export format. + passed_args (Namespace): The arguments used during export. + valid_args (dict): List of valid arguments for the format. + + Raises: + AssertionError: If an argument that's not supported by the export format is used, or if format doesn't have the supported arguments listed. + """ + # Only check valid usage of these args + export_args = ["half", "int8", "dynamic", "keras", "nms", "batch"] + + assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed." + custom = {"batch": 1, "data": None, "device": None} # exporter defaults + default_args = get_cfg(DEFAULT_CFG, custom) + for arg in export_args: + not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None) + if not_default: + assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'" def gd_outputs(gd): @@ -182,7 +206,8 @@ class Exporter: fmt = "engine" if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases fmt = "coreml" - fmts = tuple(export_formats()["Argument"][1:]) # available export formats + fmts_dict = export_formats() + fmts = tuple(fmts_dict["Argument"][1:]) # available export formats if fmt not in fmts: import difflib @@ -224,7 +249,8 @@ class Exporter: assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}." self.device = select_device("cpu" if self.args.device is None else self.args.device) - # Checks + # Argument compatibility checks + validate_args(fmt, self.args, fmts_dict["Arguments"][flags.index(True) + 1]) if imx and not self.args.int8: LOGGER.warning("WARNING ⚠️ IMX only supports int8 export, setting int8=True.") self.args.int8 = True diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index e65d128876..e5a6c22ab6 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -90,7 +90,7 @@ def benchmark( y = [] t0 = time.time() - for i, (name, format, suffix, cpu, gpu) in enumerate(zip(*export_formats().values())): + for i, (name, format, suffix, cpu, gpu, _) in enumerate(zip(*export_formats().values())): emoji, filename = "❌", None # export defaults try: # Checks