ultralytics 8.3.110 New dataset file access speed checks (#20197)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-04-17 12:51:27 +02:00 committed by GitHub
parent f94058a90d
commit c872e9d474
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 100 additions and 18 deletions

View File

@ -76,11 +76,11 @@ Community leaders will follow these Community Impact Guidelines in determining t
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct/.
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion).
For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.
For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq/. Translations are available at https://www.contributor-covenant.org/translations/.
## FAQ

View File

@ -236,13 +236,13 @@ YOLOE supports both text-based and visual prompting. Using prompts is straightfo
# Define visual prompts based on a separate reference image
visual_prompts = dict(
bboxes=np.array([[221.52, 405.8, 344.98, 857.54]]), # Box enclosing person
cls=np.array([0]) # ID to be assigned for person
cls=np.array([0]), # ID to be assigned for person
)
# Run prediction on a different image, using reference image to guide what to look for
results = model.predict(
"ultralytics/assets/zidane.jpg", # Target image for detection
refer_image="ultralytics/assets/bus.jpg", # Reference image used to get visual prompts
"ultralytics/assets/zidane.jpg", # Target image for detection
refer_image="ultralytics/assets/bus.jpg", # Reference image used to get visual prompts
visual_prompts=visual_prompts,
predictor=YOLOEVPSegPredictor,
)
@ -263,7 +263,7 @@ YOLOE supports both text-based and visual prompting. Using prompts is straightfo
model = YOLOE("yoloe-11l-seg.pt")
# Define visual prompts using bounding boxes and their corresponding class IDs.
# Each box highlights an example of the object you want the model to detect.
# Each box highlights an example of the object you want the model to detect.
visual_prompts = dict(
bboxes=[
np.array(

View File

@ -19,6 +19,10 @@ keywords: Ultralytics, dataset utils, data handling, image verification, Python,
<br><br><hr><br>
## ::: ultralytics.data.utils.check_file_speeds
<br><br><hr><br>
## ::: ultralytics.data.utils.get_hash
<br><br><hr><br>

View File

@ -109,7 +109,7 @@ For efficient data management, especially with large datasets or numerous experi
```bash
# Ensure Google Cloud SDK is installed and initialized
# If not installed: curl https://sdk.cloud.google.com | bash
# If not installed: curl https://sdk.cloud.google.com/ | bash
# Then initialize: gcloud init
# Example: Copy your dataset from a GCS bucket to your VM

View File

@ -5,7 +5,7 @@
<div class="banner-wrapper">
<div
class="banner-content-wrapper"
onclick="window.open('https://docs.ultralytics.com/models/yolo11')"
onclick="window.open('https://docs.ultralytics.com/models/yolo11/')"
>
<p>Introducing</p>
<img

View File

@ -1,6 +1,6 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
__version__ = "8.3.109"
__version__ = "8.3.110"
import os

View File

@ -14,7 +14,7 @@ import numpy as np
import psutil
from torch.utils.data import Dataset
from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS
from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS, check_file_speeds
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
@ -172,6 +172,7 @@ class BaseDataset(Dataset):
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
if self.fraction < 1:
im_files = im_files[: round(len(im_files) * self.fraction)] # retain a fraction of the dataset
check_file_speeds(im_files, prefix=self.prefix) # check image read speeds
return im_files
def update_labels(self, include_class: Optional[list]):

View File

@ -31,6 +31,7 @@ from .converter import merge_multi_segment
from .utils import (
HELP_URL,
LOGGER,
check_file_speeds,
get_hash,
img2label_paths,
load_dataset_cache_file,
@ -794,6 +795,7 @@ class ClassificationDataset:
path = Path(self.root).with_suffix(".cache") # *.cache file path
try:
check_file_speeds([file for (file, _) in self.samples[:5]], prefix=self.prefix) # check image read speeds
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash

View File

@ -47,6 +47,81 @@ def img2label_paths(img_paths):
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
def check_file_speeds(files, threshold_ms=10, max_files=5, prefix=""):
"""
Check dataset file access speed and provide performance feedback.
This function tests the access speed of dataset files by measuring ping (stat call) time and read speed.
It samples up to 5 files from the provided list and warns if access times exceed the threshold.
Args:
files (list): List of file paths to check for access speed.
threshold_ms (float, optional): Threshold in milliseconds for ping time warnings.
max_files (int, optional): The maximum number of files to check.
prefix (str, optional): Prefix string to add to log messages.
Examples:
>>> from pathlib import Path
>>> image_files = list(Path("dataset/images").glob("*.jpg"))
>>> check_file_speeds(image_files, threshold_ms=15)
"""
if not files or len(files) == 0:
LOGGER.warning(f"{prefix}WARNING ⚠️ Image speed checks: No files to check")
return
# Sample files (max 5)
files = random.sample(files, min(max_files, len(files)))
# Test ping (stat time)
ping_times = []
file_sizes = []
read_speeds = []
for f in files:
try:
# Measure ping (stat call)
start = time.perf_counter()
file_size = os.stat(f).st_size
ping_times.append((time.perf_counter() - start) * 1000) # ms
file_sizes.append(file_size)
# Measure read speed
start = time.perf_counter()
with open(f, "rb") as file_obj:
_ = file_obj.read()
read_time = time.perf_counter() - start
if read_time > 0: # Avoid division by zero
read_speeds.append(file_size / (1 << 20) / read_time) # MB/s
except Exception:
pass
if not ping_times:
LOGGER.warning(f"{prefix}WARNING ⚠️ Image speed checks: failed to access files")
return
# Calculate stats with uncertainties
avg_ping = np.mean(ping_times)
std_ping = np.std(ping_times, ddof=1) if len(ping_times) > 1 else 0
size_msg = f", size: {np.mean(file_sizes) / (1 << 10):.1f} KB"
ping_msg = f"ping: {avg_ping:.1f}±{std_ping:.1f} ms"
if read_speeds:
avg_speed = np.mean(read_speeds)
std_speed = np.std(read_speeds, ddof=1) if len(read_speeds) > 1 else 0
speed_msg = f", read: {avg_speed:.1f}±{std_speed:.1f} MB/s"
else:
speed_msg = ""
if avg_ping < threshold_ms:
LOGGER.info(f"{prefix}Fast image access ✅ ({ping_msg}{speed_msg}{size_msg})")
else:
LOGGER.warning(
f"{prefix}WARNING ⚠️ Slow image access detected ({ping_msg}{speed_msg}{size_msg}). "
f"Use local storage instead of remote/mounted storage for better performance. "
f"See https://docs.ultralytics.com/guides/model-training-tips/"
)
def get_hash(paths):
"""Returns a single hash value of a list of paths (files or dirs)."""
size = 0

View File

@ -1000,7 +1000,7 @@ class Exporter:
@try_export
def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
"""YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
"""YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen-Graph-TensorFlow."""
import tensorflow as tf # noqa
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
@ -1281,7 +1281,7 @@ class Exporter:
return f, None
def _add_tflite_metadata(self, file):
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
"""Add metadata to *.tflite models per https://ai.google.dev/edge/litert/models/metadata."""
import flatbuffers
try:

View File

@ -397,7 +397,7 @@ class AutoBackend(nn.Module):
pass
# TFLite or TFLite Edge TPU
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
elif tflite or edgetpu: # https://ai.google.dev/edge/litert/microcontrollers/python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:

View File

@ -422,7 +422,7 @@ class C3Ghost(C3):
class GhostBottleneck(nn.Module):
"""Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
"""Ghost Bottleneck https://github.com/huawei-noah/Efficient-AI-Backbones."""
def __init__(self, c1, c2, k=3, s=1):
"""

View File

@ -337,7 +337,7 @@ class GhostConv(nn.Module):
cv2 (Conv): Cheap operation convolution.
References:
https://github.com/huawei-noah/ghostnet
https://github.com/huawei-noah/Efficient-AI-Backbones
"""
def __init__(self, c1, c2, k=1, s=1, g=1, act=True):

View File

@ -58,7 +58,7 @@ class ParkingPtsSelection:
"Linux": "sudo apt install python3-tk (Debian/Ubuntu) | sudo dnf install python3-tkinter (Fedora) | "
"sudo pacman -S tk (Arch)",
"Windows": "reinstall Python and enable the checkbox `tcl/tk and IDLE` on **Optional Features** during installation",
"Darwin": "reinstall Python from https://www.python.org/downloads/mac-osx/ or `brew install python-tk`",
"Darwin": "reinstall Python from https://www.python.org/downloads/macos/ or `brew install python-tk`",
}.get(platform.system(), "Unknown OS. Check your Python installation.")
LOGGER.warning(f"WARNING ⚠️ Tkinter is not configured or supported. Potential fix: {install_cmd}")

View File

@ -87,7 +87,7 @@ def generate_ddp_command(world_size, trainer):
cmd (List[str]): The command to execute for distributed training.
file (str): Path to the temporary file created for DDP training.
"""
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir

View File

@ -52,7 +52,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
def box_iou(box1, box2, eps=1e-7):
"""
Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py.
Based on https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py.
Args:
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.