YOLOE: Fix visual prompt training (#20413)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Laughing 2025-04-29 16:32:36 +08:00 committed by GitHub
parent e700646ea2
commit 3b2427ff63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 10 additions and 7 deletions

1
.gitignore vendored
View File

@ -150,6 +150,7 @@ wandb/
weights/
*.weights
*.pt
*.ts
*.pb
*.onnx
*.engine

View File

@ -462,7 +462,7 @@ Model validation on a dataset is streamlined as follows:
```python
from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOEVPTrainer
from ultralytics.models.yolo.yoloe import YOLOESegVPTrainer
data = dict(
train=dict(
@ -503,7 +503,7 @@ Model validation on a dataset is streamlined as follows:
weight_decay=0.025,
momentum=0.9,
workers=4,
trainer=YOLOEVPTrainer,
trainer=YOLOESegVPTrainer, # use YOLOEVPTrainer if converted to detection model
device="0,1,2,3,4,5,6,7",
freeze=freeze,
)

View File

@ -443,7 +443,7 @@ class GroundingDataset(YOLODataset):
"""
assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
self.json_file = json_file
super().__init__(*args, task=task, data={}, **kwargs)
super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
def get_img_files(self, img_path):
"""

View File

@ -291,8 +291,9 @@ class YOLOETrainerFromScratch(YOLOETrainer):
# NOTE: to make training work properly, set `nc` and `names`
final_data["nc"] = data["val"][0]["nc"]
final_data["names"] = data["val"][0]["names"]
# NOTE: add path with lvis path
# NOTE: add path with lvis path and image channels
final_data["path"] = data["val"][0]["path"]
final_data["channels"] = data["val"][0]["channels"]
self.data = final_data
if self.args.single_cls: # consistent with base trainer
LOGGER.info("Overriding class names with single class.")

View File

@ -794,15 +794,16 @@ class TVPSegmentLoss(TVPDetectLoss):
def __init__(self, model):
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
super().__init__(model)
self.vp_criterion = v8SegmentationLoss(model)
def __call__(self, preds, batch):
"""Calculate the loss for text-visual prompt segmentation."""
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
assert self.tp_criterion.reg_max == self.vp_criterion.reg_max
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
if self.tp_criterion.reg_max * 4 + self.tp_criterion.nc == feats[0].shape[1]:
loss = torch.zeros(4, device=self.tp_criterion.device, requires_grad=True)
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
return loss, loss.detach()
vp_feats = self._get_vp_features(feats)