Skip to content

Instantly share code, notes, and snippets.

@yezhengli-Mr9
Forked from tonysyu/mask_creator.py
Created October 1, 2018 11:35
Show Gist options
  • Save yezhengli-Mr9/8e9a35c6db01903308bab8f9dc632031 to your computer and use it in GitHub Desktop.
Save yezhengli-Mr9/8e9a35c6db01903308bab8f9dc632031 to your computer and use it in GitHub Desktop.
Tool to create polygon mask in Matplotlib
"""
Interactive tool to draw mask on an image or image-like array.
Adapted from matplotlib/examples/event_handling/poly_editor.py
"""
import numpy as np
# import matplotlib as mpl
# mpl.use('tkagg')
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.mlab import dist_point_to_segment
from matplotlib import nxutils
class MaskCreator(object):
"""An interactive polygon editor.
Parameters
----------
poly_xy : list of (float, float)
List of (x, y) coordinates used as vertices of the polygon.
max_ds : float
Max pixel distance to count as a vertex hit.
Key-bindings
------------
't' : toggle vertex markers on and off. When vertex markers are on,
you can move them, delete them
'd' : delete the vertex under point
'i' : insert a vertex at point. You must be within max_ds of the
line connecting two existing vertices
"""
def __init__(self, ax, poly_xy=None, max_ds=10):
self.showverts = True
self.max_ds = max_ds
if poly_xy is None:
poly_xy = default_vertices(ax)
self.poly = Polygon(poly_xy, animated=True,
fc='y', ec='none', alpha=0.4)
ax.add_patch(self.poly)
ax.set_clip_on(False)
ax.set_title("Click and drag a point to move it; "
"'i' to insert; 'd' to delete.\n"
"Close figure when done.")
self.ax = ax
x, y = zip(*self.poly.xy)
self.line = plt.Line2D(x, y, color='none', marker='o', mfc='r',
alpha=0.2, animated=True)
self._update_line()
self.ax.add_line(self.line)
self.poly.add_callback(self.poly_changed)
self._ind = None # the active vert
canvas = self.poly.figure.canvas
canvas.mpl_connect('draw_event', self.draw_callback)
canvas.mpl_connect('button_press_event', self.button_press_callback)
canvas.mpl_connect('button_release_event', self.button_release_callback)
canvas.mpl_connect('key_press_event', self.key_press_callback)
canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)
self.canvas = canvas
def get_mask(self, shape):
"""Return image mask given by mask creator"""
h, w = shape
y, x = np.mgrid[:h, :w]
points = np.transpose((x.ravel(), y.ravel()))
mask = nxutils.points_inside_poly(points, self.verts)
return mask.reshape(h, w)
def poly_changed(self, poly):
'this method is called whenever the polygon object is called'
# only copy the artist props to the line (except visibility)
vis = self.line.get_visible()
#Artist.update_from(self.line, poly)
self.line.set_visible(vis) # don't use the poly visibility state
def draw_callback(self, event):
self.background = self.canvas.copy_from_bbox(self.ax.bbox)
self.ax.draw_artist(self.poly)
self.ax.draw_artist(self.line)
self.canvas.blit(self.ax.bbox)
def button_press_callback(self, event):
'whenever a mouse button is pressed'
ignore = not self.showverts or event.inaxes is None or event.button != 1
if ignore:
return
self._ind = self.get_ind_under_cursor(event)
def button_release_callback(self, event):
'whenever a mouse button is released'
ignore = not self.showverts or event.button != 1
if ignore:
return
self._ind = None
def key_press_callback(self, event):
'whenever a key is pressed'
if not event.inaxes:
return
if event.key=='t':
self.showverts = not self.showverts
self.line.set_visible(self.showverts)
if not self.showverts:
self._ind = None
elif event.key=='d':
ind = self.get_ind_under_cursor(event)
if ind is None:
return
if ind == 0 or ind == self.last_vert_ind:
print "Cannot delete root node"
return
self.poly.xy = [tup for i,tup in enumerate(self.poly.xy)
if i!=ind]
self._update_line()
elif event.key=='i':
xys = self.poly.get_transform().transform(self.poly.xy)
p = event.x, event.y # cursor coords
for i in range(len(xys)-1):
s0 = xys[i]
s1 = xys[i+1]
d = dist_point_to_segment(p, s0, s1)
if d <= self.max_ds:
self.poly.xy = np.array(
list(self.poly.xy[:i+1]) +
[(event.xdata, event.ydata)] +
list(self.poly.xy[i+1:]))
self._update_line()
break
self.canvas.draw()
def motion_notify_callback(self, event):
'on mouse movement'
ignore = (not self.showverts or event.inaxes is None or
event.button != 1 or self._ind is None)
if ignore:
return
x,y = event.xdata, event.ydata
if self._ind == 0 or self._ind == self.last_vert_ind:
self.poly.xy[0] = x,y
self.poly.xy[self.last_vert_ind] = x,y
else:
self.poly.xy[self._ind] = x,y
self._update_line()
self.canvas.restore_region(self.background)
self.ax.draw_artist(self.poly)
self.ax.draw_artist(self.line)
self.canvas.blit(self.ax.bbox)
def _update_line(self):
# save verts because polygon gets deleted when figure is closed
self.verts = self.poly.xy
self.last_vert_ind = len(self.poly.xy) - 1
self.line.set_data(zip(*self.poly.xy))
def get_ind_under_cursor(self, event):
'get the index of the vertex under cursor if within max_ds tolerance'
# display coords
xy = np.asarray(self.poly.xy)
xyt = self.poly.get_transform().transform(xy)
xt, yt = xyt[:, 0], xyt[:, 1]
d = np.sqrt((xt - event.x)**2 + (yt - event.y)**2)
indseq = np.nonzero(np.equal(d, np.amin(d)))[0]
ind = indseq[0]
if d[ind] >= self.max_ds:
ind = None
return ind
def default_vertices(ax):
"""Default to rectangle that has a quarter-width/height border."""
xlims = ax.get_xlim()
ylims = ax.get_ylim()
w = np.diff(xlims)
h = np.diff(ylims)
x1, x2 = xlims + w // 4 * np.array([1, -1])
y1, y2 = ylims + h // 4 * np.array([1, -1])
return ((x1, y1), (x1, y2), (x2, y2), (x2, y1))
def mask_creator_demo():
img = np.random.uniform(0, 255, size=(100, 100))
ax = plt.subplot(111)
ax.imshow(img)
mc = MaskCreator(ax)
plt.show()
mask = mc.get_mask(img.shape)
img[~mask] = np.uint8(np.clip(img[~mask] - 100., 0, 255))
plt.imshow(img)
plt.title('Region outside of mask is darkened')
plt.show()
if __name__ == '__main__':
mask_creator_demo()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment