Skip to content

Instantly share code, notes, and snippets.

@tonysyu
Created July 11, 2012 14:25
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save tonysyu/3090704 to your computer and use it in GitHub Desktop.
Save tonysyu/3090704 to your computer and use it in GitHub Desktop.
Tool to create polygon mask in Matplotlib
"""
Interactive tool to draw mask on an image or image-like array.
Adapted from matplotlib/examples/event_handling/poly_editor.py
"""
import numpy as np
# import matplotlib as mpl
# mpl.use('tkagg')
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.mlab import dist_point_to_segment
from matplotlib import nxutils
class MaskCreator(object):
"""An interactive polygon editor.
Parameters
----------
poly_xy : list of (float, float)
List of (x, y) coordinates used as vertices of the polygon.
max_ds : float
Max pixel distance to count as a vertex hit.
Key-bindings
------------
't' : toggle vertex markers on and off. When vertex markers are on,
you can move them, delete them
'd' : delete the vertex under point
'i' : insert a vertex at point. You must be within max_ds of the
line connecting two existing vertices
"""
def __init__(self, ax, poly_xy=None, max_ds=10):
self.showverts = True
self.max_ds = max_ds
if poly_xy is None:
poly_xy = default_vertices(ax)
self.poly = Polygon(poly_xy, animated=True,
fc='y', ec='none', alpha=0.4)
ax.add_patch(self.poly)
ax.set_clip_on(False)
ax.set_title("Click and drag a point to move it; "
"'i' to insert; 'd' to delete.\n"
"Close figure when done.")
self.ax = ax
x, y = zip(*self.poly.xy)
self.line = plt.Line2D(x, y, color='none', marker='o', mfc='r',
alpha=0.2, animated=True)
self._update_line()
self.ax.add_line(self.line)
self.poly.add_callback(self.poly_changed)
self._ind = None # the active vert
canvas = self.poly.figure.canvas
canvas.mpl_connect('draw_event', self.draw_callback)
canvas.mpl_connect('button_press_event', self.button_press_callback)
canvas.mpl_connect('button_release_event', self.button_release_callback)
canvas.mpl_connect('key_press_event', self.key_press_callback)
canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)
self.canvas = canvas
def get_mask(self, shape):
"""Return image mask given by mask creator"""
h, w = shape
y, x = np.mgrid[:h, :w]
points = np.transpose((x.ravel(), y.ravel()))
mask = nxutils.points_inside_poly(points, self.verts)
return mask.reshape(h, w)
def poly_changed(self, poly):
'this method is called whenever the polygon object is called'
# only copy the artist props to the line (except visibility)
vis = self.line.get_visible()
#Artist.update_from(self.line, poly)
self.line.set_visible(vis) # don't use the poly visibility state
def draw_callback(self, event):
self.background = self.canvas.copy_from_bbox(self.ax.bbox)
self.ax.draw_artist(self.poly)
self.ax.draw_artist(self.line)
self.canvas.blit(self.ax.bbox)
def button_press_callback(self, event):
'whenever a mouse button is pressed'
ignore = not self.showverts or event.inaxes is None or event.button != 1
if ignore:
return
self._ind = self.get_ind_under_cursor(event)
def button_release_callback(self, event):
'whenever a mouse button is released'
ignore = not self.showverts or event.button != 1
if ignore:
return
self._ind = None
def key_press_callback(self, event):
'whenever a key is pressed'
if not event.inaxes:
return
if event.key=='t':
self.showverts = not self.showverts
self.line.set_visible(self.showverts)
if not self.showverts:
self._ind = None
elif event.key=='d':
ind = self.get_ind_under_cursor(event)
if ind is None:
return
if ind == 0 or ind == self.last_vert_ind:
print "Cannot delete root node"
return
self.poly.xy = [tup for i,tup in enumerate(self.poly.xy)
if i!=ind]
self._update_line()
elif event.key=='i':
xys = self.poly.get_transform().transform(self.poly.xy)
p = event.x, event.y # cursor coords
for i in range(len(xys)-1):
s0 = xys[i]
s1 = xys[i+1]
d = dist_point_to_segment(p, s0, s1)
if d <= self.max_ds:
self.poly.xy = np.array(
list(self.poly.xy[:i+1]) +
[(event.xdata, event.ydata)] +
list(self.poly.xy[i+1:]))
self._update_line()
break
self.canvas.draw()
def motion_notify_callback(self, event):
'on mouse movement'
ignore = (not self.showverts or event.inaxes is None or
event.button != 1 or self._ind is None)
if ignore:
return
x,y = event.xdata, event.ydata
if self._ind == 0 or self._ind == self.last_vert_ind:
self.poly.xy[0] = x,y
self.poly.xy[self.last_vert_ind] = x,y
else:
self.poly.xy[self._ind] = x,y
self._update_line()
self.canvas.restore_region(self.background)
self.ax.draw_artist(self.poly)
self.ax.draw_artist(self.line)
self.canvas.blit(self.ax.bbox)
def _update_line(self):
# save verts because polygon gets deleted when figure is closed
self.verts = self.poly.xy
self.last_vert_ind = len(self.poly.xy) - 1
self.line.set_data(zip(*self.poly.xy))
def get_ind_under_cursor(self, event):
'get the index of the vertex under cursor if within max_ds tolerance'
# display coords
xy = np.asarray(self.poly.xy)
xyt = self.poly.get_transform().transform(xy)
xt, yt = xyt[:, 0], xyt[:, 1]
d = np.sqrt((xt - event.x)**2 + (yt - event.y)**2)
indseq = np.nonzero(np.equal(d, np.amin(d)))[0]
ind = indseq[0]
if d[ind] >= self.max_ds:
ind = None
return ind
def default_vertices(ax):
"""Default to rectangle that has a quarter-width/height border."""
xlims = ax.get_xlim()
ylims = ax.get_ylim()
w = np.diff(xlims)
h = np.diff(ylims)
x1, x2 = xlims + w // 4 * np.array([1, -1])
y1, y2 = ylims + h // 4 * np.array([1, -1])
return ((x1, y1), (x1, y2), (x2, y2), (x2, y1))
def mask_creator_demo():
img = np.random.uniform(0, 255, size=(100, 100))
ax = plt.subplot(111)
ax.imshow(img)
mc = MaskCreator(ax)
plt.show()
mask = mc.get_mask(img.shape)
img[~mask] = np.uint8(np.clip(img[~mask] - 100., 0, 255))
plt.imshow(img)
plt.title('Region outside of mask is darkened')
plt.show()
if __name__ == '__main__':
mask_creator_demo()
@remivanbel
Copy link

