Created
October 29, 2021 20:42
-
-
Save mauigna06/cb2b81cd0e1945ea2165fb7f0841adf3 to your computer and use it in GitHub Desktop.
This an example of registration by SVD decomposition of the data. Useful to register very similar bones
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Posted example | |
import numpy as np | |
#SVD Registration | |
boneModel0 = getNode('deformedBone') | |
boneModel1 = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLModelNode',"fibulaMoved") | |
boneModel1.CreateDefaultDisplayNodes() | |
boneModel1.CopyContent(boneModel0) | |
boneModel1DisplayNode = boneModel1.GetDisplayNode() | |
color = boneModel0.GetDisplayNode().GetColor() | |
boneModel1DisplayNode.SetColor(color[2],color[0],color[1]) | |
#Create random transformation matrix | |
import numpy as np | |
import random | |
axisOfRotation = np.zeros(3) | |
angleOfRotation = 0 | |
for i in range(len(axisOfRotation)): | |
axisOfRotation[i] = random.random() | |
axisOfRotation = axisOfRotation/np.linalg.norm(axisOfRotation) | |
angleOfRotation = random.uniform(45, 315) | |
scaleMatrix = vtk.vtkMatrix4x4() | |
scaleMatrix.SetElement(0,0,random.uniform(1.3, 1.8)) | |
scaleMatrix.SetElement(1,1,random.uniform(1.3, 1.8)) | |
scaleMatrix.SetElement(1,2,random.uniform(1.3, 1.8)) | |
origin = np.zeros(3) | |
for i in range(len(origin)): | |
origin[i] = random.uniform(50, 150) | |
randomTransform = vtk.vtkTransform() | |
randomTransform.PostMultiply() | |
randomTransform.RotateWXYZ(angleOfRotation,axisOfRotation) | |
randomTransform.Translate(origin) | |
#apply to boneModel1 | |
transformer = vtk.vtkTransformFilter() | |
transformer.SetInputData(boneModel1.GetPolyData()) | |
transformer.SetTransform(randomTransform) | |
transformer.Update() | |
boneModel1.SetAndObservePolyData(transformer.GetOutput()) | |
maskPointsFilter = vtk.vtkMaskPoints() | |
maskPointsFilter.SetInputData(boneModel0.GetPolyData()) | |
numberOfSampledPoints0 = 2000 | |
ratio0 = int(boneModel0.GetPolyData().GetNumberOfPoints()/numberOfSampledPoints0) | |
#This works but the sampling could be biased spatially I think | |
maskPointsFilter.SetOnRatio(ratio0) | |
maskPointsFilter.RandomModeOn() | |
maskPointsFilter.Update() | |
polydata0 = vtk.vtkPolyData() | |
polydata0.ShallowCopy(maskPointsFilter.GetOutput()) | |
maskPointsFilter = vtk.vtkMaskPoints() | |
maskPointsFilter.SetInputData(boneModel1.GetPolyData()) | |
numberOfSampledPoints1 = 2000 | |
ratio1 = int(boneModel1.GetPolyData().GetNumberOfPoints()/numberOfSampledPoints1) | |
#This works but the sampling could be biased spatially I think | |
maskPointsFilter.SetOnRatio(ratio1) | |
maskPointsFilter.RandomModeOn() | |
maskPointsFilter.Update() | |
polydata1 = vtk.vtkPolyData() | |
polydata1.ShallowCopy(maskPointsFilter.GetOutput()) | |
#Calculate the SVD and mean | |
from vtk.util.numpy_support import vtk_to_numpy | |
model0Points = vtk_to_numpy(polydata0.GetPoints().GetData()) | |
model1Points = vtk_to_numpy(polydata1.GetPoints().GetData()) | |
# Calculate the mean of the points, i.e. the 'center' of the cloud | |
model0PointsMean = model0Points.mean(axis=0) | |
model1PointsMean = model1Points.mean(axis=0) | |
# Do an SVD on the mean-centered data. | |
uu0, dd0, eigenvectors0 = np.linalg.svd(model0Points - model0PointsMean) | |
uu1, dd1, eigenvectors1 = np.linalg.svd(model1Points - model1PointsMean) | |
# Create a frame for boneModel0 | |
model0Z = np.zeros(3) | |
model0X = eigenvectors0[0] | |
model0Y = eigenvectors0[1] | |
vtk.vtkMath.Cross(model0X, model0Y, model0Z) | |
model0Z = model0Z/np.linalg.norm(model0Z) | |
model0Origin = model0PointsMean | |
def getAxes1ToWorldChangeOfFrameMatrix(axis1X,axis1Y,axis1Z,axisOrigin): | |
axes1ToWorldChangeOfFrameMatrix = vtk.vtkMatrix4x4() | |
axes1ToWorldChangeOfFrameMatrix.DeepCopy((1, 0, 0, 0, | |
0, 1, 0, 0, | |
0, 0, 1, 0, | |
0, 0, 0, 1)) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(0,0,axis1X[0]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(1,0,axis1X[1]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(2,0,axis1X[2]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(0,1,axis1Y[0]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(1,1,axis1Y[1]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(2,1,axis1Y[2]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(0,2,axis1Z[0]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(1,2,axis1Z[1]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(2,2,axis1Z[2]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(0,3,axisOrigin[0]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(1,3,axisOrigin[1]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(2,3,axisOrigin[2]) | |
return axes1ToWorldChangeOfFrameMatrix | |
model0ToWorldMatrix = getAxes1ToWorldChangeOfFrameMatrix(model0X,model0Y,model0Z,model0Origin) | |
# temporal frame for boneModel1 | |
model1Z = np.zeros(3) | |
model1X = eigenvectors1[0] | |
model1Y = eigenvectors1[1] | |
vtk.vtkMath.Cross(model1X, model1Y, model1Z) | |
model1Z = model1Z/np.linalg.norm(model1Z) | |
model1Origin = model1PointsMean | |
temporalModel1ToWorldMatrix = getAxes1ToWorldChangeOfFrameMatrix(model1X,model1Y,model1Z,model1Origin) | |
def getChangeOfSignMatrix(i,j): | |
changeOfSignMatrix = vtk.vtkMatrix4x4() | |
changeOfSignMatrix.DeepCopy((1, 0, 0, 0, | |
0, 1, 0, 0, | |
0, 0, 1, 0, | |
0, 0, 0, 1)) | |
if i==0: | |
changeOfSignMatrix.SetElement(0,0,1) | |
else: | |
changeOfSignMatrix.SetElement(0,0,-1) | |
# | |
if j==0: | |
changeOfSignMatrix.SetElement(1,1,1) | |
else: | |
changeOfSignMatrix.SetElement(1,1,-1) | |
# | |
if j==i: | |
changeOfSignMatrix.SetElement(2,2,1) | |
else: | |
changeOfSignMatrix.SetElement(2,2,-1) | |
# | |
return changeOfSignMatrix | |
# find the registration boneModel0ToBoneModel1Transform | |
boneModel0ToBoneModel1TransformListWithScore = [] | |
for i in range(2): | |
for j in range(2): | |
boneModel0ToBoneModel1Transform = vtk.vtkTransform() | |
boneModel0ToBoneModel1Transform.PostMultiply() | |
# | |
worldToBoneModel0Matrix = vtk.vtkMatrix4x4() | |
worldToBoneModel0Matrix.DeepCopy(model0ToWorldMatrix) | |
worldToBoneModel0Matrix.Invert() | |
# | |
boneModel0ToBoneModel1Transform.Concatenate(worldToBoneModel0Matrix) | |
# | |
model1ToWorldTransform = vtk.vtkTransform() | |
model1ToWorldTransform.PostMultiply() | |
model1ToWorldTransform.Concatenate(getChangeOfSignMatrix(i,j)) | |
model1ToWorldTransform.Concatenate(temporalModel1ToWorldMatrix) | |
# | |
boneModel0ToBoneModel1Transform.Concatenate(model1ToWorldTransform) | |
# | |
#boneModel0ToBoneModel1Transformer | |
bone0ToBone1Transformer = vtk.vtkTransformFilter() | |
bone0ToBone1Transformer.SetInputData(polydata0) | |
bone0ToBone1Transformer.SetTransform(boneModel0ToBoneModel1Transform) | |
bone0ToBone1Transformer.Update() | |
# | |
from vtk.util.numpy_support import vtk_to_numpy | |
transformedModel0Points_vtk = bone0ToBone1Transformer.GetOutput().GetPoints().GetData() | |
transformedModel0Points = vtk_to_numpy(transformedModel0Points_vtk) | |
# | |
pointsLocator = vtk.vtkPointLocator() | |
pointsLocator.SetDataSet(bone0ToBone1Transformer.GetOutput()) | |
pointsLocator.BuildLocator() | |
# | |
distanceList = [] | |
for k in range(len(model1Points)): | |
closestPointOfTransformedBone0ID = pointsLocator.FindClosestPoint(model1Points[k]) | |
difference = model1Points[k] - transformedModel0Points[closestPointOfTransformedBone0ID] | |
distanceBetweenPoints = np.linalg.norm(difference) | |
distanceList.append(distanceBetweenPoints) | |
# | |
distanceArray = np.array(distanceList) | |
meanDistance = distanceArray.mean(axis=0) | |
# | |
boneModel0ToBoneModel1TransformListWithScore.append([boneModel0ToBoneModel1Transform,meanDistance]) | |
boneModel0ToBoneModel1TransformListWithScore.sort(key = lambda item : item[1]) | |
bone0ToBone1RegistrationTransformNode = slicer.vtkMRMLLinearTransformNode() | |
bone0ToBone1RegistrationTransformNode.SetName("Bone0ToBone1RegistrationTransform") | |
slicer.mrmlScene.AddNode(bone0ToBone1RegistrationTransformNode) | |
bone0ToBone1RegistrationTransformNode.SetMatrixTransformToParent(boneModel0ToBoneModel1TransformListWithScore[0][0].GetMatrix()) | |
boneModel0.SetAndObserveTransformNodeID(bone0ToBone1RegistrationTransformNode.GetID()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment