Skip to content

Instantly share code, notes, and snippets.

@humpydonkey
Last active November 17, 2023 19:51
Show Gist options
  • Save humpydonkey/905bde3156d4c0119de33c9f3b474be7 to your computer and use it in GitHub Desktop.
Save humpydonkey/905bde3156d4c0119de33c9f3b474be7 to your computer and use it in GitHub Desktop.
Extract segmentation masks via rmgb
import logging
import numpy as np
import PIL.Image
import matplotlib.pyplot as plt
from datasets import load_dataset
from rembg import remove
import cv2
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 1, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
kernel = np.ones((9,9),np.uint8)
def get_largest_component(image: np.ndarray) -> np.ndarray:
image = image.astype('uint8')
nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(image, connectivity=4)
sizes = stats[:, -1]
max_label = 1
max_size = sizes[1]
for i in range(2, nb_components):
if sizes[i] > max_size:
max_label = i
max_size = sizes[i]
result_img = np.zeros(output.shape)
result_img[output == max_label] = 255
return result_img == 255
def get_mask(img: PIL.Image.Image) -> np.ndarray:
"""Get the mask from the input image
1. remove the background of the input image
2. convert the output of step #1 to grayscale as a mask
3. get the largest connected component of step #2 as the mask
"""
out = remove(img).convert("L")
mask = np.array(out)
mask[mask > 0] = 255
processed = mask
# processed = cv2.morphologyEx(processed, cv2.MORPH_OPEN, kernel)
processed = cv2.morphologyEx(processed, cv2.MORPH_CLOSE, kernel)
# Convert to a boolean mask
processed = processed > 0
return get_largest_component(processed)
def evaluate_all(dataset):
# Run through all images and check the quality of the mask
for i, row in enumerate(dataset):
image = row["image"]
# mask = get_best_mask(image)
mask = get_mask(image)
plt.figure(figsize=(10,10))
plt.title(f"Row index: {i}", fontsize=18)
plt.imshow(image)
show_mask(mask, plt.gca())
# show_points(get_points(image)[0], plt.gca())
# show_box(get_boxes(image), plt.gca())
plt.axis('off')
plt.show()
def evaluate_one(dataset):
img=dataset[28]["image"]
mask = get_mask(img)
# PIL.Image.fromarray(mask)
plt.figure(figsize=(10,10))
plt.imshow(img)
show_mask(mask, plt.gca())
if __main__:
# Test on a single image
dataset = load_dataset("Raspberry-ai/monse-v4")["train"]
evaluate_one(dataset)
# Run extract
new_dataset = dataset.map(lambda row: {
"image": row["image"],
"text": row["text"],
"mask": PIL.Image.fromarray(get_mask(row["image"]))
}, batched=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment