Skip to content

Instantly share code, notes, and snippets.

@brikeats
Last active April 11, 2018 03:21
Show Gist options
  • Save brikeats/f664e548f7ba44fb71b3749756fbc2c1 to your computer and use it in GitHub Desktop.
Save brikeats/f664e548f7ba44fb71b3749756fbc2c1 to your computer and use it in GitHub Desktop.
import numpy as np
from os.path import basename
from glob import glob
from skimage.color import label2rgb
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt
from tomopy.misc.phantom import shepp3d
"""
TODO:
* add label showing which dim, slice you're on
* add label showing (i,j,k), (r,a,s), pixel value and label value
* add label_dict argument so that we can display the label string
"""
class VolumeComparer(object):
"""
"""
def __init__(self, vol1, vol2, labels1=None, labels2=None, alpha=0.2):
self.vol1 = vol1
self.vol2 = vol2
self.labels1 = labels1
self.labels2 = labels2
self.alpha = alpha
# create initial plot
self.figure = plt.figure(figsize=(7,7))
self.left_ax = plt.subplot(121)
self.right_ax = plt.subplot(122)
self.dim = 0
self.slice_nums = [int(round(0.5*s)) for s in vol1.shape]
self.update_images()
# disable default keybindings
manager, canvas = self.figure.canvas.manager, self.figure.canvas
canvas.mpl_disconnect(manager.key_press_handler_id)
# set callbacks
self.cid = self.figure.canvas.mpl_connect('key_press_event', self._on_keypress)
self.cid2 = self.figure.canvas.mpl_connect('button_press_event', self._on_mouse_click)
self.cid3 = self.figure.canvas.mpl_connect('axes_enter_event', self._on_mouse_enter_axes)
self.cid4 = self.figure.canvas.mpl_connect('axes_leave_event', self._on_mouse_leave_axes)
self.cid5 = self.figure.canvas.mpl_connect('motion_notify_event', self._on_mouse_motion)
self.cid6 = self.figure.canvas.mpl_connect('scroll_event', self._on_mouse_scroll)
plt.ion()
plt.show()
self.figure.canvas.start_event_loop(timeout=0)
def slice_volumes(self):
slc_num = self.slice_nums[self.dim]
inds = [slice(None, None, None),
slice(None, None, None),
slice(None, None, None)]
inds[self.dim] = self.slice_nums[self.dim]
slc1 = vol1[inds]
slc2 = vol2[inds]
if self.labels1 is not None:
label_slc1 = self.labels1[inds]
slc1 = label2rgb(label_slc1, image=slc1, alpha=self.alpha, bg_label=0)
if self.labels2 is not None:
label_slc2 = self.labels2[inds]
slc2 = label2rgb(label_slc2, image=slc2, alpha=self.alpha, bg_label=0)
return slc1, slc2
def update_images(self):
slc1, slc2 = self.slice_volumes()
try:
self.left_im_plot.set_data(slc1)
self.right_im_plot.set_data(slc2)
except AttributeError:
self.left_im_plot = self.left_ax.imshow(slc1, interpolation='none', cmap='gray')
plt.axis('off')
self.right_im_plot = self.right_ax.imshow(slc2, interpolation='none', cmap='gray')
self.right_ax.axis('off')
self.left_ax.axis('off')
self.left_crosshair = self.left_ax.plot(0, 0, marker='+', markersize=20, color='r')[0]
self.right_crosshair = self.right_ax.plot(0, 0, marker='+', markersize=20, color='r')[0]
self.right_crosshair.set_visible(False)
self.left_crosshair.set_visible(False)
def _on_mouse_scroll(self, event):
self.slice_nums[self.dim] += event.step
self.slice_nums[self.dim] = self.slice_nums[self.dim] % self.vol1.shape[self.dim]
self.update_images()
def _on_mouse_motion(self, event):
if event.inaxes == self.left_ax:
self.right_crosshair.set_data([event.xdata], [event.ydata])
elif event.inaxes == self.right_ax:
self.left_crosshair.set_data([event.xdata], [event.ydata])
def _on_mouse_enter_axes(self, event):
if event.inaxes == self.left_ax:
self.right_crosshair.set_visible(True)
self.left_crosshair.set_visible(False)
elif event.inaxes == self.right_ax:
self.left_crosshair.set_visible(True)
self.right_crosshair.set_visible(False)
def _on_mouse_leave_axes(self, event):
self.right_crosshair.set_visible(False)
self.left_crosshair.set_visible(False)
def _on_mouse_click(self, event):
pass
def next(self):
self.slice_nums[self.dim] += 1
self.slice_nums[self.dim] = self.slice_nums[self.dim] % self.vol1.shape[self.dim]
self.update_images()
def previous(self):
self.slice_nums[self.dim] -= 1
self.slice_nums[self.dim] = self.slice_nums[self.dim] % self.vol1.shape[self.dim]
self.update_images()
def toggle_dimension(self):
self.dim = (self.dim + 1) % 3
self.update_images()
def _on_keypress(self, event):
if event.key.lower() in ['q', 'escape']:
plt.close()
self.figure.canvas.stop_event_loop()
elif event.key == 'up':
self.next()
elif event.key == 'down':
self.previous()
elif event.key == 'd':
self.toggle_dimension()
else:
pass
# print 'pressed', event.key
if __name__ == '__main__':
from skimage.filters import threshold_otsu
from skimage import measure
vol_shape = (128, 128, 128)
vol1 = np.random.randn(*vol_shape)
vol2 = shepp3d(vol_shape)
labels2 = measure.label(vol2 > 0.1)
comparer = VolumeComparer(vol1, vol2, labels2=labels2) # blocks until quit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment