mirror of
https://github.com/ultralytics/ultralytics.git
synced 2025-09-15 15:48:41 +08:00
ultralytics 8.2.60 refactor process_mask_upsample (#14474)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
0822710185
commit
dcde8bd23d
@ -99,10 +99,6 @@ keywords: Ultralytics, utility operations, non-max suppression, bounding box tra
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.utils.ops.process_mask_upsample
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.utils.ops.process_mask
|
||||
|
||||
<br><br>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.59"
|
||||
__version__ = "8.2.60"
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
if self.args.save_json:
|
||||
check_requirements("pycocotools>=2.0.6")
|
||||
# more accurate vs faster
|
||||
self.process = ops.process_mask_upsample if self.args.save_json or self.args.save_txt else ops.process_mask
|
||||
self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
|
||||
self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
||||
|
||||
def get_desc(self):
|
||||
|
||||
@ -652,27 +652,6 @@ def crop_mask(masks, boxes):
|
||||
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
||||
|
||||
|
||||
def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
|
||||
but is slower.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The upsampled masks.
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.0)
|
||||
|
||||
|
||||
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
||||
"""
|
||||
Apply masks to bounding boxes using the output of the mask head.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user