Add custom example for classification augmentation (#20949)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2025-06-07 12:26:53 +08:00 committed by GitHub
parent 766dc10b7d
commit 9fed5624f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 0 deletions

View File

@ -382,3 +382,7 @@ You can find the full list of applied transformations in our [technical document
### When starting a training, I don't see any reference to albumentations. Why?
Check if the `albumentations` package is installed. If not, you can install it by running `pip install albumentations`. Once installed, the package should be automatically detected and used by Ultralytics.
### How do I customize my augmentations?
You can customize augmentations by creating a custom dataset class and trainer. For example, you can replace the default Ultralytics classification augmentations with PyTorch's [torchvision.transforms.Resize](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) or other transforms. See the [custom training example](../tasks/classify.md#train) in the classification documentation for implementation details.

View File

@ -72,6 +72,54 @@ Train YOLO11n-cls on the MNIST160 dataset for 100 [epochs](https://www.ultralyti
yolo classify train data=mnist160 model=yolo11n-cls.yaml pretrained=yolo11n-cls.pt epochs=100 imgsz=64
```
!!! tip
Ultralytics YOLO classification uses [torchvision.transforms.RandomResizedCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.RandomResizedCrop.html) for training augmentation and [torchvision.transforms.CenterCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.CenterCrop.html) for validation/inference.
For images with extreme aspect ratios, consider using [torchvision.transforms.Resize](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) instead. The example below shows how to customize augmentations for classification training.
```python
import torch
import torchvision.transforms as T
from ultralytics.data.dataset import ClassificationDataset
from ultralytics.models.yolo.classify import ClassificationTrainer
class CustomizedDataset(ClassificationDataset):
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
super().__init__(root, args, augment, prefix)
train_transforms = T.Compose(
[
T.Resize((args.imgsz, args.imgsz)),
T.RandomHorizontalFlip(p=args.fliplr),
T.RandomVerticalFlip(p=args.flipud),
T.RandAugment(interpolation=T.InterpolationMode.BILINEAR),
T.ColorJitter(brightness=args.hsv_v, contrast=args.hsv_v, saturation=args.hsv_s, hue=args.hsv_h),
T.ToTensor(),
T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),
T.RandomErasing(p=args.erasing, inplace=True),
]
)
val_transforms = T.Compose(
[
T.Resize((args.imgsz, args.imgsz)),
T.ToTensor(),
T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),
]
)
self.torch_transforms = train_transforms if augment else val_transforms
class CustomizedTrainer(ClassificationTrainer):
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
return CustomizedDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
from ultralytics import YOLO
model = YOLO("yolo11n-cls.pt")
model.train(data="imagenet1000", trainer=CustomizedTrainer, epochs=10, imgsz=224, batch=64)
```
### Dataset format
YOLO classification dataset format can be found in detail in the [Dataset Guide](../datasets/classify/index.md).