Skip to content

Instantly share code, notes, and snippets.

@LukeAI
Last active July 9, 2023 08:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save LukeAI/6af4984c79a7534c9c1330958545367c to your computer and use it in GitHub Desktop.
Save LukeAI/6af4984c79a7534c9c1330958545367c to your computer and use it in GitHub Desktop.
How to process a dir of images with SAM and save visualisations of their masks
#!/usr/bin/env python
from __future__ import annotations
import os
from pathlib import Path
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import cv2
import numpy as np
import torch
from tqdm import tqdm
# config
in_dir = 'my_images'
out_dir = 'segmented'
sam_model = "vit_l"
sam_check = "sam_vit_l_0b3195.pth"
#sam_model = "vit_h"
#sam_check = "sam_vit_h_4b8939.pth"
#sam_model = "vit_b"
#sam_check = "sam_vit_b_01ec64.pth"
device="cuda"
transparency = 0.3
max_masks = 300
# sam generator params
points_per_batch=64
points_per_side=64
pred_iou_thresh=0.86
stability_score_thresh=0.92
crop_n_layers=1
crop_n_points_downscale_factor=2
min_mask_region_area=100
# list of random colors
colors = []
for i in range(max_masks):
colors.append(np.random.random((3)))
def draw_segmentation(anns):
if len(anns) == 0:
return
h, w = anns[0]['segmentation'].shape
image = np.zeros((h, w, 3), dtype=np.float64)
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
no_masks = min(len(sorted_anns), max_masks)
for i in range(no_masks):
# true/false segmentation
seg = sorted_anns[i]['segmentation']
# set this segmentation a random color
image[seg] = colors[i]
return image
def process_image(img_path, out_path, mask_generator):
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# mask generator wants the default uint8 image
masks = mask_generator.generate(image)
# convert to float64
image = image.astype(np.float64) / 255
seg = draw_segmentation(masks)
# add segmentation image on top of original image
image += transparency * seg
# convert back to uint8 for display/save
image = (255 * image).astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# cv2.imshow("my img", image)
# cv2.waitKey(-1)
cv2.imwrite(out_path, image)
if __name__ == "__main__":
# make sure output dir exists
if not os.path.exists(out_dir):
os.makedirs(out_dirs)
# load SAM model + create mask generator
sam = sam_model_registry[sam_model](checkpoint=sam_check)
sam.to(device=device)
sam = torch.compile(sam)
mask_generator = SamAutomaticMaskGenerator(sam,
points_per_side=points_per_side,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
crop_n_layers=crop_n_layers,
crop_n_points_downscale_factor=crop_n_points_downscale_factor,
min_mask_region_area=min_mask_region_area)
# process input directory
for img in tqdm(os.listdir(in_dir)):
# change extension of output image to .png
out_img = Path(img).stem + ".png"
out_img = os.path.join(out_dir, out_img)
# if we can read/decode this file as an image
in_img = os.path.join(in_dir, img)
if cv2.haveImageReader(in_img):
process_image(in_img, out_img, mask_generator)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment