Skip to content

Instantly share code, notes, and snippets.

@cip8
Last active May 26, 2023 17:07
Show Gist options
  • Save cip8/64a7503b079fbf4dc6ab5f807702838f to your computer and use it in GitHub Desktop.
Save cip8/64a7503b079fbf4dc6ab5f807702838f to your computer and use it in GitHub Desktop.
Custom mask to SAM
def get_grid_points(self, mask_arr: np.ndarray, pad_ratio: int = 50) -> np.ndarray:
"""
Returns a grid of points that are in the foreground of the given binary mask.
The padding between points in the grid is adjusted based on the size of the mask.
Parameters:
- mask: A 2D binary numpy array, where 1 represents foreground and 0 represents background.
- pad_ratio: Scaling factor (divisor) used to compute the padding between points in the grid. Larger values result in smaller padding.
Returns:
- A 2D numpy array where each row is a point (x, y) in the grid.
"""
# Convert mask to boolean values.
mask_arr = (mask_arr != 0).astype(bool)
# Calculate the padding based on the size of the mask
padding = int(np.sqrt(mask_arr.size) / pad_ratio)
# Create a grid of points spaced out by the padding
grid_y, grid_x = np.mgrid[
0 : mask_arr.shape[0] : padding, 0 : mask_arr.shape[1] : padding
]
# Flatten the grid arrays and stack them into a 2D array of points
points = np.vstack((grid_x.ravel(), grid_y.ravel())).T
# Select only the points that are in the foreground
foreground_points = points[mask_arr[points[:, 1], points[:, 0]] == 1]
return foreground_points
def resize_mask(
self, ref_mask: np.ndarray, longest_side: int = 256
) -> tuple[np.ndarray, int, int]:
"""
Resize an image to have its longest side equal to the specified value.
Args:
ref_mask (np.ndarray): The image to be resized.
longest_side (int, optional): The length of the longest side after resizing. Default is 256.
Returns:
tuple[np.ndarray, int, int]: The resized image and its new height and width.
"""
height, width = ref_mask.shape[:2]
if height > width:
new_height = longest_side
new_width = int(width * (new_height / height))
else:
new_width = longest_side
new_height = int(height * (new_width / width))
return (
cv2.resize(
ref_mask, (new_width, new_height), interpolation=cv2.INTER_NEAREST
),
new_height,
new_width,
)
def pad_mask(
self,
ref_mask: np.ndarray,
new_height: int,
new_width: int,
pad_all_sides: bool = False,
) -> np.ndarray:
"""
Add padding to an image to make it square.
Args:
ref_mask (np.ndarray): The image to be padded.
new_height (int): The height of the image after resizing.
new_width (int): The width of the image after resizing.
pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False.
Returns:
np.ndarray: The padded image.
"""
pad_height = 256 - new_height
pad_width = 256 - new_width
if pad_all_sides:
padding = (
(pad_height // 2, pad_height - pad_height // 2),
(pad_width // 2, pad_width - pad_width // 2),
)
else:
padding = ((0, pad_height), (0, pad_width))
# Padding value defaults to '0' when the `np.pad`` mode is set to 'constant'.
return np.pad(ref_mask, padding, mode="constant")
def reference_to_sam_mask(
self, ref_mask: np.ndarray, threshold: int = 127, pad_all_sides: bool = False
) -> np.ndarray:
"""
Convert a grayscale mask to a binary mask, resize it to have its longest side equal to 256, and add padding to make it square.
Args:
ref_mask (np.ndarray): The grayscale mask to be processed.
threshold (int, optional): The threshold value for the binarization. Default is 127.
pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False.
Returns:
np.ndarray: The processed binary mask.
"""
# Convert a grayscale mask to a binary mask.
# Values over the threshold are set to 1, values below are set to -1.
ref_mask = np.clip((ref_mask > threshold) * 2 - 1, -1, 1) # type: ignore
# Resize to have the longest side 256.
resized_mask, new_height, new_width = self.resize_mask(ref_mask)
# Add padding to make it square.
square_mask = self.pad_mask(resized_mask, new_height, new_width, pad_all_sides)
return square_mask
[...]
# Obtain SAM compatible mask.
sam_mask: np.ndarray = self.reference_to_sam_mask(ref_mask)
# Initialize SAM predictor and set the image.
predictor: SamPredictor = SamPredictor(self._models["sam"])
predictor.set_image(img_arr) # bbox cut image!
# Expand SAM mask's dimensions to 1xHxW (1x256x256).
sam_mask = np.expand_dims(sam_mask, axis=0)
# Run SAM predictor.
masks, scores, logits = predictor.predict(
multimask_output=True,
point_coords=input_points,
point_labels=np.ones(len(input_points)),
mask_input=sam_mask,
)
[...]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment