mirror of
https://github.com/ultralytics/ultralytics.git
synced 2025-09-15 15:48:41 +08:00
Add custom example for classification augmentation (#20949)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
766dc10b7d
commit
9fed5624f8
@ -382,3 +382,7 @@ You can find the full list of applied transformations in our [technical document
|
||||
### When starting a training, I don't see any reference to albumentations. Why?
|
||||
|
||||
Check if the `albumentations` package is installed. If not, you can install it by running `pip install albumentations`. Once installed, the package should be automatically detected and used by Ultralytics.
|
||||
|
||||
### How do I customize my augmentations?
|
||||
|
||||
You can customize augmentations by creating a custom dataset class and trainer. For example, you can replace the default Ultralytics classification augmentations with PyTorch's [torchvision.transforms.Resize](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) or other transforms. See the [custom training example](../tasks/classify.md#train) in the classification documentation for implementation details.
|
||||
|
||||
@ -72,6 +72,54 @@ Train YOLO11n-cls on the MNIST160 dataset for 100 [epochs](https://www.ultralyti
|
||||
yolo classify train data=mnist160 model=yolo11n-cls.yaml pretrained=yolo11n-cls.pt epochs=100 imgsz=64
|
||||
```
|
||||
|
||||
!!! tip
|
||||
|
||||
Ultralytics YOLO classification uses [torchvision.transforms.RandomResizedCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.RandomResizedCrop.html) for training augmentation and [torchvision.transforms.CenterCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.CenterCrop.html) for validation/inference.
|
||||
For images with extreme aspect ratios, consider using [torchvision.transforms.Resize](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) instead. The example below shows how to customize augmentations for classification training.
|
||||
```python
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
|
||||
from ultralytics.data.dataset import ClassificationDataset
|
||||
from ultralytics.models.yolo.classify import ClassificationTrainer
|
||||
|
||||
|
||||
class CustomizedDataset(ClassificationDataset):
|
||||
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
|
||||
super().__init__(root, args, augment, prefix)
|
||||
train_transforms = T.Compose(
|
||||
[
|
||||
T.Resize((args.imgsz, args.imgsz)),
|
||||
T.RandomHorizontalFlip(p=args.fliplr),
|
||||
T.RandomVerticalFlip(p=args.flipud),
|
||||
T.RandAugment(interpolation=T.InterpolationMode.BILINEAR),
|
||||
T.ColorJitter(brightness=args.hsv_v, contrast=args.hsv_v, saturation=args.hsv_s, hue=args.hsv_h),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),
|
||||
T.RandomErasing(p=args.erasing, inplace=True),
|
||||
]
|
||||
)
|
||||
val_transforms = T.Compose(
|
||||
[
|
||||
T.Resize((args.imgsz, args.imgsz)),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),
|
||||
]
|
||||
)
|
||||
self.torch_transforms = train_transforms if augment else val_transforms
|
||||
|
||||
|
||||
class CustomizedTrainer(ClassificationTrainer):
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
||||
return CustomizedDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
||||
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO("yolo11n-cls.pt")
|
||||
model.train(data="imagenet1000", trainer=CustomizedTrainer, epochs=10, imgsz=224, batch=64)
|
||||
```
|
||||
|
||||
### Dataset format
|
||||
|
||||
YOLO classification dataset format can be found in detail in the [Dataset Guide](../datasets/classify/index.md).
|
||||
|
||||
Loading…
Reference in New Issue
Block a user