March 1, 2022
This is code for a simple 3D Slicer module which allows stitching together of multiple image volumes into one larger image volume. This is intended for cases where imaging was acquired at multiple stations along the same scanner axis.
import os
import unittest
import logging
import vtk, qt, ctk, slicer
import numpy as np
from slicer.ScriptedLoadableModule import *
from slicer.util import VTKObservationMixin
# StitchVolumes
class StitchVolumes(ScriptedLoadableModule):
"""Uses ScriptedLoadableModule base class, available at:
def __init__(self, parent):
ScriptedLoadableModule.__init__(self, parent)
self.parent.title = "Stitch Volumes"
self.parent.categories = ["MikeTools"] # TODO: set categories (folders where the module shows up in the module selector)
self.parent.dependencies = [] # TODO: add here list of module names that this module requires
self.parent.contributors = ["Mike Bindschadler (Seattle Children's Hospital)"]
self.parent.helpText = """
This module allows a user to stitch together two or more volumes. A set of volumes to stitch, as well
as a rectangular ROI (to define the output geometry) is supplied, and this module produces an output
volume which represents all the input volumes cropped, resampled, and stitched together. Areas of overlap
between original volumes are handled by finding the center of the overlap region, and assigning each half
of the overlap to the closer original volume.
self.parent.helpText += self.getDefaultModuleDocumentationLink() # TODO: verify that the default URL is correct or change it to the actual documentation
self.parent.acknowledgementText = """
This work was funded by Seattle Children's Hospital.
""" # TODO: replace with organization, grant and thanks.
# StitchVolumesWidget
class StitchVolumesWidget(ScriptedLoadableModuleWidget, VTKObservationMixin):
"""Uses ScriptedLoadableModuleWidget base class, available at:
def __init__(self, parent=None):
Called when the user opens the module the first time and the widget is initialized.
ScriptedLoadableModuleWidget.__init__(self, parent)
VTKObservationMixin.__init__(self) # needed for parameter node observation
self.logic = None
self._parameterNode = None
def setup(self):
Called when the user opens the module the first time and the widget is initialized.
# Load widget from .ui file (created by Qt Designer)
uiWidget = slicer.util.loadUI(self.resourcePath('UI/StitchVolumes.ui'))
self.ui = slicer.util.childWidgetVariables(uiWidget)
# Set scene in MRML widgets. Make sure that in Qt designer
# "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's.
# "setMRMLScene(vtkMRMLScene*)" slot.
'''# Example of adding widgets dynamically (without Qt designer).
# This approach is not recommended, but only shown as an illustrative example.
parametersCollapsibleButton = ctk.ctkCollapsibleButton()
parametersCollapsibleButton.text = "More"
parametersCollapsibleButton.collapsed = True
parametersFormLayout = qt.QFormLayout(parametersCollapsibleButton)
self.invertedOutputSelector = slicer.qMRMLNodeComboBox()
self.invertedOutputSelector.nodeTypes = ["vtkMRMLScalarVolumeNode"]
self.invertedOutputSelector.addEnabled = True
self.invertedOutputSelector.removeEnabled = True
self.invertedOutputSelector.noneEnabled = True
self.invertedOutputSelector.setToolTip("Result with inverted threshold will be written into this volume")
parametersFormLayout.addRow("Inverted output volume: ", self.invertedOutputSelector)
# Create a new parameterNode
# This parameterNode stores all user choices in parameter values, node selections, etc.
# so that when the scene is saved and reloaded, these settings are restored.
self.logic = StitchVolumesLogic()
self.ui.parameterNodeSelector.addAttribute("vtkMRMLScriptedModuleNode", "ModuleName", self.moduleName)
# Connections
self.ui.parameterNodeSelector.connect('currentNodeChanged(vtkMRMLNode*)', self.setParameterNode)
self.ui.applyButton.connect('clicked(bool)', self.onApplyButton)
# These connections ensure that whenever user changes some settings on the GUI, that is saved in the MRML scene
# (in the selected parameter node).
self.ui.roiSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
self.ui.volumeSelector1.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
self.ui.volumeSelector2.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
self.ui.volumeSelector3.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
self.ui.volumeSelector4.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
self.ui.volumeSelector5.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
#self.ui.imageThresholdSliderWidget.connect("valueChanged(double)", self.updateParameterNodeFromGUI)
#self.ui.invertOutputCheckBox.connect("toggled(bool)", self.updateParameterNodeFromGUI)
#self.invertedOutputSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
# Initial GUI update
def cleanup(self):
Called when the application closes and the module widget is destroyed.
def setParameterNode(self, inputParameterNode):
Adds observers to the selected parameter node. Observation is needed because when the
parameter node is changed then the GUI must be updated immediately.
if inputParameterNode:
# Set parameter node in the parameter node selector widget
wasBlocked = self.ui.parameterNodeSelector.blockSignals(True)
if inputParameterNode == self._parameterNode:
# No change
# Unobserve previusly selected parameter node and add an observer to the newly selected.
# Changes of parameter node are observed so that whenever parameters are changed by a script or any other module
# those are reflected immediately in the GUI.
if self._parameterNode is not None:
self.removeObserver(self._parameterNode, vtk.vtkCommand.ModifiedEvent, self.updateGUIFromParameterNode)
if inputParameterNode is not None:
self.addObserver(inputParameterNode, vtk.vtkCommand.ModifiedEvent, self.updateGUIFromParameterNode)
self._parameterNode = inputParameterNode
# Initial GUI update
def updateGUIFromParameterNode(self, caller=None, event=None):
This method is called whenever parameter node is changed.
The module GUI is updated to show the current state of the parameter node.
# Disable all sections if no parameter node is selected
self.ui.basicCollapsibleButton.enabled = self._parameterNode is not None
#self.ui.advancedCollapsibleButton.enabled = self._parameterNode is not None
if self._parameterNode is None:
# Update each widget from parameter node
# Need to temporarily block signals to prevent infinite recursion (MRML node update triggers
# GUI update, which triggers MRML node update, which triggers GUI update, ...)
wasBlocked = self.ui.roiSelector.blockSignals(True)
wasBlocked = self.ui.volumeSelector1.blockSignals(True)
wasBlocked = self.ui.volumeSelector2.blockSignals(True)
wasBlocked = self.ui.volumeSelector3.blockSignals(True)
wasBlocked = self.ui.volumeSelector4.blockSignals(True)
wasBlocked = self.ui.volumeSelector5.blockSignals(True)
# What about other values? (current text, e.g.)? The example code did not update them here
# Update buttons states and tooltips
# Enable the Stitch Volumes button if there is an ROI, at least two original volumes, and a name for the output vol
if (self._parameterNode.GetNodeReference('StitchedVolumeROI') and
self._parameterNode.GetNodeReference('InputVol1') and
self._parameterNode.GetNodeReference('InputVol2') and
self.ui.applyButton.toolTip = 'Compute stitched volume'
self.ui.applyButton.enabled = True
self.ui.applyButton.toolTip = 'Enter inputs to enable stitching'
self.ui.applyButton.enabled = False
def updateParameterNodeFromGUI(self, caller=None, event=None):
This method is called when the user makes any change in the GUI.
The changes are saved into the parameter node (so that they are restored when the scene is saved and loaded).
if self._parameterNode is None:
self._parameterNode.SetNodeReferenceID('StitchedVolumeROI', self.ui.roiSelector.currentNodeID)
#self._parameterNode.SetNodeReferenceID("InputVolume", self.ui.inputSelector.currentNodeID)
#self._parameterNode.SetNodeReferenceID("OutputVolume", self.ui.outputSelector.currentNodeID)
#self._parameterNode.SetParameter("Threshold", str(self.ui.imageThresholdSliderWidget.value))
#self._parameterNode.SetParameter("Invert", "true" if self.ui.invertOutputCheckBox.checked else "false")
#self._parameterNode.SetNodeReferenceID("OutputVolumeInverse", self.invertedOutputSelector.currentNodeID)
def onApplyButton(self):
Run processing when user clicks "Apply" button.
# Gather inputs
orig_nodes = self.gather_original_nodes()
roi_node = self.ui.roiSelector.currentNode()
stitched_vol_name = self.ui.stitchVolNameLineEdit.text
# Run the stitching
self.logic.stitch_volumes(orig_nodes, roi_node, stitched_vol_name, keep_intermediate_volumes=False)
except Exception as e:
slicer.util.errorDisplay("Failed to compute results: "+str(e))
import traceback
def gather_original_nodes(self):
orig_nodes = []
if self.ui.volumeSelector1.currentNode():
if self.ui.volumeSelector2.currentNode():
if self.ui.volumeSelector3.currentNode():
if self.ui.volumeSelector4.currentNode():
if self.ui.volumeSelector5.currentNode():
return orig_nodes
# StitchVolumesLogic
class StitchVolumesLogic(ScriptedLoadableModuleLogic):
"""This class should implement all the actual
computation done by your module. The interface
should be such that other python code can import
this class and make use of the functionality without
requiring an instance of the Widget.
Uses ScriptedLoadableModuleLogic base class, available at:
def setDefaultParameters(self, parameterNode):
Initialize parameter node with default settings.
if not parameterNode.GetParameter('OutputVolName'):
def stitch_volumes(self, orig_nodes, roi_node, stitched_vol_name, keep_intermediate_volumes=False):
# Stitch together the supplied original volumes, resampling them
# into the space defined by the supplied roi, putting the stitched
# output into a volume with the given stitched volume name
# Crop/Resample first orig node
ref_vol_node = resample_volume(roi_node, orig_nodes[0], 'ReferenceVolume')
# Resample other nodes
resamp_vol_nodes = []
for orig_node in orig_nodes:
resampled_name = 'Resamp_'+orig_node.GetName()
resamp_node = createOrReplaceNode(resampled_name)
resamp_vol_nodes.append(resample(orig_node, ref_vol_node, resamp_node))
imArrays = [slicer.util.arrayFromVolume(resamp_vol_node) for resamp_vol_node in resamp_vol_nodes]
# Create output stitched volume node, create by cloning one of the resamp nodes
# (it doesn't matter which one, it's just being used to get orientation and spacing)
stitched_vol_node = slicer.vtkSlicerVolumesLogic().CloneVolume(slicer.mrmlScene, resamp_vol_nodes[0], stitched_vol_name)
# Find the dimension to stitch together (I,J,or K)
dim_to_stitch = find_dim_to_stitch(orig_nodes, resamp_vol_nodes[0])
# dim_to_stitch is 0, 1, or 2, depending on whether the dimension to stitch is
# K,J, or I, respectively (recalling that np arrays are KJI)
other_dims = tuple({0,1,2}-{dim_to_stitch})
# We can now sample each resampled volume in along the stitch dimension to
# figure out where the data starts and
# stops for each of them. Then, we can order them by data start value.
dataSlices = [np.sum(imArray,axis=other_dims) != 0 for imArray in imArrays]
dataStartIdxs = [np.nonzero(dataSlice)[0][0] for dataSlice in dataSlices]
dataEndIdxs = [np.nonzero(dataSlice)[0][-1] for dataSlice in dataSlices]
# Re-order in increasing dataStartIdx order
ordered = sorted(zip(dataStartIdxs, imArrays, dataEndIdxs), key=lambda pair: pair[0])
orderedDataStartIdxs, orderedImArrays, orderedDataEndIdxs = zip(*ordered)
imCombined = np.zeros(imArrays[0].shape)
# We can use the starting and ending indices to determine whether there is overlap
priorOverlapFlag = False
for imIdx in range(len(orderedImArrays)):
imArray = orderedImArrays[imIdx]
start1 = orderedDataStartIdxs[imIdx]
end1 = orderedDataEndIdxs[imIdx] + 1 # add 1 because of python indexing
if imIdx==(len(orderedImArrays)-1):
# There is no next volume, just run out to the end of volume
start2 = end1+1
# Get the start idx of the next volume
start2 = orderedDataStartIdxs[imIdx+1]
if priorOverlapFlag:
start1 = nextStartIdx
# Is there overlap?
if start2 < end1:
# There is overlap, the end idx should be shortened
end1 = np.ceil((end1+1+start2)/2.0).astype(int) # don't add one, already accounted for
priorOverlapFlag = True
nextStartIdx = end1
priorOverlapFlag = False
nextStartIdx = None
sliceIndexTuple = getSliceIndexTuple(start1,end1,dim_to_stitch)
imCombined[sliceIndexTuple] = imArray[sliceIndexTuple]
# Put the result into the stitched volume
# Clean up
if not keep_intermediate_volumes:
for resamp_vol_node in resamp_vol_nodes:
# Return stitched volume node
return stitched_vol_node
# StitchVolumesTest
class StitchVolumesTest(ScriptedLoadableModuleTest):
This is the test case for your scripted module.
Uses ScriptedLoadableModuleTest base class, available at:
def setUp(self):
""" Do whatever is needed to reset the state - typically a scene clear will be enough.
def runTest(self):
"""Run as few or as many tests as needed here.
def test_StitchVolumes1(self):
""" Ideally you should have several levels of tests. At the lowest level
tests should exercise the functionality of the logic with different inputs
(both valid and invalid). At higher levels your tests should emulate the
way the user would interact with your code and confirm that it still works
the way you intended.
One of the most important features of the tests is that it should alert other
developers when their changes will have an impact on the behavior of your
module. For example, if a developer removes a feature that you depend on,
your test should break so they know that the feature is needed.
self.delayDisplay("Starting the test")
# Get/create input data
import SampleData
inputVolume = SampleData.downloadFromURL(
self.delayDisplay('Finished with download and loading')
inputScalarRange = inputVolume.GetImageData().GetScalarRange()
self.assertEqual(inputScalarRange[0], 0)
self.assertEqual(inputScalarRange[1], 279)
outputVolume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLScalarVolumeNode")
threshold = 50
# Test the module logic
logic = StitchVolumesLogic()
# Test algorithm with non-inverted threshold, outputVolume, threshold, True)
outputScalarRange = outputVolume.GetImageData().GetScalarRange()
self.assertEqual(outputScalarRange[0], inputScalarRange[0])
self.assertEqual(outputScalarRange[1], threshold)
# Test algorithm with inverted threshold, outputVolume, threshold, False)
outputScalarRange = outputVolume.GetImageData().GetScalarRange()
self.assertEqual(outputScalarRange[0], inputScalarRange[0])
self.assertEqual(outputScalarRange[1], inputScalarRange[1])
self.delayDisplay('Test passed')
# Subfunctions
def get_RAS_center(vol_node):
b = [0]*6
cen = [np.mean([b[0],b[1]]), np.mean([b[2],b[3]]), np.mean([b[4],b[5]])]
return cen
def ras_to_ijk(point_ras, vol_node, return_ints_flag=False, use_volume_transform_flag=True):
# Return the IJK coord corresponding to the RAS location
# of the supplied point in the given volume.
if use_volume_transform_flag:
# If volume node is transformed, apply that transform to get volume's RAS coordinates
transformRasToVolumeRas = vtk.vtkGeneralTransform()
slicer.vtkMRMLTransformNode.GetTransformBetweenNodes(None, vol_node.GetParentTransformNode(), transformRasToVolumeRas)
point_VolumeRas = transformRasToVolumeRas.TransformPoint(point_ras[0:3])
point_VolumeRas = point_ras
# Get voxel coordinates from physical coordinates
volumeRasToIjk = vtk.vtkMatrix4x4()
point_Ijk = [0, 0, 0, 1]
volumeRasToIjk.MultiplyPoint(np.append(point_VolumeRas,1.0), point_Ijk)
# Trim homogenous coord
point_ijk = point_Ijk[0:3]
# Round to integers if requested
if return_ints_flag:
point_ijk = [ int(round(c)) for c in point_ijk]
return point_ijk
def find_dim_to_stitch(orig_nodes,resamp_node):
# This function determines the dimension to stitch the original nodes along by
# finding the image axis dimension (I,J,or K) which is best aligned with the
# vector between the centers of the furthest apart original volumes.
# A resampled volume is needed just in case its IJK direction matrix
# differs from the original nodes. I believe this method should be
# fairly robust.
RAS_centers = [get_RAS_center(vol) for vol in orig_nodes]
dists = [np.linalg.norm(np.subtract(RAS_center,RAS_centers[0])) for RAS_center in RAS_centers]
furthest_from_first = np.argmax(dists)
stitch_vect = np.subtract(RAS_centers[0],RAS_centers[furthest_from_first])
stitch_vect = stitch_vect/np.linalg.norm(stitch_vect)
#RAS_biggest_change_idx= np.argmax(np.abs(stitch_vect))
# Now I need to know which image volume axis (I,J,or K) is most aligned with the stitching vector
# We can do this by comparing the dot products of each of the I J and K vectors with the stitch
# vector. The one with the maximum abs dot product is the winner
ijkdirs = [[0,0,0],[0,0,0],[0,0,0]]
absDotsIJK = [np.abs(,stitch_vect)) for d in ijkdirs]
IJKmatchIdx = np.argmax(absDotsIJK)
KJImatchIdx = 2-IJKmatchIdx
dim_to_stitch = KJImatchIdx
return dim_to_stitch
def createOrReplaceNode(name,nodeClass='vtkMRMLScalarVolumeNode'):
node = slicer.util.getNode(name)
node = slicer.mrmlScene.AddNewNodeByClass(nodeClass,name)
return node
def resample_volume(roi_node, input_vol_node, output_vol_name):
# Carry out the cropping
cropVolumeNode = slicer.vtkMRMLCropVolumeParametersNode()
cropVolumeNode.SetInterpolationMode(cropVolumeNode.InterpolationNearestNeighbor) # use nearest neighbor to avoid resampling artifacts
cropVolumeNode.SetFillValue(0) # needs to be zero so that sum of filled slices is zero
cropVolumeNode.SetROINodeID(roi_node.GetID()) # roi
output_vol_node = createOrReplaceNode(output_vol_name,'vtkMRMLScalarVolumeNode')
cropVolumeNode.SetInputVolumeNodeID(input_vol_node.GetID()) # input
cropVolumeNode.SetOutputVolumeNodeID(output_vol_node.GetID()) # output
slicer.modules.cropvolume.logic().Apply(cropVolumeNode) # do the crop
return output_vol_node
def resample(vol_node_to_resample, reference_vol_node, output_vol_node=None, interpolationMode='NearestNeighbor'):
# Handle resampling a second node based on the geometry of reference node.
# Switch method and warn if NearestNeighbor is selected and inappropriate
if interpolationMode=='NearestNeighbor':
import numpy as np
maxVoxDimDiff = np.max(np.abs(np.subtract(reference_vol_node.GetSpacing(),vol_node_to_resample.GetSpacing())))
if maxVoxDimDiff > 1e-4:
logging.warning('Automatically switching from NearestNeighbor interpolation to Linear interpolation because the volume to resample (%s) has a different resolution (%0.2fmm x %0.2fmm x %0.2fmm) than the first original volume (%s, %0.2fmm x %0.2fmm x %0.2fmm)'%(
vol_node_to_resample.GetName(), *vol_node_to_resample.GetSpacing(), reference_vol_node.GetName(), *reference_vol_node.GetSpacing()))
inputVolID = vol_node_to_resample.GetID()
refVolID = reference_vol_node.GetID()
if output_vol_node is None:
output_vol_node = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLScalarVolumeNode')
outputVolID = output_vol_node.GetID()
params = {'inputVolume': inputVolID,
'referenceVolume': refVolID,
'outputVolume': outputVolID,
'interpolationMode': interpolationMode,
'defaultValue': 0}
slicer.cli.runSync(slicer.modules.brainsresample, None, params)
return output_vol_node
def getSliceIndexTuple(start,end,dim_to_stitch,nDims=3):
# Constructs a tuple which can be used as an index into a 3D array
# To illustrate, if the dim_to_stitch were 1, the output would be
# (slice(None),slice(start:end),slice(None)), which can be used in
# indexing into a 3D array equivalently to arr[:,start:end,:]
sliceIndexList = []
for dim in range(nDims):
if dim==dim_to_stitch:
return tuple(sliceIndexList)
def rename_dixon_dicom_volumes(volNodes=None):
# substitutes the "imageType N" with the Dixon type ("F","W","OP", or "IP")
# If volume is not a DICOM volume, then it is left unchanged
import re
if volNodes is None:
# Gather all scalar volumes in the scene
volNodes = []
shNode = slicer.mrmlScene.GetSubjectHierarchyNode()
sceneItemID = shNode.GetSceneItemID()
c = vtk.vtkCollection()
for idx in range(c.GetNumberOfItems()):
# Loop over all volumes, renaming only if DICOM and if node name matches r"imageType \d"
for volNode in volNodes:
uids = volNode.GetAttribute('DICOM.instanceUIDs') # empty for non DICOM volumes
imageTypeField = '0008,0008' # DICOM field corresponding to ImageType
if uids is not None:
uid = uids.split()[0] # all of these UIDs have the same ImageType (at least so far as I tested)
filename = slicer.dicomDatabase.fileForInstance(uid)
imageType = slicer.dicomDatabase.fileValue(filename, imageTypeField) # looks like "DERIVED\PRIMARY\OP\OP\DERIVED"
dixonType = imageType.split('\\')[2] # pulls out the 3rd entry in that field
origVolName = volNode.GetName()
# Substitute dixon type for 'imageType N'
newName = re.sub(r'imageType \d', dixonType, origVolName)
