Training data augmentation with random volume translation, rotation, and deformation
# This script randomly warps a 3D volume and adds random translations, rotations, | |
# and save each resulting 3D volume (and a screenshot for quick overview) | |
# | |
# The script can be executed by copy-pasting into 3D Slicer's Python console | |
# or in a Jupyter notebook running 3D Slicer kernel (provided by SlicerJupyter extension). | |
# | |
# Prerequisites: | |
# - Recent Slicer-4.11 version | |
# - SlicerIGT extension installed (for random deformations) | |
import SampleData | |
import ScreenCapture | |
import numpy as np | |
import os | |
############################# | |
# Set inputs | |
slicer.mrmlScene.Clear() | |
# Set input volume that will be deformed. To load your own volume from file, use: | |
# slicer.util.loadVolume("c:/path/to/myvolume.nrrd") | |
volumeNode = SampleData.SampleDataLogic().downloadMRBrainTumor1() | |
# Load sample segmentation node | |
segmentationNode = slicer.util.loadSegmentation(SampleData.downloadFromURL(fileNames='MRBrainTumor1.seg.nrrd', | |
loadFileTypes=['SegmentationFile'], | |
uris='https://github.com/Slicer/SlicerTestingData/releases/download/SHA256/dfd34fe31b48e605a8419efb7427ed42030c90d163b9785fe92692474e664310', | |
checksums='SHA256:dfd34fe31b48e605a8419efb7427ed42030c90d163b9785fe92692474e664310')[0]) | |
# Convert segmentation node to label volume node | |
labelVolumeNode = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLLabelMapVolumeNode') | |
slicer.modules.segmentations.logic().ExportAllSegmentsToLabelmapNode(segmentationNode, labelVolumeNode) | |
slicer.mrmlScene.RemoveNode(segmentationNode) | |
# Set output parameters | |
numberOfOutputVolumesToCreate = 24 | |
outputVolumeFilenamePattern = slicer.app.temporaryPath+"/volumes/transformedVolume_%04d.nrrd" | |
outputLabelVolumeFilenamePattern = slicer.app.temporaryPath+"/volumes/transformedVolume_%04d-label.nrrd" | |
outputScreenshotsFilenamePattern = slicer.app.temporaryPath+"/screenshots/transformedVolume_%04d.png" | |
# Transformation parameters | |
translationStDev = 10.0 | |
rotationDegStDev = 5.0 | |
warpingControlPointsSpacing = 70 | |
warpingDisplacementStdDev = 5.0 | |
############################# | |
# Processing | |
# Create output folders | |
for filepath in [outputVolumeFilenamePattern, outputLabelVolumeFilenamePattern, outputScreenshotsFilenamePattern]: | |
filedir = os.path.dirname(filepath) | |
if not os.path.exists(filedir): | |
os.makedirs(filedir) | |
# Set up warping transform computation | |
pointsFrom = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "PointsFrom") | |
pointsFrom.SetLocked(True) | |
pointsFrom.GetDisplayNode().SetPointLabelsVisibility(False) | |
pointsFrom.GetDisplayNode().SetSelectedColor(0,1,0) | |
pointsTo = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "PointsTo") | |
pointsTo.GetDisplayNode().SetPointLabelsVisibility(False) | |
volumeBounds=[0,0,0,0,0,0] | |
volumeNode.GetBounds(volumeBounds) | |
warpingTransformNode = None | |
if hasattr(slicer.modules, "fiducialregistrationwizard"): | |
warpingTransformNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLTransformNode", "WarpingTransform") | |
fidReg = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLFiducialRegistrationWizardNode") | |
fidReg.SetRegistrationModeToWarping() | |
fidReg.SetAndObserveFromFiducialListNodeId(pointsFrom.GetID()) | |
fidReg.SetAndObserveToFiducialListNodeId(pointsTo.GetID()) | |
fidReg.SetOutputTransformNodeId(warpingTransformNode.GetID()) | |
else: | |
slicer.util.errorDisplay("SlicerIGT extension is required for applying warping transform") | |
# Set up linear transform computation | |
fullTransformNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLTransformNode", "FullTransform") | |
fullTransformNode.SetAndObserveMatrixTransformToParent(vtk.vtkMatrix4x4()) | |
#fullTransformNode.AddDefaultStorageNode() | |
volumeNode.SetAndObserveTransformNodeID(fullTransformNode.GetID()) | |
# Set up transformation chain: volume is warped, then translated&rotated | |
transformedVolumeNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLScalarVolumeNode") | |
parameters = { | |
"inputVolume": volumeNode.GetID(), | |
"outputVolume": transformedVolumeNode.GetID(), | |
"referenceVolume": volumeNode.GetID(), | |
"transformationFile": fullTransformNode.GetID()} | |
# Initial resampling (without transformation) | |
resampleParameterNode = slicer.cli.runSync(slicer.modules.resamplescalarvectordwivolume, None, parameters) | |
# Set up visualization for screenshots | |
slicer.app.layoutManager().setLayout(slicer.vtkMRMLLayoutNode.SlicerLayoutFourUpView) | |
slicer.util.setSliceViewerLayers(background=transformedVolumeNode, fit=True) | |
pointsFrom.GetDisplayNode().SetVisibility(False) | |
pointsTo.GetDisplayNode().SetVisibility(False) | |
slicer.app.layoutManager().threeDWidget(0).mrmlViewNode().SetBackgroundColor(0,0,0) | |
slicer.app.layoutManager().threeDWidget(0).mrmlViewNode().SetBackgroundColor2(0,0,0) | |
# Volume rendering | |
volRenLogic = slicer.modules.volumerendering.logic() | |
displayNode = volRenLogic.CreateDefaultVolumeRenderingNodes(transformedVolumeNode) | |
displayNode.SetVisibility(True) | |
scalarRange = transformedVolumeNode.GetImageData().GetScalarRange() | |
if scalarRange[1]-scalarRange[0] < 1500: | |
# small dynamic range, probably MRI | |
displayNode.GetVolumePropertyNode().Copy(volRenLogic.GetPresetByName('MR-Default')) | |
else: | |
# larger dynamic range, probably CT | |
displayNode.GetVolumePropertyNode().Copy(volRenLogic.GetPresetByName('CT-Chest-Contrast-Enhanced')) | |
# Generate as many deformed volumes as requested | |
for outputVolumeIndex in range(numberOfOutputVolumesToCreate): | |
# translation and rotation | |
fullTransform = vtk.vtkGeneralTransform() | |
if warpingTransformNode: | |
# warping | |
controlPointCoordsSplit = np.mgrid[ | |
volumeBounds[0]:volumeBounds[1]:warpingControlPointsSpacing, | |
volumeBounds[2]:volumeBounds[3]:warpingControlPointsSpacing, | |
volumeBounds[4]:volumeBounds[5]:warpingControlPointsSpacing] | |
controlPointCoords = np.vstack([controlPointCoordsSplit[0].ravel(), controlPointCoordsSplit[1].ravel(), controlPointCoordsSplit[2].ravel()]).T | |
controlPointDisplacement = np.random.normal(0, warpingDisplacementStdDev, size=controlPointCoords.shape) | |
slicer.util.updateMarkupsControlPointsFromArray(pointsFrom, controlPointCoords) | |
slicer.util.updateMarkupsControlPointsFromArray(pointsTo, controlPointCoords + controlPointDisplacement) | |
fullTransform.Concatenate(warpingTransformNode.GetTransformFromParent()) | |
fullTransform.Translate(np.random.normal(0, translationStDev, 3)) | |
fullTransform.RotateX(np.random.normal(0, rotationDegStDev)) | |
fullTransform.RotateY(np.random.normal(0, rotationDegStDev)) | |
fullTransform.RotateZ(np.random.normal(0, rotationDegStDev)) | |
fullTransformNode.SetAndObserveTransformFromParent(fullTransform) | |
# Compute transformed label volume and save to file | |
parameters["inputVolume"] = labelVolumeNode.GetID() | |
parameters["interpolationType"] = "nn" # nearest neighbor to preserve label values | |
resampleParameterNode = slicer.cli.runSync(slicer.modules.resamplescalarvectordwivolume, resampleParameterNode, parameters) | |
# Save result volume | |
outputFilename = outputLabelVolumeFilenamePattern % outputVolumeIndex | |
print("Save transformed label volume {0}/{1} as {2}".format(outputVolumeIndex+1, numberOfOutputVolumesToCreate, outputFilename)) | |
success = slicer.util.saveNode(transformedVolumeNode, outputFilename) | |
# Compute transformed volume and save to file | |
parameters["inputVolume"] = volumeNode.GetID() | |
parameters["interpolationType"] = "linear" | |
resampleParameterNode = slicer.cli.runSync(slicer.modules.resamplescalarvectordwivolume, resampleParameterNode, parameters) | |
# Save result volume | |
outputFilename = outputVolumeFilenamePattern % outputVolumeIndex | |
print("Save transformed volume {0}/{1} as {2}".format(outputVolumeIndex+1, numberOfOutputVolumesToCreate, outputFilename)) | |
success = slicer.util.saveNode(transformedVolumeNode, outputFilename) | |
# Save result screenshot | |
cap = ScreenCapture.ScreenCaptureLogic() | |
cap.showViewControllers(False) | |
outputFilename = outputScreenshotsFilenamePattern % outputVolumeIndex | |
cap.captureImageFromView(None, outputFilename) | |
cap.showViewControllers(True) | |
# Create gallery view of all augmented images | |
cap.createLightboxImage(8, | |
os.path.dirname(outputScreenshotsFilenamePattern), | |
os.path.basename(outputScreenshotsFilenamePattern), | |
numberOfOutputVolumesToCreate, | |
os.path.dirname(outputScreenshotsFilenamePattern)+"/gallery.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment