ultralytics 8.2.52 fix CenterCrop transforms for PIL Image inputs (#14308)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Lucas Buligon Antunes <lukasbuligonantunes@gmail.com>
This commit is contained in:
Glenn Jocher 2024-07-10 03:00:14 +02:00 committed by GitHub
parent 755dcd6ca0
commit 997f2c92cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 38 additions and 30 deletions

View File

@ -8,13 +8,13 @@ keywords: Computer Vision Models, AI Model Monitoring, Data Drift Detection, Ano
## Introduction
If you are here, we can assume you've completed many [steps in your computer vision project](./steps-of-a-cv-project.md): from [gathering requirements](./defining-project-goals.md), [annotating data](./data-collection-and-annotation.md), and [training the model](./model-training-tips.md) to finally [deploying](./model-deployment-practices.md) it. Your application is now running in production, but your project doesn't end here. The most important part of a computer vision project is making sure your model continues to fulfill your [project's objectives](./defining-project-goals.md) over time, and thats where monitoring, maintaining, and documenting your computer vision model enters the picture.
If you are here, we can assume you've completed many [steps in your computer vision project](./steps-of-a-cv-project.md): from [gathering requirements](./defining-project-goals.md), [annotating data](./data-collection-and-annotation.md), and [training the model](./model-training-tips.md) to finally [deploying](./model-deployment-practices.md) it. Your application is now running in production, but your project doesn't end here. The most important part of a computer vision project is making sure your model continues to fulfill your [project's objectives](./defining-project-goals.md) over time, and that's where monitoring, maintaining, and documenting your computer vision model enters the picture.
In this guide, we'll take a closer look at how you can maintain your computer vision models after deployment. We'll explore how model monitoring can help you catch problems early on, how to keep your model accurate and up-to-date, and why documentation is important for troubleshooting.
## Model Monitoring is Key
Keeping a close eye on your deployed computer vision models is essential. Without proper monitoring, models can lose accuracy. A common issue is data distribution shift or data drift, where the data the model encounters changes from what it was trained on. When the model has to make predictions on data it doesnt recognize, it can lead to misinterpretations and poor performance. Outliers, or unusual data points, can also throw off the models accuracy.
Keeping a close eye on your deployed computer vision models is essential. Without proper monitoring, models can lose accuracy. A common issue is data distribution shift or data drift, where the data the model encounters changes from what it was trained on. When the model has to make predictions on data it doesn't recognize, it can lead to misinterpretations and poor performance. Outliers, or unusual data points, can also throw off the model's accuracy.
Regular model monitoring helps developers track the [model's performance](./model-evaluation-insights.md), spot anomalies, and quickly address problems like data drift. It also helps manage resources by indicating when updates are needed, avoiding expensive overhauls, and keeping the model relevant.
@ -22,9 +22,9 @@ Regular model monitoring helps developers track the [model's performance](./mode
Here are some best practices to keep in mind while monitoring your computer vision model in production:
- **Track Performance Regularly**: Continuously monitor the models performance to detect changes over time.
- **Track Performance Regularly**: Continuously monitor the model's performance to detect changes over time.
- **Double Check the Data Quality**: Check for missing values or anomalies in the data.
- **Use Diverse Data Sources**: Monitor data from various sources to get a comprehensive view of the models performance.
- **Use Diverse Data Sources**: Monitor data from various sources to get a comprehensive view of the model's performance.
- **Combine Monitoring Techniques**: Use a mix of drift detection algorithms and rule-based approaches to identify a wide range of issues.
- **Monitor Inputs and Outputs**: Keep an eye on both the data the model processes and the results it produces to make sure everything is functioning correctly.
- **Set Up Alerts**: Implement alerts for unusual behavior, such as performance drops, to be able to make quick corrective actions.
@ -67,7 +67,7 @@ Data drift detection is a concept that helps identify when the statistical prope
Here are several methods to detect data drift:
**Continuous Monitoring**: Regularly monitor the models input data and outputs for signs of drift. Track key metrics and compare them against historical data to identify significant changes.
**Continuous Monitoring**: Regularly monitor the model's input data and outputs for signs of drift. Track key metrics and compare them against historical data to identify significant changes.
**Statistical Techniques**: Use methods like the Kolmogorov-Smirnov test or Population Stability Index (PSI) to detect changes in data distributions. These tests compare the distribution of new data with the training data to identify significant differences.
@ -111,7 +111,7 @@ These are some of the key elements that should be included in project documentat
- **[Training Process](./model-training-tips.md)**: Document the training procedure, including the datasets used, training parameters, and loss functions. Explain how the model was trained and any challenges encountered during training.
- **[Evaluation Metrics](./model-evaluation-insights.md)**: Specify the metrics used to evaluate the model's performance, such as accuracy, precision, recall, and F1-score. Include performance results and an analysis of these metrics.
- **[Deployment Steps](./model-deployment-options.md)**: Outline the steps taken to deploy the model, including the tools and platforms used, deployment configurations, and any specific challenges or considerations.
- **Monitoring and Maintenance Procedure**: Provide a detailed plan for monitoring the models performance post-deployment. Include methods for detecting and addressing data and model drift, and describe the process for regular updates and retraining.
- **Monitoring and Maintenance Procedure**: Provide a detailed plan for monitoring the model's performance post-deployment. Include methods for detecting and addressing data and model drift, and describe the process for regular updates and retraining.
### Tools for Documentation

View File

@ -90,7 +90,7 @@ def test_predict_img(model_name):
batch = [
str(SOURCE), # filename
Path(SOURCE), # Path
"https://ultralytics.com/images/zidane.jpg" if ONLINE else SOURCE, # URI
"https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg" if ONLINE else SOURCE, # URI
cv2.imread(str(SOURCE)), # OpenCV
Image.open(SOURCE), # PIL
np.zeros((320, 640, 3), dtype=np.uint8), # numpy
@ -149,7 +149,7 @@ def test_track_stream():
Note imgsz=160 required for tracking for higher confidence and better matches.
"""
video_url = "https://ultralytics.com/assets/decelera_portrait_min.mov"
video_url = "https://github.com/ultralytics/yolov5/releases/download/v1.0/decelera_portrait_min.mov"
model = YOLO(MODEL)
model.track(video_url, imgsz=160, tracker="bytetrack.yaml")
model.track(video_url, imgsz=160, tracker="botsort.yaml", save_frames=True) # test frame saving also

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.51"
__version__ = "8.2.52"
import os

View File

@ -21,4 +21,4 @@ names:
3: zebra
# Download script/URL (optional)
download: https://ultralytics.com/assets/african-wildlife.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/african-wildlife.zip

View File

@ -19,4 +19,4 @@ names:
1: positive
# Download script/URL (optional)
download: https://ultralytics.com/assets/brain-tumor.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/brain-tumor.zip

View File

@ -40,4 +40,4 @@ names:
22: wheel
# Download script/URL (optional)
download: https://ultralytics.com/assets/carparts-seg.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/carparts-seg.zip

View File

@ -97,4 +97,4 @@ names:
79: toothbrush
# Download script/URL (optional)
download: https://ultralytics.com/assets/coco128-seg.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128-seg.zip

View File

@ -97,4 +97,4 @@ names:
79: toothbrush
# Download script/URL (optional)
download: https://ultralytics.com/assets/coco128.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip

View File

@ -22,4 +22,4 @@ names:
0: person
# Download script/URL (optional)
download: https://ultralytics.com/assets/coco8-pose.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco8-pose.zip

View File

@ -97,4 +97,4 @@ names:
79: toothbrush
# Download script/URL (optional)
download: https://ultralytics.com/assets/coco8-seg.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco8-seg.zip

View File

@ -97,4 +97,4 @@ names:
79: toothbrush
# Download script/URL (optional)
download: https://ultralytics.com/assets/coco8.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco8.zip

View File

@ -18,4 +18,4 @@ names:
0: crack
# Download script/URL (optional)
download: https://ultralytics.com/assets/crack-seg.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/crack-seg.zip

View File

@ -18,4 +18,4 @@ names:
0: package
# Download script/URL (optional)
download: https://ultralytics.com/assets/package-seg.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/package-seg.zip

View File

@ -17,4 +17,4 @@ names:
0: signature
# Download script/URL (optional)
download: https://ultralytics.com/assets/signature.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/signature.zip

View File

@ -21,4 +21,4 @@ names:
0: tiger
# Download script/URL (optional)
download: https://ultralytics.com/assets/tiger-pose.zip
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/tiger-pose.zip

View File

@ -1401,6 +1401,8 @@ class CenterCrop:
Returns:
(numpy.ndarray): The center-cropped and resized image as a numpy array.
"""
if isinstance(im, Image.Image): # convert from PIL to numpy array if required
im = np.asarray(im)
imh, imw = im.shape[:2]
m = min(imh, imw) # min dimension
top, left = (imh - m) // 2, (imw - m) // 2

View File

@ -15,6 +15,7 @@ from torch.utils.data import ConcatDataset
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from ultralytics.utils.torch_utils import TORCH_1_13
from .augment import (
Compose,
@ -263,7 +264,7 @@ class YOLOMultiModalDataset(YOLODataset):
super().__init__(*args, data=data, task=task, **kwargs)
def update_labels_info(self, label):
"""Add texts information for multi modal model training."""
"""Add texts information for multi-modal model training."""
labels = super().update_labels_info(label)
# NOTE: some categories are concatenated with its synonyms by `/`.
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
@ -296,10 +297,10 @@ class GroundingDataset(YOLODataset):
with open(self.json_file, "r") as f:
annotations = json.load(f)
images = {f'{x["id"]:d}': x for x in annotations["images"]}
imgToAnns = defaultdict(list)
img_to_anns = defaultdict(list)
for ann in annotations["annotations"]:
imgToAnns[ann["image_id"]].append(ann)
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
img_to_anns[ann["image_id"]].append(ann)
for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
im_file = Path(self.img_path) / f
@ -416,7 +417,10 @@ class ClassificationDataset:
import torchvision # scope for faster 'import ultralytics'
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
if TORCH_1_13: # 'allow_empty' argument first introduced in torch 1.13
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
else:
self.base = torchvision.datasets.ImageFolder(root=root)
self.samples = self.base.samples
self.root = self.base.root

View File

@ -195,7 +195,7 @@ class RF100Benchmark:
(shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
os.chdir("rf-100")
os.mkdir("ultralytics-benchmarks")
safe_download("https://ultralytics.com/assets/datasets_links.txt")
safe_download("https://github.com/ultralytics/yolov5/releases/download/v1.0/datasets_links.txt")
with open(ds_link_txt, "r") as file:
for line in file:

View File

@ -315,7 +315,7 @@ def check_font(font="Arial.ttf"):
return matches[0]
# Download to USER_CONFIG_DIR if missing
url = f"https://ultralytics.com/assets/{name}"
url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{name}"
if downloads.is_url(url, check=True):
downloads.safe_download(url=url, file=file)
return file

View File

@ -194,12 +194,14 @@ def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=Fals
return path # return unzip dir
def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True):
def check_disk_space(
url="https://github.com/ultralytics/yolov5/releases/download/v1.0/coco8.zip", path=Path.cwd(), sf=1.5, hard=True
):
"""
Check if there is sufficient disk space to download and store a file.
Args:
url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco128.zip'.
url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco8.zip'.
path (str | Path, optional): The path or drive to check the available free space on.
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True.