From 3b2427ff632c5eb4ffcc388dc4e07692e75bfa4c Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Tue, 29 Apr 2025 16:32:36 +0800 Subject: [PATCH] YOLOE: Fix visual prompt training (#20413) Co-authored-by: UltralyticsAssistant --- .gitignore | 1 + docs/en/models/yoloe.md | 4 ++-- ultralytics/data/dataset.py | 2 +- ultralytics/models/yolo/yoloe/train.py | 3 ++- ultralytics/utils/loss.py | 7 ++++--- 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 35d7c17918..809c3380b6 100644 --- a/.gitignore +++ b/.gitignore @@ -150,6 +150,7 @@ wandb/ weights/ *.weights *.pt +*.ts *.pb *.onnx *.engine diff --git a/docs/en/models/yoloe.md b/docs/en/models/yoloe.md index 2779eea66b..95263d92d6 100644 --- a/docs/en/models/yoloe.md +++ b/docs/en/models/yoloe.md @@ -462,7 +462,7 @@ Model validation on a dataset is streamlined as follows: ```python from ultralytics import YOLOE - from ultralytics.models.yolo.yoloe import YOLOEVPTrainer + from ultralytics.models.yolo.yoloe import YOLOESegVPTrainer data = dict( train=dict( @@ -503,7 +503,7 @@ Model validation on a dataset is streamlined as follows: weight_decay=0.025, momentum=0.9, workers=4, - trainer=YOLOEVPTrainer, + trainer=YOLOESegVPTrainer, # use YOLOEVPTrainer if converted to detection model device="0,1,2,3,4,5,6,7", freeze=freeze, ) diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py index c62c3f92a1..2c03ebf788 100644 --- a/ultralytics/data/dataset.py +++ b/ultralytics/data/dataset.py @@ -443,7 +443,7 @@ class GroundingDataset(YOLODataset): """ assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks" self.json_file = json_file - super().__init__(*args, task=task, data={}, **kwargs) + super().__init__(*args, task=task, data={"channels": 3}, **kwargs) def get_img_files(self, img_path): """ diff --git a/ultralytics/models/yolo/yoloe/train.py b/ultralytics/models/yolo/yoloe/train.py index d2b1c61cf3..c9b8e6c505 100644 --- a/ultralytics/models/yolo/yoloe/train.py +++ b/ultralytics/models/yolo/yoloe/train.py @@ -291,8 +291,9 @@ class YOLOETrainerFromScratch(YOLOETrainer): # NOTE: to make training work properly, set `nc` and `names` final_data["nc"] = data["val"][0]["nc"] final_data["names"] = data["val"][0]["names"] - # NOTE: add path with lvis path + # NOTE: add path with lvis path and image channels final_data["path"] = data["val"][0]["path"] + final_data["channels"] = data["val"][0]["channels"] self.data = final_data if self.args.single_cls: # consistent with base trainer LOGGER.info("Overriding class names with single class.") diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 1ed2178a89..8bfd48a20c 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -794,15 +794,16 @@ class TVPSegmentLoss(TVPDetectLoss): def __init__(self, model): """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model.""" + super().__init__(model) self.vp_criterion = v8SegmentationLoss(model) def __call__(self, preds, batch): """Calculate the loss for text-visual prompt segmentation.""" feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] - assert self.tp_criterion.reg_max == self.vp_criterion.reg_max + assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it - if self.tp_criterion.reg_max * 4 + self.tp_criterion.nc == feats[0].shape[1]: - loss = torch.zeros(4, device=self.tp_criterion.device, requires_grad=True) + if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]: + loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True) return loss, loss.detach() vp_feats = self._get_vp_features(feats)