public
Created

Tool to create polygon mask in Matplotlib

  • Download Gist
mask_creator.py
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
"""
Interactive tool to draw mask on an image or image-like array.
 
Adapted from matplotlib/examples/event_handling/poly_editor.py
 
"""
import numpy as np
 
# import matplotlib as mpl
# mpl.use('tkagg')
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.mlab import dist_point_to_segment
from matplotlib import nxutils
 
 
class MaskCreator(object):
"""An interactive polygon editor.
 
Parameters
----------
poly_xy : list of (float, float)
List of (x, y) coordinates used as vertices of the polygon.
max_ds : float
Max pixel distance to count as a vertex hit.
 
Key-bindings
------------
't' : toggle vertex markers on and off. When vertex markers are on,
you can move them, delete them
'd' : delete the vertex under point
'i' : insert a vertex at point. You must be within max_ds of the
line connecting two existing vertices
"""
def __init__(self, ax, poly_xy=None, max_ds=10):
self.showverts = True
self.max_ds = max_ds
if poly_xy is None:
poly_xy = default_vertices(ax)
self.poly = Polygon(poly_xy, animated=True,
fc='y', ec='none', alpha=0.4)
 
ax.add_patch(self.poly)
ax.set_clip_on(False)
ax.set_title("Click and drag a point to move it; "
"'i' to insert; 'd' to delete.\n"
"Close figure when done.")
self.ax = ax
 
x, y = zip(*self.poly.xy)
self.line = plt.Line2D(x, y, color='none', marker='o', mfc='r',
alpha=0.2, animated=True)
self._update_line()
self.ax.add_line(self.line)
 
self.poly.add_callback(self.poly_changed)
self._ind = None # the active vert
 
canvas = self.poly.figure.canvas
canvas.mpl_connect('draw_event', self.draw_callback)
canvas.mpl_connect('button_press_event', self.button_press_callback)
canvas.mpl_connect('button_release_event', self.button_release_callback)
canvas.mpl_connect('key_press_event', self.key_press_callback)
canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)
self.canvas = canvas
 
def get_mask(self, shape):
"""Return image mask given by mask creator"""
h, w = shape
y, x = np.mgrid[:h, :w]
points = np.transpose((x.ravel(), y.ravel()))
mask = nxutils.points_inside_poly(points, self.verts)
return mask.reshape(h, w)
 
def poly_changed(self, poly):
'this method is called whenever the polygon object is called'
# only copy the artist props to the line (except visibility)
vis = self.line.get_visible()
#Artist.update_from(self.line, poly)
self.line.set_visible(vis) # don't use the poly visibility state
 
def draw_callback(self, event):
self.background = self.canvas.copy_from_bbox(self.ax.bbox)
self.ax.draw_artist(self.poly)
self.ax.draw_artist(self.line)
self.canvas.blit(self.ax.bbox)
 
def button_press_callback(self, event):
'whenever a mouse button is pressed'
ignore = not self.showverts or event.inaxes is None or event.button != 1
if ignore:
return
self._ind = self.get_ind_under_cursor(event)
 
def button_release_callback(self, event):
'whenever a mouse button is released'
ignore = not self.showverts or event.button != 1
if ignore:
return
self._ind = None
 
def key_press_callback(self, event):
'whenever a key is pressed'
if not event.inaxes:
return
if event.key=='t':
self.showverts = not self.showverts
self.line.set_visible(self.showverts)
if not self.showverts:
self._ind = None
elif event.key=='d':
ind = self.get_ind_under_cursor(event)
if ind is None:
return
if ind == 0 or ind == self.last_vert_ind:
print "Cannot delete root node"
return
self.poly.xy = [tup for i,tup in enumerate(self.poly.xy)
if i!=ind]
self._update_line()
elif event.key=='i':
xys = self.poly.get_transform().transform(self.poly.xy)
p = event.x, event.y # cursor coords
for i in range(len(xys)-1):
s0 = xys[i]
s1 = xys[i+1]
d = dist_point_to_segment(p, s0, s1)
if d <= self.max_ds:
self.poly.xy = np.array(
list(self.poly.xy[:i+1]) +
[(event.xdata, event.ydata)] +
list(self.poly.xy[i+1:]))
self._update_line()
break
self.canvas.draw()
 
def motion_notify_callback(self, event):
'on mouse movement'
ignore = (not self.showverts or event.inaxes is None or
event.button != 1 or self._ind is None)
if ignore:
return
x,y = event.xdata, event.ydata
 
if self._ind == 0 or self._ind == self.last_vert_ind:
self.poly.xy[0] = x,y
self.poly.xy[self.last_vert_ind] = x,y
else:
self.poly.xy[self._ind] = x,y
self._update_line()
 
self.canvas.restore_region(self.background)
self.ax.draw_artist(self.poly)
self.ax.draw_artist(self.line)
self.canvas.blit(self.ax.bbox)
 
def _update_line(self):
# save verts because polygon gets deleted when figure is closed
self.verts = self.poly.xy
self.last_vert_ind = len(self.poly.xy) - 1
self.line.set_data(zip(*self.poly.xy))
 
def get_ind_under_cursor(self, event):
'get the index of the vertex under cursor if within max_ds tolerance'
# display coords
xy = np.asarray(self.poly.xy)
xyt = self.poly.get_transform().transform(xy)
xt, yt = xyt[:, 0], xyt[:, 1]
d = np.sqrt((xt - event.x)**2 + (yt - event.y)**2)
indseq = np.nonzero(np.equal(d, np.amin(d)))[0]
ind = indseq[0]
if d[ind] >= self.max_ds:
ind = None
return ind
 
 
def default_vertices(ax):
"""Default to rectangle that has a quarter-width/height border."""
xlims = ax.get_xlim()
ylims = ax.get_ylim()
w = np.diff(xlims)
h = np.diff(ylims)
x1, x2 = xlims + w // 4 * np.array([1, -1])
y1, y2 = ylims + h // 4 * np.array([1, -1])
return ((x1, y1), (x1, y2), (x2, y2), (x2, y1))
 
 
def mask_creator_demo():
img = np.random.uniform(0, 255, size=(100, 100))
ax = plt.subplot(111)
ax.imshow(img)
 
mc = MaskCreator(ax)
plt.show()
 
mask = mc.get_mask(img.shape)
img[~mask] = np.uint8(np.clip(img[~mask] - 100., 0, 255))
plt.imshow(img)
plt.title('Region outside of mask is darkened')
plt.show()
 
 
if __name__ == '__main__':
mask_creator_demo()

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.