Fix SAM2DynamicInteractivePredictor example in docs (#21955)

Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Mohammed Yasin 2025-09-05 16:31:03 +08:00 committed by GitHub
parent e843b4c713
commit d61006647b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 5 deletions

View File

@ -260,7 +260,7 @@ It offers three significant enhancements:
predictor = SAM2DynamicInteractivePredictor(overrides=overrides, max_obj_num=10)
# Define a category by box prompt
predictor.inference(img="image1.jpg", bboxes=[[100, 100, 200, 200]], obj_ids=[1], update_memory=True)
predictor(source="image1.jpg", bboxes=[[100, 100, 200, 200]], obj_ids=[1], update_memory=True)
# Detect this particular object in a new image
results = predictor(source="image2.jpg")
@ -273,7 +273,7 @@ It offers three significant enhancements:
update_memory=True, # Add to memory
)
# Perform inference
results = predictor.inference(img="image5.jpg")
results = predictor(source="image5.jpg")
# Add refinement prompts to the same category to boost performance
# This helps when object appearance changes significantly

View File

@ -1761,7 +1761,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
@smart_inference_mode()
def inference(
self,
img: torch.Tensor | np.ndarray,
im: torch.Tensor | np.ndarray,
bboxes: list[list[float]] | None = None,
masks: torch.Tensor | np.ndarray | None = None,
points: list[list[float]] | None = None,
@ -1777,7 +1777,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
When update_memory is False, it will only run inference on the provided image without updating the memory.
Args:
img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
im (torch.Tensor | np.ndarray): The input image tensor or numpy array.
bboxes (List[List[float]] | None): Optional list of bounding boxes to update the memory.
masks (List[torch.Tensor | np.ndarray] | None): Optional masks to update the memory.
points (List[List[float]] | None): Optional list of points to update the memory, each point is [x, y].
@ -1789,7 +1789,7 @@ class SAM2DynamicInteractivePredictor(SAM2Predictor):
res_masks (torch.Tensor): The output masks in shape (C, H, W)
object_score_logits (torch.Tensor): Quality scores for each mask
"""
self.get_im_features(img)
self.get_im_features(im)
points, labels, masks = self._prepare_prompts(
dst_shape=self.imgsz,
src_shape=self.batch[1][0].shape[:2],