mirror of
https://github.com/ultralytics/ultralytics.git
synced 2025-09-15 15:48:41 +08:00
YOLOE: Fix visual prompt training (#20413)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
e700646ea2
commit
3b2427ff63
1
.gitignore
vendored
1
.gitignore
vendored
@ -150,6 +150,7 @@ wandb/
|
||||
weights/
|
||||
*.weights
|
||||
*.pt
|
||||
*.ts
|
||||
*.pb
|
||||
*.onnx
|
||||
*.engine
|
||||
|
||||
@ -462,7 +462,7 @@ Model validation on a dataset is streamlined as follows:
|
||||
|
||||
```python
|
||||
from ultralytics import YOLOE
|
||||
from ultralytics.models.yolo.yoloe import YOLOEVPTrainer
|
||||
from ultralytics.models.yolo.yoloe import YOLOESegVPTrainer
|
||||
|
||||
data = dict(
|
||||
train=dict(
|
||||
@ -503,7 +503,7 @@ Model validation on a dataset is streamlined as follows:
|
||||
weight_decay=0.025,
|
||||
momentum=0.9,
|
||||
workers=4,
|
||||
trainer=YOLOEVPTrainer,
|
||||
trainer=YOLOESegVPTrainer, # use YOLOEVPTrainer if converted to detection model
|
||||
device="0,1,2,3,4,5,6,7",
|
||||
freeze=freeze,
|
||||
)
|
||||
|
||||
@ -443,7 +443,7 @@ class GroundingDataset(YOLODataset):
|
||||
"""
|
||||
assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
|
||||
self.json_file = json_file
|
||||
super().__init__(*args, task=task, data={}, **kwargs)
|
||||
super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
|
||||
|
||||
def get_img_files(self, img_path):
|
||||
"""
|
||||
|
||||
@ -291,8 +291,9 @@ class YOLOETrainerFromScratch(YOLOETrainer):
|
||||
# NOTE: to make training work properly, set `nc` and `names`
|
||||
final_data["nc"] = data["val"][0]["nc"]
|
||||
final_data["names"] = data["val"][0]["names"]
|
||||
# NOTE: add path with lvis path
|
||||
# NOTE: add path with lvis path and image channels
|
||||
final_data["path"] = data["val"][0]["path"]
|
||||
final_data["channels"] = data["val"][0]["channels"]
|
||||
self.data = final_data
|
||||
if self.args.single_cls: # consistent with base trainer
|
||||
LOGGER.info("Overriding class names with single class.")
|
||||
|
||||
@ -794,15 +794,16 @@ class TVPSegmentLoss(TVPDetectLoss):
|
||||
|
||||
def __init__(self, model):
|
||||
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
|
||||
super().__init__(model)
|
||||
self.vp_criterion = v8SegmentationLoss(model)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate the loss for text-visual prompt segmentation."""
|
||||
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
||||
assert self.tp_criterion.reg_max == self.vp_criterion.reg_max
|
||||
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
||||
|
||||
if self.tp_criterion.reg_max * 4 + self.tp_criterion.nc == feats[0].shape[1]:
|
||||
loss = torch.zeros(4, device=self.tp_criterion.device, requires_grad=True)
|
||||
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
|
||||
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
|
||||
return loss, loss.detach()
|
||||
|
||||
vp_feats = self._get_vp_features(feats)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user