Fix mask shape mismatch when validating with mask_ratio=1 (#22037)

This commit is contained in:
Jing Qiu 2025-09-10 23:52:55 +08:00 committed by GitHub
parent a6309f70e5
commit fd4acf7bd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 4 deletions

View File

@ -67,7 +67,15 @@ def test_detect():
def test_segment():
"""Test image segmentation training, validation, and prediction pipelines using YOLO models."""
overrides = {"data": "coco8-seg.yaml", "model": "yolo11n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False}
overrides = {
"data": "coco8-seg.yaml",
"model": "yolo11n-seg.yaml",
"imgsz": 32,
"epochs": 1,
"save": False,
"mask_ratio": 1,
"overlap_mask": False,
}
cfg = get_cfg(DEFAULT_CFG)
cfg.data = "coco8-seg.yaml"
cfg.imgsz = 32

View File

@ -140,9 +140,11 @@ class SegmentationValidator(DetectionValidator):
masks = (masks == index).float()
else:
masks = batch["masks"][batch["batch_idx"] == si]
if nl and self.process is ops.process_mask_native:
masks = F.interpolate(masks[None], prepared_batch["imgsz"], mode="bilinear", align_corners=False)[0]
masks = masks.gt_(0.5)
if nl:
mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
if masks.shape[1:] != mask_size:
masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
masks = masks.gt_(0.5)
prepared_batch["masks"] = masks
return prepared_batch