Hi,

Thank you for making this code available!

Because the imports

from matplotlib.mlab import dist_point_to_segment from matplotlib import nxutils

were deprecated in matplotlib, I made two minor adjustments to make it work again.

  1. new implementation for dist_point_to_segment taken from (OSGeo/grass#645)
  2. use of Path instead of nxutils for creating the mask as suggested in (https://stackoverflow.com/questions/21339448/how-to-get-list-of-points-inside-a-polygon-in-python)

I hope it might help others

"""
Interactive tool to draw mask on an image or image-like array.
Adapted from matplotlib/examples/event_handling/poly_editor.py
"""
import numpy as np

# import matplotlib as mpl
# mpl.use('tkagg')
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.path import Path
import numpy as np


# from https://matplotlib.org/3.0.0/_modules/matplotlib/mlab.html#dist
def dist(x, y):
    """
    Return the distance between two points.
    """
    d = x - y
    return np.sqrt(np.dot(d, d))


# from https://matplotlib.org/3.0.0/_modules/matplotlib/mlab.html#dist_point_to_segment
def dist_point_to_segment(p, s0, s1):
    """
    Get the distance of a point to a segment.

      *p*, *s0*, *s1* are *xy* sequences

    This algorithm from
    http://geomalgorithms.com/a02-_lines.html
    """
    p = np.asarray(p, float)
    s0 = np.asarray(s0, float)
    s1 = np.asarray(s1, float)
    v = s1 - s0
    w = p - s0

    c1 = np.dot(w, v)
    if c1 <= 0:
        return dist(p, s0)

    c2 = np.dot(v, v)
    if c2 <= c1:
        return dist(p, s1)

    b = c1 / c2
    pb = s0 + b * v
    return dist(p, pb)

class MaskCreator(object):
    """An interactive polygon editor.
    Parameters
    ----------
    poly_xy : list of (float, float)
        List of (x, y) coordinates used as vertices of the polygon.
    max_ds : float
        Max pixel distance to count as a vertex hit.
    Key-bindings
    ------------
    't' : toggle vertex markers on and off.  When vertex markers are on,
          you can move them, delete them
    'd' : delete the vertex under point
    'i' : insert a vertex at point.  You must be within max_ds of the
          line connecting two existing vertices
    """
    def __init__(self, ax, poly_xy=None, max_ds=10):
        self.showverts = True
        self.max_ds = max_ds
        if poly_xy is None:
            poly_xy = default_vertices(ax)
        self.poly = Polygon(poly_xy, animated=True,
                            fc='y', ec='none', alpha=0.4)

        ax.add_patch(self.poly)
        ax.set_clip_on(False)
        ax.set_title("Click and drag a point to move it; "
                     "'i' to insert; 'd' to delete.\n"
                     "Close figure when done.")
        self.ax = ax

        x, y = zip(*self.poly.xy)
        self.line = plt.Line2D(x, y, color='none', marker='o', mfc='r',
                               alpha=0.2, animated=True)
        self._update_line()
        self.ax.add_line(self.line)

        self.poly.add_callback(self.poly_changed)
        self._ind = None # the active vert

        canvas = self.poly.figure.canvas
        canvas.mpl_connect('draw_event', self.draw_callback)
        canvas.mpl_connect('button_press_event', self.button_press_callback)
        canvas.mpl_connect('button_release_event', self.button_release_callback)
        canvas.mpl_connect('key_press_event', self.key_press_callback)
        canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)
        self.canvas = canvas

    def get_mask(self, shape):
        """Return image mask given by mask creator"""
        h, w = shape
        y, x = np.mgrid[:h, :w]
        points = np.transpose((x.ravel(), y.ravel()))
        p = Path(self.verts)  # make a polygon
        grid = p.contains_points(points)  # now you have a mask with points inside a polygon
        return grid.reshape(h, w)

    def poly_changed(self, poly):
        'this method is called whenever the polygon object is called'
        # only copy the artist props to the line (except visibility)
        vis = self.line.get_visible()
        #Artist.update_from(self.line, poly)
        self.line.set_visible(vis)  # don't use the poly visibility state

    def draw_callback(self, event):
        self.background = self.canvas.copy_from_bbox(self.ax.bbox)
        self.ax.draw_artist(self.poly)
        self.ax.draw_artist(self.line)
        self.canvas.blit(self.ax.bbox)

    def button_press_callback(self, event):
        'whenever a mouse button is pressed'
        ignore = not self.showverts or event.inaxes is None or event.button != 1
        if ignore:
            return
        self._ind = self.get_ind_under_cursor(event)

    def button_release_callback(self, event):
        'whenever a mouse button is released'
        ignore = not self.showverts or event.button != 1
        if ignore:
            return
        self._ind = None

    def key_press_callback(self, event):
        'whenever a key is pressed'
        if not event.inaxes:
            return
        if event.key=='t':
            self.showverts = not self.showverts
            self.line.set_visible(self.showverts)
            if not self.showverts:
                self._ind = None
        elif event.key=='d':
            ind = self.get_ind_under_cursor(event)
            if ind is None:
                return
            if ind == 0 or ind == self.last_vert_ind:
                return
            self.poly.xy = [tup for i,tup in enumerate(self.poly.xy)
                                if i!=ind]
            self._update_line()
        elif event.key=='i':
            xys = self.poly.get_transform().transform(self.poly.xy)
            p = event.x, event.y # cursor coords
            for i in range(len(xys)-1):
                s0 = xys[i]
                s1 = xys[i+1]
                d = dist_point_to_segment(p, s0, s1)
                if d <= self.max_ds:
                    self.poly.xy = np.array(
                        list(self.poly.xy[:i+1]) +
                        [(event.xdata, event.ydata)] +
                        list(self.poly.xy[i+1:]))
                    self._update_line()
                    break
        self.canvas.draw()

    def motion_notify_callback(self, event):
        'on mouse movement'
        ignore = (not self.showverts or event.inaxes is None or
                  event.button != 1 or self._ind is None)
        if ignore:
            return
        x,y = event.xdata, event.ydata

        if self._ind == 0 or self._ind == self.last_vert_ind:
            self.poly.xy[0] = x,y
            self.poly.xy[self.last_vert_ind] = x,y
        else:
            self.poly.xy[self._ind] = x,y
        self._update_line()

        self.canvas.restore_region(self.background)
        self.ax.draw_artist(self.poly)
        self.ax.draw_artist(self.line)
        self.canvas.blit(self.ax.bbox)

    def _update_line(self):
        # save verts because polygon gets deleted when figure is closed
        self.verts = self.poly.xy
        self.last_vert_ind = len(self.poly.xy) - 1
        self.line.set_data(zip(*self.poly.xy))

    def get_ind_under_cursor(self, event):
        'get the index of the vertex under cursor if within max_ds tolerance'
        # display coords
        xy = np.asarray(self.poly.xy)
        xyt = self.poly.get_transform().transform(xy)
        xt, yt = xyt[:, 0], xyt[:, 1]
        d = np.sqrt((xt - event.x)**2 + (yt - event.y)**2)
        indseq = np.nonzero(np.equal(d, np.amin(d)))[0]
        ind = indseq[0]
        if d[ind] >= self.max_ds:
            ind = None
        return ind


def default_vertices(ax):
    """Default to rectangle that has a quarter-width/height border."""
    xlims = ax.get_xlim()
    ylims = ax.get_ylim()
    w = np.diff(xlims)
    h = np.diff(ylims)
    x1, x2 = xlims + w // 4 * np.array([1, -1])
    y1, y2 = ylims + h // 4 * np.array([1, -1])
    return ((x1, y1), (x1, y2), (x2, y2), (x2, y1))


def mask_creator_demo():
    img = np.random.uniform(0, 255, size=(100, 100))
    ax = plt.subplot(111)
    ax.imshow(img)

    mc = MaskCreator(ax)
    plt.show()

    mask = mc.get_mask(img.shape)
    img[~mask] = np.uint8(np.clip(img[~mask] - 100., 0, 255))
    plt.imshow(img)
    plt.title('Region outside of mask is darkened')
    plt.show()


if __name__ == '__main__':
    mask_creator_demo()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment