Skip to content

Instantly share code, notes, and snippets.

@brikeats
Last active June 13, 2023 11:02
Show Gist options
  • Save brikeats/063e6873acfa3fc280e8 to your computer and use it in GitHub Desktop.
Save brikeats/063e6873acfa3fc280e8 to your computer and use it in GitHub Desktop.
Show slices through a 3D numpy array (volume) with optional mask overlay. Uses only numpy and matplotlib.
def volshow(vol, sl_dim=0, mask=None, **kwargs):
# FIXME: doesn't work when sl_dim=-1
"""
This function displays slices into a 3-dimensional numpy array. Optionally, the user can
pass a binary mask of the same size to be displayed on top of the volume. This is intended
to be a 3D equivalent of pyplot.imshow.
Usage:
the up and down arrows flip through the slices
'd' toggles the slice dimension (eg, axial->sagittal->coronal)
'q' or escape closes the figure and exits the function
In a script, put the lines `import matplotlib; matplotlib.use('TkAgg')` before any other
imports.
In an ipython notebook session, use the line `%matplotlib tk` to
get the interactivity to work correctly.
"""
class Slicer:
def __init__(self, vol, sl_dim, mask=None, alpha=0.5, color='red', **kwargs):
self.vol = vol
self.slice_dim = sl_dim
self.color = color
self.alpha = alpha
self.slice_nums = [sz/2 for sz in vol.shape]
self.slice_num = self.slice_nums[self.slice_dim]
indx = [Ellipsis, Ellipsis, Ellipsis]
indx[self.slice_dim] = self.slice_num
sl = np.squeeze(vol[indx])
self.fig, self.ax = plt.subplots()
self.cid = self.fig.canvas.mpl_connect('key_press_event', self.on_key)
self.plot = plt.imshow(sl, **kwargs)
# TODO: check that mask is okay (right shape, binary)
self.has_mask = mask is not None
if self.has_mask:
mask = mask > 0
self.mask_vol = np.ma.masked_where(~mask, mask)
mask_sl = np.squeeze(self.mask_vol[indx])
my_cmap = ListedColormap([color])
my_cmap._init()
my_cmap._lut[:-1,-1] = alpha
self.mask_vol_plot = plt.imshow(mask_sl, alpha=alpha, cmap=my_cmap)
plt.title('Slice %i of %i' % (self.slice_num, self.vol.shape[self.slice_dim]))
plt.axis('off')
plt.draw()
plt.show()
try:
self.fig.canvas.start_event_loop(timeout=-1)
except: # this may throw a TclError, depending on the backend
pass
def replot(self):
# create a new imshow plot. Used when toggling slice dimension
self.plot.remove()
indx = [Ellipsis, Ellipsis, Ellipsis]
indx[self.slice_dim] = self.slice_num
sl = np.squeeze(vol[indx])
self.plot = plt.imshow(sl, cmap='gray')
if self.has_mask:
self.mask_vol_plot.remove()
mask_sl = np.squeeze(self.mask_vol[indx])
self.mask_vol_plot = plt.imshow(mask_sl, alpha=self.alpha, cmap=ListedColormap([self.color]))
plt.title('Slice %i of %i' % (self.slice_num+1, self.vol.shape[self.slice_dim]))
plt.draw()
def redraw(self):
# refresh the data to display a new slice (along the same dimension)
indx = [Ellipsis, Ellipsis, Ellipsis]
indx[self.slice_dim] = self.slice_num
sl = np.squeeze(vol[indx])
self.plot.set_data(sl)
if self.has_mask:
mask_sl = np.squeeze(self.mask_vol[indx])
self.mask_vol_plot.set_data(mask_sl)
plt.title('Slice %i of %i' % (self.slice_num+1, self.vol.shape[self.slice_dim]))
plt.draw()
def prev_slice(self, _):
self.slice_num = max([0, self.slice_num - 1])
self.redraw()
def next_slice(self, _):
self.slice_num = min([self.slice_num + 1, self.vol.shape[self.slice_dim]-1])
self.redraw()
def toggle_slice_dim(self, _):
self.slice_nums[self.slice_dim] = self.slice_num # save the current slice index
self.slice_dim = (self.slice_dim+1) % len(self.vol.shape)
self.slice_num = self.slice_nums[self.slice_dim]
self.replot()
def on_key(self, event):
if event.key in ['q', 'Q', 'escape']:
plt.close()
elif event.key == 'up':
self.next_slice(event)
elif event.key == 'down':
self.prev_slice(event)
elif event.key == 'd':
self.toggle_slice_dim(event)
slicer = Slicer(vol, sl_dim, mask, **kwargs)
@mikbuch
Copy link

mikbuch commented Jun 13, 2023

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[23], line 1
----> 1 volshow(segmentation_arr)

Cell In[22], line 107, in volshow(vol, sl_dim, mask, **kwargs)
    104         elif event.key == 'd':
    105             self.toggle_slice_dim(event)
--> 107 slicer = Slicer(vol, sl_dim, mask, **kwargs)

Cell In[22], line 31, in volshow.<locals>.Slicer.__init__(self, vol, sl_dim, mask, alpha, color, **kwargs)
     29 indx = [Ellipsis, Ellipsis, Ellipsis]
     30 indx[self.slice_dim] = self.slice_num            
---> 31 sl = np.squeeze(vol[indx])
     33 self.fig, self.ax = plt.subplots()
     34 self.cid = self.fig.canvas.mpl_connect('key_press_event', self.on_key)

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

index in my case is 149.5

@mikbuch
Copy link

mikbuch commented Jun 13, 2023

This is actually quite nice, although not slicer, it can visualize 3D masks quite well: https://matplotlib.org/stable/gallery/mplot3d/voxels.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment