Skip to content

Instantly share code, notes, and snippets.

@pieper
Created December 8, 2022 18:49
Show Gist options
  • Save pieper/f33cfd6afef95a8b0818ac7b1dc00243 to your computer and use it in GitHub Desktop.
Save pieper/f33cfd6afef95a8b0818ac7b1dc00243 to your computer and use it in GitHub Desktop.
example slicer transforms
"""
exec(open("/opt/data/nac/neurosurgicalatlas/ss2-Ture-028-registration.py").read())
"""
import glob
from math import floor, ceil
import numpy
segmentsToComposite = {}
falkPath = "/opt/data/nac/neurosurgicalatlas/sub-01_avg-04_T1w-1mm.nrrd"
falkSS2Path = "/opt/data/nac/neurosurgicalatlas/sub-01_avg-04_T1w_synthseg.nii.gz"
fixedPath = f"/opt/data/nac/neurosurgicalatlas/Ture/Ture-subset3/Ture-028/Ture-028-T1C.nrrd"
fixedSS2Path = "/opt/data/nac/neurosurgicalatlas/Ture/Ture-028ssml.seg-2.nrrd"
movingVolume = slicer.util.loadVolume(falkPath)
movingSegmentation = slicer.util.loadSegmentation(falkSS2Path)
fixedVolume = slicer.util.loadVolume(fixedPath)
fixedSegmentation = slicer.util.loadSegmentation(fixedSS2Path)
movingVolume.SetName("Moving Volume")
movingSegmentation.SetName("Moving Segmentation")
fixedVolume.SetName("Fixed Volume")
fixedSegmentation.SetName("Fixed Segmentation")
# fixedSegmentation.CreateClosedSurfaceRepresentation()
# movingSegmentation.CreateClosedSurfaceRepresentation()
movingSegmentation.GetDisplayNode().SetVisibility(False)
fixedSegmentation.GetDisplayNode().SetVisibility(False)
nsaPath = "/opt/data/nac/neurosurgicalatlas/250um_to_Phantom"
nsaSegmentation = loadSegmentation(f"{nsaPath}/NSA_LabelMap_wSkinSkull.nrrd")
nsaTransform = loadTransform(f"{nsaPath}/Transform-grid_250um-to-Phantom.h5")
nsaTransform.Inverse()
nsaSegmentation.SetAndObserveTransformNodeID(nsaTransform.GetID())
# nsaSegmentation.CreateClosedSurfaceRepresentation()
# Compute centroids
landmarkTransform = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLTransformNode")
landmarkTransform.SetName("Landmark Transform")
thinPlateTransform = vtk.vtkThinPlateSplineTransform()
thinPlateTransform.SetBasisToR()
landmarkTransform.SetAndObserveTransformToParent(thinPlateTransform)
movingPoints = vtk.vtkPoints()
fixedPoints = vtk.vtkPoints()
transformForSegment = vtk.vtkGeneralTransform()
def getSegmentIDFromLabelValue(segmentation, labelValue):
for segmentId in segmentation.GetSegmentation().GetSegmentIDs():
if segmentation.GetSegmentation().GetSegment(segmentId).GetLabelValue() == labelValue:
return segmentId
return None
import SegmentStatistics
segStatLogic = SegmentStatistics.SegmentStatisticsLogic()
segStatLogic.getParameterNode().SetParameter("LabelmapSegmentStatisticsPlugin.centroid_ras.enabled", str(True))
segStatLogic.getParameterNode().SetParameter("visibleSegmentsOnly", str(False))
segStatLogic.getParameterNode().SetParameter("Segmentation", movingSegmentation.GetID())
segStatLogic.computeStatistics()
movingStats = segStatLogic.getStatistics()
segStatLogic.getParameterNode().SetParameter("Segmentation", fixedSegmentation.GetID())
segStatLogic.computeStatistics()
fixedStats = segStatLogic.getStatistics()
for movingSegmentId in movingStats["SegmentIDs"]:
movingCentroid = movingStats[movingSegmentId,"LabelmapSegmentStatisticsPlugin.centroid_ras"]
movingLabelValue = movingSegmentation.GetSegmentation().GetSegment(movingSegmentId).GetLabelValue()
fixedSegmentId = getSegmentIDFromLabelValue(fixedSegmentation, movingLabelValue)
fixedSegmentName = fixedSegmentation.GetSegmentation().GetSegment(fixedSegmentId).GetName()
movingSegmentation.GetSegmentation().GetSegment(movingSegmentId).SetName(fixedSegmentName)
try:
fixedCentroid = fixedStats[fixedSegmentId,"LabelmapSegmentStatisticsPlugin.centroid_ras"]
except KeyError:
print(f"Skipping segment {movingSegmentId}, mapped to {fixedSegmentId}")
segmentsToComposite[movingSegmentId] = False
continue
movingPoints.InsertNextPoint(movingCentroid)
fixedPoints.InsertNextPoint(fixedCentroid)
segmentsToComposite[movingSegmentId] = True
thinPlateTransform.SetSourceLandmarks(movingPoints)
thinPlateTransform.SetTargetLandmarks(fixedPoints)
thinPlateTransform.Update()
movingVolume.SetAndObserveTransformNodeID(landmarkTransform.GetID())
movingSegmentation.SetAndObserveTransformNodeID(landmarkTransform.GetID())
nsaTransform.SetAndObserveTransformNodeID(landmarkTransform.GetID())
# Converts the current thin plate transform to a grid
# since the transform is ras-to-ras, we find the extreme points
# in ras space of the fixed (target) volume and fix the unoriented
# box around it. Sample the grid transform at the resolution of
# the fixed volume, which may be a bit overkill but it should aways
# work without too much loss.
gridTransformDownsample = 3
rasBounds = [0,]*6
fixedVolume.GetRASBounds(rasBounds)
origin = list(map(int,map(floor,rasBounds[::2])))
maxes = list(map(int,map(ceil,rasBounds[1::2])))
boundSize = [m - o for m,o in zip(maxes,origin) ]
spacing = fixedVolume.GetSpacing()
spacing = [max(spacing)*gridTransformDownsample]*3
samples = [ceil(int(b / s)) for b,s in zip(boundSize,spacing)]
extent = [0,]*6
extent[::2] = [0,]*3
extent[1::2] = samples
extent = list(map(int,extent))
toGrid = vtk.vtkTransformToGrid()
toGrid.SetGridOrigin(origin)
toGrid.SetGridSpacing(spacing)
toGrid.SetGridExtent(extent)
toGrid.SetInput(landmarkTransform.GetTransformToParent())
toGrid.Update()
gridTransform = slicer.vtkOrientedGridTransform()
gridTransform.SetDisplacementGridData(toGrid.GetOutput())
gridNode = slicer.vtkMRMLGridTransformNode()
gridNode.SetAndObserveTransformToParent(gridTransform)
gridNode.SetName(landmarkTransform.GetName()+"-grid")
slicer.mrmlScene.AddNode(gridNode)
def showOnly(segmentationNodes, segments):
for segmentationNode in segmentationNodes:
segmentation = segmentationNode.GetSegmentation()
for segmentID in segmentation.GetSegmentIDs():
visibility = (segmentID in segments) or (segmentation.GetSegment(segmentID).GetName() in segments)
segmentationNode.GetDisplayNode().SetSegmentVisibility(segmentID, visibility)
segmentationNode.GetDisplayNode().SetSegmentVisibility3D(segmentID, visibility)
showOnly([fixedSegmentation, movingSegmentation], ["ctx-rh-insula", "ctx-rh-superiotemporal","ctx-rh-lateralorbitofrontal"])
showOnly([fixedSegmentation, movingSegmentation], ["ctx-rh-lateralorbitofrontal"])
showOnly([fixedSegmentation, movingSegmentation], ["ctx-rh-insula"])
"""
transform = registerSegment(fixedSegmentation, movingSegmentation, "Segment_10")
showOnly([fixedSegmentation, movingSegmentation], fixedSegmentation.GetSegmentation().GetSegmentIDs())
for segment in movingSegmentation.GetSegmentation().GetSegmentIDs():
color = fixedSegmentation.GetSegmentation().GetSegment(segment).GetColor()
movingSegmentation.GetSegmentation().GetSegment(segment).SetColor(color)
"""
"""
movingPointList = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode")
movingPointList.CreateDefaultDisplayNodes()
movingPointList.SetName(f"{movingVolume.GetName()}-landmarks")
fixedPointList = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode")
fixedPointList.CreateDefaultDisplayNodes()
fixedPointList.SetName(f"{fixedVolume.GetName()}-landmarks")
transformedPointList = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode")
transformedPointList.CreateDefaultDisplayNodes()
transformedPointList.SetName(f"{movingVolume.GetName()}-transformed-landmarks")
fixedPointList.AddFiducialFromArray(fixedCentroid, segmentName)
transformedPointList.AddFiducialFromArray(fixedCentroid, segmentName)
movingPointList.AddFiducialFromArray(movingCentroid, segmentName)
# for:
# nsPath = "/opt/data/nac/neurosurgicalatlas/data-2022-02-22/Pathfinder_PhantomBrain_SynthSeg_Labels_v05/NSA_Phantom_LabellMap_v05-RAS.nrrd"
movingLabel = slicer.util.loadLabelVolume(nsPath)
movingArray = slicer.util.arrayFromVolume(movingLabel)
backwardRange = list(range(len(indices)))
backwardRange.reverse()
for index in backwardRange:
print(f"asigning {index+1} to be {indices[index]}")
movingArray[movingArray == index+1] = indices[index]
slicer.util.arrayFromVolumeModified(movingLabel)
nsLabelPath = f"{slicer.app.temporaryPath}/nsLabel.nrrd"
slicer.util.saveNode(movingLabel, nsLabelPath)
#slicer.mrmlScene.RemoveNode(movingLabel)
movingSegmentation = load(nsLabelPath, "nsHeadSeg", slicer.util.loadSegmentation)
for segmentID in movingSegmentation.GetSegmentation().GetSegmentIDs():
segment = movingSegmentation.GetSegmentation().GetSegment(segmentID)
segment.SetName(f"Segment_{indices[int(segment.GetName())-1]}")
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment