Skip to content

Instantly share code, notes, and snippets.

@mauigna06
Created October 29, 2021 20:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mauigna06/cb2b81cd0e1945ea2165fb7f0841adf3 to your computer and use it in GitHub Desktop.
Save mauigna06/cb2b81cd0e1945ea2165fb7f0841adf3 to your computer and use it in GitHub Desktop.
This an example of registration by SVD decomposition of the data. Useful to register very similar bones
#Posted example
import numpy as np
#SVD Registration
boneModel0 = getNode('deformedBone')
boneModel1 = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLModelNode',"fibulaMoved")
boneModel1.CreateDefaultDisplayNodes()
boneModel1.CopyContent(boneModel0)
boneModel1DisplayNode = boneModel1.GetDisplayNode()
color = boneModel0.GetDisplayNode().GetColor()
boneModel1DisplayNode.SetColor(color[2],color[0],color[1])
#Create random transformation matrix
import numpy as np
import random
axisOfRotation = np.zeros(3)
angleOfRotation = 0
for i in range(len(axisOfRotation)):
axisOfRotation[i] = random.random()
axisOfRotation = axisOfRotation/np.linalg.norm(axisOfRotation)
angleOfRotation = random.uniform(45, 315)
scaleMatrix = vtk.vtkMatrix4x4()
scaleMatrix.SetElement(0,0,random.uniform(1.3, 1.8))
scaleMatrix.SetElement(1,1,random.uniform(1.3, 1.8))
scaleMatrix.SetElement(1,2,random.uniform(1.3, 1.8))
origin = np.zeros(3)
for i in range(len(origin)):
origin[i] = random.uniform(50, 150)
randomTransform = vtk.vtkTransform()
randomTransform.PostMultiply()
randomTransform.RotateWXYZ(angleOfRotation,axisOfRotation)
randomTransform.Translate(origin)
#apply to boneModel1
transformer = vtk.vtkTransformFilter()
transformer.SetInputData(boneModel1.GetPolyData())
transformer.SetTransform(randomTransform)
transformer.Update()
boneModel1.SetAndObservePolyData(transformer.GetOutput())
maskPointsFilter = vtk.vtkMaskPoints()
maskPointsFilter.SetInputData(boneModel0.GetPolyData())
numberOfSampledPoints0 = 2000
ratio0 = int(boneModel0.GetPolyData().GetNumberOfPoints()/numberOfSampledPoints0)
#This works but the sampling could be biased spatially I think
maskPointsFilter.SetOnRatio(ratio0)
maskPointsFilter.RandomModeOn()
maskPointsFilter.Update()
polydata0 = vtk.vtkPolyData()
polydata0.ShallowCopy(maskPointsFilter.GetOutput())
maskPointsFilter = vtk.vtkMaskPoints()
maskPointsFilter.SetInputData(boneModel1.GetPolyData())
numberOfSampledPoints1 = 2000
ratio1 = int(boneModel1.GetPolyData().GetNumberOfPoints()/numberOfSampledPoints1)
#This works but the sampling could be biased spatially I think
maskPointsFilter.SetOnRatio(ratio1)
maskPointsFilter.RandomModeOn()
maskPointsFilter.Update()
polydata1 = vtk.vtkPolyData()
polydata1.ShallowCopy(maskPointsFilter.GetOutput())
#Calculate the SVD and mean
from vtk.util.numpy_support import vtk_to_numpy
model0Points = vtk_to_numpy(polydata0.GetPoints().GetData())
model1Points = vtk_to_numpy(polydata1.GetPoints().GetData())
# Calculate the mean of the points, i.e. the 'center' of the cloud
model0PointsMean = model0Points.mean(axis=0)
model1PointsMean = model1Points.mean(axis=0)
# Do an SVD on the mean-centered data.
uu0, dd0, eigenvectors0 = np.linalg.svd(model0Points - model0PointsMean)
uu1, dd1, eigenvectors1 = np.linalg.svd(model1Points - model1PointsMean)
# Create a frame for boneModel0
model0Z = np.zeros(3)
model0X = eigenvectors0[0]
model0Y = eigenvectors0[1]
vtk.vtkMath.Cross(model0X, model0Y, model0Z)
model0Z = model0Z/np.linalg.norm(model0Z)
model0Origin = model0PointsMean
def getAxes1ToWorldChangeOfFrameMatrix(axis1X,axis1Y,axis1Z,axisOrigin):
axes1ToWorldChangeOfFrameMatrix = vtk.vtkMatrix4x4()
axes1ToWorldChangeOfFrameMatrix.DeepCopy((1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1))
axes1ToWorldChangeOfFrameMatrix.SetElement(0,0,axis1X[0])
axes1ToWorldChangeOfFrameMatrix.SetElement(1,0,axis1X[1])
axes1ToWorldChangeOfFrameMatrix.SetElement(2,0,axis1X[2])
axes1ToWorldChangeOfFrameMatrix.SetElement(0,1,axis1Y[0])
axes1ToWorldChangeOfFrameMatrix.SetElement(1,1,axis1Y[1])
axes1ToWorldChangeOfFrameMatrix.SetElement(2,1,axis1Y[2])
axes1ToWorldChangeOfFrameMatrix.SetElement(0,2,axis1Z[0])
axes1ToWorldChangeOfFrameMatrix.SetElement(1,2,axis1Z[1])
axes1ToWorldChangeOfFrameMatrix.SetElement(2,2,axis1Z[2])
axes1ToWorldChangeOfFrameMatrix.SetElement(0,3,axisOrigin[0])
axes1ToWorldChangeOfFrameMatrix.SetElement(1,3,axisOrigin[1])
axes1ToWorldChangeOfFrameMatrix.SetElement(2,3,axisOrigin[2])
return axes1ToWorldChangeOfFrameMatrix
model0ToWorldMatrix = getAxes1ToWorldChangeOfFrameMatrix(model0X,model0Y,model0Z,model0Origin)
# temporal frame for boneModel1
model1Z = np.zeros(3)
model1X = eigenvectors1[0]
model1Y = eigenvectors1[1]
vtk.vtkMath.Cross(model1X, model1Y, model1Z)
model1Z = model1Z/np.linalg.norm(model1Z)
model1Origin = model1PointsMean
temporalModel1ToWorldMatrix = getAxes1ToWorldChangeOfFrameMatrix(model1X,model1Y,model1Z,model1Origin)
def getChangeOfSignMatrix(i,j):
changeOfSignMatrix = vtk.vtkMatrix4x4()
changeOfSignMatrix.DeepCopy((1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1))
if i==0:
changeOfSignMatrix.SetElement(0,0,1)
else:
changeOfSignMatrix.SetElement(0,0,-1)
#
if j==0:
changeOfSignMatrix.SetElement(1,1,1)
else:
changeOfSignMatrix.SetElement(1,1,-1)
#
if j==i:
changeOfSignMatrix.SetElement(2,2,1)
else:
changeOfSignMatrix.SetElement(2,2,-1)
#
return changeOfSignMatrix
# find the registration boneModel0ToBoneModel1Transform
boneModel0ToBoneModel1TransformListWithScore = []
for i in range(2):
for j in range(2):
boneModel0ToBoneModel1Transform = vtk.vtkTransform()
boneModel0ToBoneModel1Transform.PostMultiply()
#
worldToBoneModel0Matrix = vtk.vtkMatrix4x4()
worldToBoneModel0Matrix.DeepCopy(model0ToWorldMatrix)
worldToBoneModel0Matrix.Invert()
#
boneModel0ToBoneModel1Transform.Concatenate(worldToBoneModel0Matrix)
#
model1ToWorldTransform = vtk.vtkTransform()
model1ToWorldTransform.PostMultiply()
model1ToWorldTransform.Concatenate(getChangeOfSignMatrix(i,j))
model1ToWorldTransform.Concatenate(temporalModel1ToWorldMatrix)
#
boneModel0ToBoneModel1Transform.Concatenate(model1ToWorldTransform)
#
#boneModel0ToBoneModel1Transformer
bone0ToBone1Transformer = vtk.vtkTransformFilter()
bone0ToBone1Transformer.SetInputData(polydata0)
bone0ToBone1Transformer.SetTransform(boneModel0ToBoneModel1Transform)
bone0ToBone1Transformer.Update()
#
from vtk.util.numpy_support import vtk_to_numpy
transformedModel0Points_vtk = bone0ToBone1Transformer.GetOutput().GetPoints().GetData()
transformedModel0Points = vtk_to_numpy(transformedModel0Points_vtk)
#
pointsLocator = vtk.vtkPointLocator()
pointsLocator.SetDataSet(bone0ToBone1Transformer.GetOutput())
pointsLocator.BuildLocator()
#
distanceList = []
for k in range(len(model1Points)):
closestPointOfTransformedBone0ID = pointsLocator.FindClosestPoint(model1Points[k])
difference = model1Points[k] - transformedModel0Points[closestPointOfTransformedBone0ID]
distanceBetweenPoints = np.linalg.norm(difference)
distanceList.append(distanceBetweenPoints)
#
distanceArray = np.array(distanceList)
meanDistance = distanceArray.mean(axis=0)
#
boneModel0ToBoneModel1TransformListWithScore.append([boneModel0ToBoneModel1Transform,meanDistance])
boneModel0ToBoneModel1TransformListWithScore.sort(key = lambda item : item[1])
bone0ToBone1RegistrationTransformNode = slicer.vtkMRMLLinearTransformNode()
bone0ToBone1RegistrationTransformNode.SetName("Bone0ToBone1RegistrationTransform")
slicer.mrmlScene.AddNode(bone0ToBone1RegistrationTransformNode)
bone0ToBone1RegistrationTransformNode.SetMatrixTransformToParent(boneModel0ToBoneModel1TransformListWithScore[0][0].GetMatrix())
boneModel0.SetAndObserveTransformNodeID(bone0ToBone1RegistrationTransformNode.GetID())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment