mirror of
https://github.com/ultralytics/ultralytics.git
synced 2025-09-15 15:48:41 +08:00
ultralytics 8.3.12 SAM and SAM2 multi-point prompts (#16643)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
1c4d788aa1
commit
b89d6f4070
@ -90,8 +90,17 @@ You can download the model [here](https://github.com/ChaoningZhang/MobileSAM/blo
|
||||
# Load the model
|
||||
model = SAM("mobile_sam.pt")
|
||||
|
||||
# Predict a segment based on a point prompt
|
||||
# Predict a segment based on a single point prompt
|
||||
model.predict("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
||||
|
||||
# Predict multiple segments based on multiple points prompt
|
||||
model.predict("ultralytics/assets/zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||
|
||||
# Predict a segment based on multiple points prompt per object
|
||||
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||
|
||||
# Predict a segment using both positive and negative prompts.
|
||||
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||
```
|
||||
|
||||
### Box Prompt
|
||||
@ -106,8 +115,17 @@ You can download the model [here](https://github.com/ChaoningZhang/MobileSAM/blo
|
||||
# Load the model
|
||||
model = SAM("mobile_sam.pt")
|
||||
|
||||
# Predict a segment based on a box prompt
|
||||
model.predict("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
||||
# Predict a segment based on a single point prompt
|
||||
model.predict("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
||||
|
||||
# Predict mutiple segments based on multiple points prompt
|
||||
model.predict("ultralytics/assets/zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||
|
||||
# Predict a segment based on multiple points prompt per object
|
||||
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||
|
||||
# Predict a segment using both positive and negative prompts.
|
||||
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||
```
|
||||
|
||||
We have implemented `MobileSAM` and `SAM` using the same API. For more usage information, please see the [SAM page](sam.md).
|
||||
|
||||
@ -58,8 +58,17 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
|
||||
# Run inference with bboxes prompt
|
||||
results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
||||
|
||||
# Run inference with points prompt
|
||||
results = model("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
||||
# Run inference with single point
|
||||
results = predictor(points=[900, 370], labels=[1])
|
||||
|
||||
# Run inference with multiple points
|
||||
results = predictor(points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||
|
||||
# Run inference with multiple points prompt per object
|
||||
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||
|
||||
# Run inference with negative points prompt
|
||||
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||
```
|
||||
|
||||
!!! example "Segment everything"
|
||||
@ -107,8 +116,16 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
|
||||
predictor.set_image("ultralytics/assets/zidane.jpg") # set with image file
|
||||
predictor.set_image(cv2.imread("ultralytics/assets/zidane.jpg")) # set with np.ndarray
|
||||
results = predictor(bboxes=[439, 437, 524, 709])
|
||||
|
||||
# Run inference with single point prompt
|
||||
results = predictor(points=[900, 370], labels=[1])
|
||||
|
||||
# Run inference with multiple points prompt
|
||||
results = predictor(points=[[400, 370], [900, 370]], labels=[[1, 1]])
|
||||
|
||||
# Run inference with negative points prompt
|
||||
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||
|
||||
# Reset image
|
||||
predictor.reset_image()
|
||||
```
|
||||
@ -245,6 +262,15 @@ model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
||||
|
||||
# Segment with points prompt
|
||||
model("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
||||
|
||||
# Segment with multiple points prompt
|
||||
model("ultralytics/assets/zidane.jpg", points=[[400, 370], [900, 370]], labels=[[1, 1]])
|
||||
|
||||
# Segment with multiple points prompt per object
|
||||
model("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||
|
||||
# Segment with negative points prompt.
|
||||
model("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||
```
|
||||
|
||||
Alternatively, you can run inference with SAM in the command line interface (CLI):
|
||||
|
||||
@ -97,9 +97,12 @@ def test_mobilesam():
|
||||
# Source
|
||||
source = ASSETS / "zidane.jpg"
|
||||
|
||||
# Predict a segment based on a point prompt
|
||||
# Predict a segment based on a 1D point prompt and 1D labels.
|
||||
model.predict(source, points=[900, 370], labels=[1])
|
||||
|
||||
# Predict a segment based on 3D points and 2D labels (multiple points per object).
|
||||
model.predict(source, points=[[[900, 370], [1000, 100]]], labels=[[1, 1]])
|
||||
|
||||
# Predict a segment based on a box prompt
|
||||
model.predict(source, bboxes=[439, 437, 524, 709], save=True)
|
||||
|
||||
|
||||
@ -127,9 +127,21 @@ def test_predict_sam():
|
||||
# Run inference with bboxes prompt
|
||||
model(SOURCE, bboxes=[439, 437, 524, 709], device=0)
|
||||
|
||||
# Run inference with points prompt
|
||||
# Run inference with no labels
|
||||
model(ASSETS / "zidane.jpg", points=[900, 370], device=0)
|
||||
|
||||
# Run inference with 1D points and 1D labels
|
||||
model(ASSETS / "zidane.jpg", points=[900, 370], labels=[1], device=0)
|
||||
|
||||
# Run inference with 2D points and 1D labels
|
||||
model(ASSETS / "zidane.jpg", points=[[900, 370]], labels=[1], device=0)
|
||||
|
||||
# Run inference with multiple 2D points and 1D labels
|
||||
model(ASSETS / "zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1], device=0)
|
||||
|
||||
# Run inference with 3D points and 2D labels (multiple points per object)
|
||||
model(ASSETS / "zidane.jpg", points=[[[900, 370], [1000, 100]]], labels=[[1, 1]], device=0)
|
||||
|
||||
# Create SAMPredictor
|
||||
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model=WEIGHTS_DIR / "mobile_sam.pt")
|
||||
predictor = SAMPredictor(overrides=overrides)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.3.11"
|
||||
__version__ = "8.3.12"
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@ -213,11 +213,14 @@ class Predictor(BasePredictor):
|
||||
Args:
|
||||
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
|
||||
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
@ -240,11 +243,15 @@ class Predictor(BasePredictor):
|
||||
points = points[None] if points.ndim == 1 else points
|
||||
# Assuming labels are all positive if users don't pass labels.
|
||||
if labels is None:
|
||||
labels = np.ones(points.shape[0])
|
||||
labels = np.ones(points.shape[:-1])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
assert (
|
||||
points.shape[-2] == labels.shape[-1]
|
||||
), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
|
||||
points *= r
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
points, labels = points[:, None, :], labels[:, None]
|
||||
if points.ndim == 2:
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
points, labels = points[:, None, :], labels[:, None]
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
|
||||
Loading…
Reference in New Issue
Block a user