Skip to content

Instantly share code, notes, and snippets.

@brikeats
Last active February 22, 2017 01:52
Show Gist options
  • Save brikeats/ff85b259364f59ca6bee0c4fee4f9836 to your computer and use it in GitHub Desktop.
Save brikeats/ff85b259364f59ca6bee0c4fee4f9836 to your computer and use it in GitHub Desktop.
simple matplotlib-based gui for labelling images
from os.path import basename
from glob import glob
from matplotlib import pyplot as plt
class ImageLabeller(object):
"""
Labels can be basically any type that you can cast to str.
Args:
im_fns (list of str): a list of matplotlib-readable image file names
key_label_dict (dict): a string-to-label dict, controls how the user
assigns labels with the keyboard. The keys of the dict are the keystrokes that the user can press to assign labels.
label_dict (dict): a string-to-label dictionary containing
already-existing labels (e.g., labels from a previous session). The keys are im_fns, the values are labels (which must agree with `key_label_dict`.)
Attribute:
labels (dict): a string-to-label dict, mapping im_fns to labels.
Unlabelled images will have `None` for a label.
"""
def __init__(self, im_fns, key_label_dict, label_dict=None):
# TODO: validate that the keys of label_dict match key_label_dict
# TODO: print and/or display key_label_dict for reference
self.im_fns = im_fns
self.key_label_dict = key_label_dict
self.im_num = 0
if label_dict:
self.labels = label_dict
else:
self.labels = {fn: None for fn in im_fns}
im = plt.imread(im_fns[self.im_num])
# create initial plot
self.figure = plt.figure()
self.show_image(0)
# disable default keybindings
manager, canvas = self.figure.canvas.manager, self.figure.canvas
canvas.mpl_disconnect(manager.key_press_handler_id)
# set callbacks
self.cid = self.figure.canvas.mpl_connect('key_press_event', self._on_keypress)
self.cid2 = self.figure.canvas.mpl_connect('button_press_event', self._on_mouse_click)
plt.ion()
plt.show()
self.figure.canvas.start_event_loop(timeout=-1)
def _on_mouse_click(self, event):
print 'clicked', event.xdata, event.ydata
def next(self):
self.im_num += 1
self.im_num = self.im_num % len(self.im_fns)
self.show_image(self.im_num)
def previous(self):
self.im_num -= 1
self.im_num = self.im_num % len(self.im_fns)
self.show_image(self.im_num)
def show_image(self, im_num):
im_fn = self.im_fns[im_num]
im = plt.imread(im_fn)
label = self.labels[im_fn]
try:
self.im_plot.set_data(im)
except AttributeError:
self.im_plot = plt.imshow(im, interpolation='none')
plt.axis([0, im.shape[1], im.shape[0], 0])
plt.axis('off')
if label:
self.figure.canvas.set_window_title('%s: %s'% (basename(im_fn), str(label)))
else:
self.figure.canvas.set_window_title(basename(im_fn))
def _on_keypress(self, event):
if event.key.lower() in ['q', 'escape']:
plt.close()
self.figure.canvas.stop_event_loop()
elif event.key == 'right':
self.next()
elif event.key == 'left':
self.previous()
elif event.key in self.key_label_dict.keys():
label = key_label_dict[event.key]
im_fn = self.im_fns[self.im_num]
self.labels[im_fn] = label
self.show_image(self.im_num) # refresh the window title
else:
pass
# print 'pressed', event.key
if __name__ == '__main__':
im_fns = glob('/home/brian/Pictures/pics to paint/*.jpg')
# keystroke -> label
key_label_dict = {'1':'pad', '2':'not-pad', '3':'unclear', 'd':'delete'}
labeller = ImageLabeller(im_fns, key_label_dict) # blocks until quit
# labeller.labels is a filename -> label dict
print labeller.labels.values()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment