Skip to content

Instantly share code, notes, and snippets.

@smeschke
Created June 18, 2019 00:08
Show Gist options
  • Save smeschke/3cc053d6b370aace7465604eb83d8c5e to your computer and use it in GitHub Desktop.
Save smeschke/3cc053d6b370aace7465604eb83d8c5e to your computer and use it in GitHub Desktop.
Applies grabcut using a mask generated with DL
import numpy as np
import cv2
from matplotlib import pyplot as plt
# Load image and mask
img = cv2.imread('/home/stephen/Downloads/bird.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h,w,_ = img.shape
deep_mask = cv2.imread('/home/stephen/Downloads/bird_mask.png',0)
deep_mask = cv2.resize(deep_mask, (w,h))
mask = np.zeros(img.shape[:2],np.uint8)
white_background = (255 - mask.copy())
# Initialize parameters for the GrabCut algorithm
bgdModel = np.zeros((1,65),np.float64)
fgdModel = np.zeros((1,65),np.float64)
iters, size = 3,5
kernel = np.ones((size,size),np.uint8)
big_kernel = np.ones((2*size,2*size),np.uint8)
# Dilate the mask to make sure the whole object is covered by the mask
dilation = cv2.dilate(deep_mask, big_kernel, iterations = iters)
# Start with a white background and subtract
sure_background = white_background - dilation
# Erode to find the sure foreground
sure_foreground = cv2.erode(deep_mask, kernel, iterations = iters)
# Change the values on the mask so that:
# 2 - unsure pixels
# 1 - sure foreground pixels
# 0 - sure background pixels
mask[:] = 2
mask[sure_background == 255] = 0
mask[sure_foreground == 255] = 1
# Apply GrabCut
out_mask= mask.copy()
out_mask, _, _ = cv2.grabCut(img,out_mask,None,bgdModel,fgdModel,3,cv2.GC_INIT_WITH_MASK)
out_mask = np.where((out_mask==2)|(out_mask==0),0,1).astype('uint8')
out_img = img*out_mask[:,:,np.newaxis]
# Plot with Matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
f, axarr = plt.subplots(2,3, sharex=True)
axarr[0,0].imshow(img)
axarr[1,0].imshow(deep_mask)
background = img.copy()
background[sure_background == 0] = (0,0,0)
background = cv2.addWeighted(background, .5, img, .5, 1)
axarr[0,1].imshow(background)
foreground = img.copy()
foreground[sure_foreground == 0] = (0,0,0)
foreground = cv2.addWeighted(foreground, .5, img, .5, 1)
axarr[1,1].imshow(foreground)
axarr[0,2].imshow(out_mask)
axarr[1,2].imshow(out_img)
axarr[0,0].set_title('Source Image')
axarr[1,0].set_title('Mask from DL')
axarr[0,1].set_title('Sure Background')
axarr[1,1].set_title('Sure Foreground')
axarr[0,2].set_title('GrabCut Mask')
axarr[1,2].set_title('GrabCut Image')
axarr[0,0].axis('off')
axarr[0,1].axis('off')
axarr[1,0].axis('off')
axarr[1,1].axis('off')
axarr[1,2].axis('off')
axarr[0,2].axis('off')
plt.show()
@smeschke
Copy link
Author

Figure_1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment