mirror of
https://github.com/ultralytics/ultralytics.git
synced 2025-09-15 15:48:41 +08:00
Fix mask shape mismatch when validating with mask_ratio=1 (#22037)
This commit is contained in:
parent
a6309f70e5
commit
fd4acf7bd0
@ -67,7 +67,15 @@ def test_detect():
|
||||
|
||||
def test_segment():
|
||||
"""Test image segmentation training, validation, and prediction pipelines using YOLO models."""
|
||||
overrides = {"data": "coco8-seg.yaml", "model": "yolo11n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False}
|
||||
overrides = {
|
||||
"data": "coco8-seg.yaml",
|
||||
"model": "yolo11n-seg.yaml",
|
||||
"imgsz": 32,
|
||||
"epochs": 1,
|
||||
"save": False,
|
||||
"mask_ratio": 1,
|
||||
"overlap_mask": False,
|
||||
}
|
||||
cfg = get_cfg(DEFAULT_CFG)
|
||||
cfg.data = "coco8-seg.yaml"
|
||||
cfg.imgsz = 32
|
||||
|
||||
@ -140,9 +140,11 @@ class SegmentationValidator(DetectionValidator):
|
||||
masks = (masks == index).float()
|
||||
else:
|
||||
masks = batch["masks"][batch["batch_idx"] == si]
|
||||
if nl and self.process is ops.process_mask_native:
|
||||
masks = F.interpolate(masks[None], prepared_batch["imgsz"], mode="bilinear", align_corners=False)[0]
|
||||
masks = masks.gt_(0.5)
|
||||
if nl:
|
||||
mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
|
||||
if masks.shape[1:] != mask_size:
|
||||
masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
|
||||
masks = masks.gt_(0.5)
|
||||
prepared_batch["masks"] = masks
|
||||
return prepared_batch
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user