Last active
June 13, 2023 11:02
-
-
Save brikeats/063e6873acfa3fc280e8 to your computer and use it in GitHub Desktop.
Show slices through a 3D numpy array (volume) with optional mask overlay. Uses only numpy and matplotlib.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def volshow(vol, sl_dim=0, mask=None, **kwargs): | |
# FIXME: doesn't work when sl_dim=-1 | |
""" | |
This function displays slices into a 3-dimensional numpy array. Optionally, the user can | |
pass a binary mask of the same size to be displayed on top of the volume. This is intended | |
to be a 3D equivalent of pyplot.imshow. | |
Usage: | |
the up and down arrows flip through the slices | |
'd' toggles the slice dimension (eg, axial->sagittal->coronal) | |
'q' or escape closes the figure and exits the function | |
In a script, put the lines `import matplotlib; matplotlib.use('TkAgg')` before any other | |
imports. | |
In an ipython notebook session, use the line `%matplotlib tk` to | |
get the interactivity to work correctly. | |
""" | |
class Slicer: | |
def __init__(self, vol, sl_dim, mask=None, alpha=0.5, color='red', **kwargs): | |
self.vol = vol | |
self.slice_dim = sl_dim | |
self.color = color | |
self.alpha = alpha | |
self.slice_nums = [sz/2 for sz in vol.shape] | |
self.slice_num = self.slice_nums[self.slice_dim] | |
indx = [Ellipsis, Ellipsis, Ellipsis] | |
indx[self.slice_dim] = self.slice_num | |
sl = np.squeeze(vol[indx]) | |
self.fig, self.ax = plt.subplots() | |
self.cid = self.fig.canvas.mpl_connect('key_press_event', self.on_key) | |
self.plot = plt.imshow(sl, **kwargs) | |
# TODO: check that mask is okay (right shape, binary) | |
self.has_mask = mask is not None | |
if self.has_mask: | |
mask = mask > 0 | |
self.mask_vol = np.ma.masked_where(~mask, mask) | |
mask_sl = np.squeeze(self.mask_vol[indx]) | |
my_cmap = ListedColormap([color]) | |
my_cmap._init() | |
my_cmap._lut[:-1,-1] = alpha | |
self.mask_vol_plot = plt.imshow(mask_sl, alpha=alpha, cmap=my_cmap) | |
plt.title('Slice %i of %i' % (self.slice_num, self.vol.shape[self.slice_dim])) | |
plt.axis('off') | |
plt.draw() | |
plt.show() | |
try: | |
self.fig.canvas.start_event_loop(timeout=-1) | |
except: # this may throw a TclError, depending on the backend | |
pass | |
def replot(self): | |
# create a new imshow plot. Used when toggling slice dimension | |
self.plot.remove() | |
indx = [Ellipsis, Ellipsis, Ellipsis] | |
indx[self.slice_dim] = self.slice_num | |
sl = np.squeeze(vol[indx]) | |
self.plot = plt.imshow(sl, cmap='gray') | |
if self.has_mask: | |
self.mask_vol_plot.remove() | |
mask_sl = np.squeeze(self.mask_vol[indx]) | |
self.mask_vol_plot = plt.imshow(mask_sl, alpha=self.alpha, cmap=ListedColormap([self.color])) | |
plt.title('Slice %i of %i' % (self.slice_num+1, self.vol.shape[self.slice_dim])) | |
plt.draw() | |
def redraw(self): | |
# refresh the data to display a new slice (along the same dimension) | |
indx = [Ellipsis, Ellipsis, Ellipsis] | |
indx[self.slice_dim] = self.slice_num | |
sl = np.squeeze(vol[indx]) | |
self.plot.set_data(sl) | |
if self.has_mask: | |
mask_sl = np.squeeze(self.mask_vol[indx]) | |
self.mask_vol_plot.set_data(mask_sl) | |
plt.title('Slice %i of %i' % (self.slice_num+1, self.vol.shape[self.slice_dim])) | |
plt.draw() | |
def prev_slice(self, _): | |
self.slice_num = max([0, self.slice_num - 1]) | |
self.redraw() | |
def next_slice(self, _): | |
self.slice_num = min([self.slice_num + 1, self.vol.shape[self.slice_dim]-1]) | |
self.redraw() | |
def toggle_slice_dim(self, _): | |
self.slice_nums[self.slice_dim] = self.slice_num # save the current slice index | |
self.slice_dim = (self.slice_dim+1) % len(self.vol.shape) | |
self.slice_num = self.slice_nums[self.slice_dim] | |
self.replot() | |
def on_key(self, event): | |
if event.key in ['q', 'Q', 'escape']: | |
plt.close() | |
elif event.key == 'up': | |
self.next_slice(event) | |
elif event.key == 'down': | |
self.prev_slice(event) | |
elif event.key == 'd': | |
self.toggle_slice_dim(event) | |
slicer = Slicer(vol, sl_dim, mask, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is actually quite nice, although not slicer, it can visualize 3D masks quite well: https://matplotlib.org/stable/gallery/mplot3d/voxels.html