Skip to content

Instantly share code, notes, and snippets.

@pppoe
Created January 2, 2024 05:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pppoe/299a134a3453852192a0017d75c57284 to your computer and use it in GitHub Desktop.
Save pppoe/299a134a3453852192a0017d75c57284 to your computer and use it in GitHub Desktop.
Sample 2D points from a 2D binary 0-1 mask
import numpy as np
import cv2
import random
I = cv2.imread('image.png')
mask = np.load('mask.npy')
mask = (mask[:,:,0] > 0) # True / False binary matrix
visI = np.zeros_like(I)
def sample_prompt_pts(mask, h_resol=128, w_resol=128, max_pts=128, K = 20):
h_step = max(mask.shape[0] // h_resol, 1)
w_step = max(mask.shape[1] // w_resol, 1)
valid_pts = []
for i in range(0, mask.shape[0], h_step):
for j in range(0, mask.shape[1], w_step):
if mask[i][j]: valid_pts.append((i, j))
if len(valid_pts) > max_pts:
valid_pts = np.array(valid_pts).reshape(-1, 2)
keep = set([np.random.randint(0, len(valid_pts))])
while len(keep) < max_pts:
sel = []
while len(sel) < K:
i = np.random.randint(0, len(valid_pts))
if i in keep: continue
sel.append(i)
sel_dists = np.linalg.norm(valid_pts[np.array(sel),None,:] - valid_pts[None,np.array(list(keep)),:],axis=2).min(axis=1)
j = sel[np.argmax(sel_dists)]
keep.add(j)
valid_pts = valid_pts[list(keep),:]
return valid_pts
valid_pts = sample_prompt_pts(mask)
for i in range(len(valid_pts)):
visI[valid_pts[i][0]][valid_pts[i][1]] = I[valid_pts[i][0]][valid_pts[i][1]]
cv2.imwrite('test.png', visI)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment