mirror of
https://github.com/ultralytics/ultralytics.git
synced 2025-09-15 15:48:41 +08:00
ultralytics 8.3.110 New dataset file access speed checks (#20197)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
f94058a90d
commit
c872e9d474
@ -76,11 +76,11 @@ Community leaders will follow these Community Impact Guidelines in determining t
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct/.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion).
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.
|
||||
For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq/. Translations are available at https://www.contributor-covenant.org/translations/.
|
||||
|
||||
## FAQ
|
||||
|
||||
|
||||
@ -236,13 +236,13 @@ YOLOE supports both text-based and visual prompting. Using prompts is straightfo
|
||||
# Define visual prompts based on a separate reference image
|
||||
visual_prompts = dict(
|
||||
bboxes=np.array([[221.52, 405.8, 344.98, 857.54]]), # Box enclosing person
|
||||
cls=np.array([0]) # ID to be assigned for person
|
||||
cls=np.array([0]), # ID to be assigned for person
|
||||
)
|
||||
|
||||
# Run prediction on a different image, using reference image to guide what to look for
|
||||
results = model.predict(
|
||||
"ultralytics/assets/zidane.jpg", # Target image for detection
|
||||
refer_image="ultralytics/assets/bus.jpg", # Reference image used to get visual prompts
|
||||
"ultralytics/assets/zidane.jpg", # Target image for detection
|
||||
refer_image="ultralytics/assets/bus.jpg", # Reference image used to get visual prompts
|
||||
visual_prompts=visual_prompts,
|
||||
predictor=YOLOEVPSegPredictor,
|
||||
)
|
||||
@ -263,7 +263,7 @@ YOLOE supports both text-based and visual prompting. Using prompts is straightfo
|
||||
model = YOLOE("yoloe-11l-seg.pt")
|
||||
|
||||
# Define visual prompts using bounding boxes and their corresponding class IDs.
|
||||
# Each box highlights an example of the object you want the model to detect.
|
||||
# Each box highlights an example of the object you want the model to detect.
|
||||
visual_prompts = dict(
|
||||
bboxes=[
|
||||
np.array(
|
||||
|
||||
@ -19,6 +19,10 @@ keywords: Ultralytics, dataset utils, data handling, image verification, Python,
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.utils.check_file_speeds
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.utils.get_hash
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
@ -109,7 +109,7 @@ For efficient data management, especially with large datasets or numerous experi
|
||||
|
||||
```bash
|
||||
# Ensure Google Cloud SDK is installed and initialized
|
||||
# If not installed: curl https://sdk.cloud.google.com | bash
|
||||
# If not installed: curl https://sdk.cloud.google.com/ | bash
|
||||
# Then initialize: gcloud init
|
||||
|
||||
# Example: Copy your dataset from a GCS bucket to your VM
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
<div class="banner-wrapper">
|
||||
<div
|
||||
class="banner-content-wrapper"
|
||||
onclick="window.open('https://docs.ultralytics.com/models/yolo11')"
|
||||
onclick="window.open('https://docs.ultralytics.com/models/yolo11/')"
|
||||
>
|
||||
<p>Introducing</p>
|
||||
<img
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
__version__ = "8.3.109"
|
||||
__version__ = "8.3.110"
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ import numpy as np
|
||||
import psutil
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS
|
||||
from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS, check_file_speeds
|
||||
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
|
||||
|
||||
|
||||
@ -172,6 +172,7 @@ class BaseDataset(Dataset):
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||
if self.fraction < 1:
|
||||
im_files = im_files[: round(len(im_files) * self.fraction)] # retain a fraction of the dataset
|
||||
check_file_speeds(im_files, prefix=self.prefix) # check image read speeds
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: Optional[list]):
|
||||
|
||||
@ -31,6 +31,7 @@ from .converter import merge_multi_segment
|
||||
from .utils import (
|
||||
HELP_URL,
|
||||
LOGGER,
|
||||
check_file_speeds,
|
||||
get_hash,
|
||||
img2label_paths,
|
||||
load_dataset_cache_file,
|
||||
@ -794,6 +795,7 @@ class ClassificationDataset:
|
||||
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||
|
||||
try:
|
||||
check_file_speeds([file for (file, _) in self.samples[:5]], prefix=self.prefix) # check image read speeds
|
||||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
|
||||
@ -47,6 +47,81 @@ def img2label_paths(img_paths):
|
||||
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
||||
|
||||
|
||||
def check_file_speeds(files, threshold_ms=10, max_files=5, prefix=""):
|
||||
"""
|
||||
Check dataset file access speed and provide performance feedback.
|
||||
|
||||
This function tests the access speed of dataset files by measuring ping (stat call) time and read speed.
|
||||
It samples up to 5 files from the provided list and warns if access times exceed the threshold.
|
||||
|
||||
Args:
|
||||
files (list): List of file paths to check for access speed.
|
||||
threshold_ms (float, optional): Threshold in milliseconds for ping time warnings.
|
||||
max_files (int, optional): The maximum number of files to check.
|
||||
prefix (str, optional): Prefix string to add to log messages.
|
||||
|
||||
Examples:
|
||||
>>> from pathlib import Path
|
||||
>>> image_files = list(Path("dataset/images").glob("*.jpg"))
|
||||
>>> check_file_speeds(image_files, threshold_ms=15)
|
||||
"""
|
||||
if not files or len(files) == 0:
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ Image speed checks: No files to check")
|
||||
return
|
||||
|
||||
# Sample files (max 5)
|
||||
files = random.sample(files, min(max_files, len(files)))
|
||||
|
||||
# Test ping (stat time)
|
||||
ping_times = []
|
||||
file_sizes = []
|
||||
read_speeds = []
|
||||
|
||||
for f in files:
|
||||
try:
|
||||
# Measure ping (stat call)
|
||||
start = time.perf_counter()
|
||||
file_size = os.stat(f).st_size
|
||||
ping_times.append((time.perf_counter() - start) * 1000) # ms
|
||||
file_sizes.append(file_size)
|
||||
|
||||
# Measure read speed
|
||||
start = time.perf_counter()
|
||||
with open(f, "rb") as file_obj:
|
||||
_ = file_obj.read()
|
||||
read_time = time.perf_counter() - start
|
||||
if read_time > 0: # Avoid division by zero
|
||||
read_speeds.append(file_size / (1 << 20) / read_time) # MB/s
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not ping_times:
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ Image speed checks: failed to access files")
|
||||
return
|
||||
|
||||
# Calculate stats with uncertainties
|
||||
avg_ping = np.mean(ping_times)
|
||||
std_ping = np.std(ping_times, ddof=1) if len(ping_times) > 1 else 0
|
||||
size_msg = f", size: {np.mean(file_sizes) / (1 << 10):.1f} KB"
|
||||
ping_msg = f"ping: {avg_ping:.1f}±{std_ping:.1f} ms"
|
||||
|
||||
if read_speeds:
|
||||
avg_speed = np.mean(read_speeds)
|
||||
std_speed = np.std(read_speeds, ddof=1) if len(read_speeds) > 1 else 0
|
||||
speed_msg = f", read: {avg_speed:.1f}±{std_speed:.1f} MB/s"
|
||||
else:
|
||||
speed_msg = ""
|
||||
|
||||
if avg_ping < threshold_ms:
|
||||
LOGGER.info(f"{prefix}Fast image access ✅ ({ping_msg}{speed_msg}{size_msg})")
|
||||
else:
|
||||
LOGGER.warning(
|
||||
f"{prefix}WARNING ⚠️ Slow image access detected ({ping_msg}{speed_msg}{size_msg}). "
|
||||
f"Use local storage instead of remote/mounted storage for better performance. "
|
||||
f"See https://docs.ultralytics.com/guides/model-training-tips/"
|
||||
)
|
||||
|
||||
|
||||
def get_hash(paths):
|
||||
"""Returns a single hash value of a list of paths (files or dirs)."""
|
||||
size = 0
|
||||
|
||||
@ -1000,7 +1000,7 @@ class Exporter:
|
||||
|
||||
@try_export
|
||||
def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
|
||||
"""YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
|
||||
"""YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen-Graph-TensorFlow."""
|
||||
import tensorflow as tf # noqa
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
||||
|
||||
@ -1281,7 +1281,7 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
def _add_tflite_metadata(self, file):
|
||||
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
|
||||
"""Add metadata to *.tflite models per https://ai.google.dev/edge/litert/models/metadata."""
|
||||
import flatbuffers
|
||||
|
||||
try:
|
||||
|
||||
@ -397,7 +397,7 @@ class AutoBackend(nn.Module):
|
||||
pass
|
||||
|
||||
# TFLite or TFLite Edge TPU
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
elif tflite or edgetpu: # https://ai.google.dev/edge/litert/microcontrollers/python
|
||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||
except ImportError:
|
||||
|
||||
@ -422,7 +422,7 @@ class C3Ghost(C3):
|
||||
|
||||
|
||||
class GhostBottleneck(nn.Module):
|
||||
"""Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
|
||||
"""Ghost Bottleneck https://github.com/huawei-noah/Efficient-AI-Backbones."""
|
||||
|
||||
def __init__(self, c1, c2, k=3, s=1):
|
||||
"""
|
||||
|
||||
@ -337,7 +337,7 @@ class GhostConv(nn.Module):
|
||||
cv2 (Conv): Cheap operation convolution.
|
||||
|
||||
References:
|
||||
https://github.com/huawei-noah/ghostnet
|
||||
https://github.com/huawei-noah/Efficient-AI-Backbones
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
|
||||
|
||||
@ -58,7 +58,7 @@ class ParkingPtsSelection:
|
||||
"Linux": "sudo apt install python3-tk (Debian/Ubuntu) | sudo dnf install python3-tkinter (Fedora) | "
|
||||
"sudo pacman -S tk (Arch)",
|
||||
"Windows": "reinstall Python and enable the checkbox `tcl/tk and IDLE` on **Optional Features** during installation",
|
||||
"Darwin": "reinstall Python from https://www.python.org/downloads/mac-osx/ or `brew install python-tk`",
|
||||
"Darwin": "reinstall Python from https://www.python.org/downloads/macos/ or `brew install python-tk`",
|
||||
}.get(platform.system(), "Unknown OS. Check your Python installation.")
|
||||
|
||||
LOGGER.warning(f"WARNING ⚠️ Tkinter is not configured or supported. Potential fix: {install_cmd}")
|
||||
|
||||
@ -87,7 +87,7 @@ def generate_ddp_command(world_size, trainer):
|
||||
cmd (List[str]): The command to execute for distributed training.
|
||||
file (str): Path to the temporary file created for DDP training.
|
||||
"""
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218
|
||||
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
|
||||
@ -52,7 +52,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
||||
def box_iou(box1, box2, eps=1e-7):
|
||||
"""
|
||||
Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py.
|
||||
Based on https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py.
|
||||
|
||||
Args:
|
||||
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user