Skip to content

Instantly share code, notes, and snippets.

@humpydonkey
Last active November 5, 2023 22:21
Show Gist options
  • Save humpydonkey/c61f505c074be4a87106ded7521c607b to your computer and use it in GitHub Desktop.
Save humpydonkey/c61f505c074be4a87106ded7521c607b to your computer and use it in GitHub Desktop.
Extract segmentation masks via SAM
import logging
import numpy as np
import PIL.Image
from segment_anything import SamPredictor, sam_model_registry
import matplotlib.pyplot as plt
logger = logging.getLogger(__name__)
model_type = "default"
_SAM_CKPT = "sam_vit_h_4b8939.pth"
sam = sam_model_registry[model_type](checkpoint=_SAM_CKPT).to(device="cuda")
predictor = SamPredictor(sam)
def predict(img_pil: PIL.Image.Image,points=None, labels=None, boxes=None) -> np.ndarray:
if boxes is not None:
boxes = boxes[None, :]
predictor.set_image(np.asarray(img_pil))
masks, scores, logits = predictor.predict(
point_coords=points,
point_labels=labels,
box=boxes,
multimask_output=True,
)
return masks, scores
def get_best_mask(img_pil: PIL.Image.Image) -> np.ndarray:
points, labels = get_points(img_pil)
bbox = get_boxes(img_pil)
masks, scores = predict(img_pil, points=points, labels=labels, boxes=bbox)
best_mask = masks[0]
best_score = scores[0]
for i in range(1, len(masks)):
if scores[i] > best_score:
best_mask = masks[i]
best_score = scores[i]
return best_mask
def get_points(img: PIL.Image.Image) -> (np.ndarray, np.ndarray):
# center = [img.size[0] // 2, img.size[1] // 2]
# up = [center[0] - 20, center[1]]
# down = [center[0] + 20, center[1]]
# return np.array([center, up, down])
offset = 35
top_left = [0 + offset, 0 + offset]
top_right = [img.size[0] - offset, 0 + offset]
bottom_left = [0 + offset, img.size[1] - offset]
bottom_right = [img.size[0] - offset, img.size[1] - offset]
points = np.array([top_left, top_right, bottom_left, bottom_right])
labels = np.array([0, 0, 0, 0])
return points, labels
def get_boxes(img_pil: PIL.Image.Image) -> np.ndarray:
return np.array([0 + 25, 0, img_pil.size[0] - 25, img_pil.size[1]])
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 1, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, ax, marker_size=375):
pos_points = coords
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
if __main__:
# TODO: load your HF dataset
# Test on one image
img = img_pil=dataset[0]["image"]
points, labels = get_points(img)
masks, scores = predict(img, points=points, labels=labels, boxes=get_boxes(img))
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(img)
show_mask(mask, plt.gca())
show_points(points, plt.gca())
show_box(get_boxes(img), plt.gca())
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment