Created
October 30, 2021 12:42
-
-
Save mauigna06/e6a1f177a9f6210809cf1a52079c06d3 to your computer and use it in GitHub Desktop.
Register two similar bones that has diffences in position, orientation and size (scaling)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
#SVD Registration | |
boneModel0 = getNode('bone0') | |
boneModel1 = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLModelNode',"bone1") | |
boneModel1.CreateDefaultDisplayNodes() | |
boneModel1.CopyContent(boneModel0) | |
boneModel1DisplayNode = boneModel1.GetDisplayNode() | |
color = boneModel0.GetDisplayNode().GetColor() | |
boneModel1DisplayNode.SetColor(color[2],color[0],color[1]) | |
#Create random transformation matrix | |
import numpy as np | |
import random | |
axisOfRotation = np.zeros(3) | |
angleOfRotation = 0 | |
for i in range(len(axisOfRotation)): | |
axisOfRotation[i] = random.random() | |
axisOfRotation = axisOfRotation/np.linalg.norm(axisOfRotation) | |
angleOfRotation = random.uniform(45, 315) | |
scaleMatrix = vtk.vtkMatrix4x4() | |
scaleMatrix.SetElement(0,0,random.uniform(1.3, 1.8)) | |
scaleMatrix.SetElement(1,1,random.uniform(1.3, 1.8)) | |
scaleMatrix.SetElement(1,2,random.uniform(1.3, 1.8)) | |
origin = np.zeros(3) | |
for i in range(len(origin)): | |
origin[i] = random.uniform(50, 150) | |
randomTransform = vtk.vtkTransform() | |
randomTransform.PostMultiply() | |
randomTransform.RotateWXYZ(angleOfRotation,axisOfRotation) | |
scaleFactor = random.uniform(1.3, 2) | |
randomTransform.Scale(scaleFactor,scaleFactor,scaleFactor) | |
randomTransform.Translate(origin) | |
#apply to boneModel1 | |
transformer = vtk.vtkTransformFilter() | |
transformer.SetInputData(boneModel1.GetPolyData()) | |
transformer.SetTransform(randomTransform) | |
transformer.Update() | |
boneModel1.SetAndObservePolyData(transformer.GetOutput()) | |
maskPointsFilter = vtk.vtkMaskPoints() | |
maskPointsFilter.SetInputData(boneModel0.GetPolyData()) | |
numberOfSampledPoints0 = 2000 | |
ratio0 = int(boneModel0.GetPolyData().GetNumberOfPoints()/numberOfSampledPoints0) | |
#This works but the sampling could be biased spatially I think | |
maskPointsFilter.SetOnRatio(ratio0) | |
maskPointsFilter.RandomModeOn() | |
maskPointsFilter.Update() | |
polydata0 = vtk.vtkPolyData() | |
polydata0.ShallowCopy(maskPointsFilter.GetOutput()) | |
maskPointsFilter = vtk.vtkMaskPoints() | |
maskPointsFilter.SetInputData(boneModel1.GetPolyData()) | |
numberOfSampledPoints1 = 2000 | |
ratio1 = int(boneModel1.GetPolyData().GetNumberOfPoints()/numberOfSampledPoints1) | |
#This works but the sampling could be biased spatially I think | |
maskPointsFilter.SetOnRatio(ratio1) | |
maskPointsFilter.RandomModeOn() | |
maskPointsFilter.Update() | |
polydata1 = vtk.vtkPolyData() | |
polydata1.ShallowCopy(maskPointsFilter.GetOutput()) | |
#Calculate the SVD and mean | |
from vtk.util.numpy_support import vtk_to_numpy | |
model0Points = vtk_to_numpy(polydata0.GetPoints().GetData()) | |
model1Points = vtk_to_numpy(polydata1.GetPoints().GetData()) | |
# Calculate the mean of the points, i.e. the 'center' of the cloud | |
model0PointsMean = model0Points.mean(axis=0) | |
model1PointsMean = model1Points.mean(axis=0) | |
# Do an SVD on the mean-centered data. | |
uu0, dd0, eigenvectors0 = np.linalg.svd(model0Points - model0PointsMean) | |
uu1, dd1, eigenvectors1 = np.linalg.svd(model1Points - model1PointsMean) | |
# Create a frame for boneModel0 | |
model0Z = np.zeros(3) | |
model0X = eigenvectors0[0] | |
model0Y = eigenvectors0[1] | |
vtk.vtkMath.Cross(model0X, model0Y, model0Z) | |
model0Z = model0Z/np.linalg.norm(model0Z) | |
model0Origin = model0PointsMean | |
def getAxes1ToWorldChangeOfFrameMatrix(axis1X,axis1Y,axis1Z,axisOrigin): | |
axes1ToWorldChangeOfFrameMatrix = vtk.vtkMatrix4x4() | |
axes1ToWorldChangeOfFrameMatrix.DeepCopy((1, 0, 0, 0, | |
0, 1, 0, 0, | |
0, 0, 1, 0, | |
0, 0, 0, 1)) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(0,0,axis1X[0]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(1,0,axis1X[1]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(2,0,axis1X[2]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(0,1,axis1Y[0]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(1,1,axis1Y[1]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(2,1,axis1Y[2]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(0,2,axis1Z[0]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(1,2,axis1Z[1]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(2,2,axis1Z[2]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(0,3,axisOrigin[0]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(1,3,axisOrigin[1]) | |
axes1ToWorldChangeOfFrameMatrix.SetElement(2,3,axisOrigin[2]) | |
return axes1ToWorldChangeOfFrameMatrix | |
model0ToWorldMatrix = getAxes1ToWorldChangeOfFrameMatrix(model0X,model0Y,model0Z,model0Origin) | |
scaleMatrix0 = np.diag(dd0) | |
invertedScaleMatrix0_np = np.eye(4,4) | |
invertedScaleMatrix0_np[:3,:3] = np.linalg.inv(scaleMatrix0) | |
invertedScaleMatrix0 = vtk.vtkMatrix4x4() | |
invertedScaleMatrix0.DeepCopy(invertedScaleMatrix0_np.reshape(4*4).tolist()) | |
# temporal frame for boneModel1 | |
model1Z = np.zeros(3) | |
model1X = eigenvectors1[0] | |
model1Y = eigenvectors1[1] | |
vtk.vtkMath.Cross(model1X, model1Y, model1Z) | |
model1Z = model1Z/np.linalg.norm(model1Z) | |
model1Origin = model1PointsMean | |
scaleMatrix1_np = np.eye(4,4) | |
scaleMatrix1_np[:3,:3] = np.diag(dd1) | |
scaleMatrix1 = vtk.vtkMatrix4x4() | |
scaleMatrix1.DeepCopy(scaleMatrix1_np.reshape(4*4).tolist()) | |
temporalModel1ToWorldMatrix = getAxes1ToWorldChangeOfFrameMatrix(model1X,model1Y,model1Z,model1Origin) | |
def getChangeOfSignMatrix(i,j): | |
changeOfSignMatrix = vtk.vtkMatrix4x4() | |
changeOfSignMatrix.DeepCopy((1, 0, 0, 0, | |
0, 1, 0, 0, | |
0, 0, 1, 0, | |
0, 0, 0, 1)) | |
if i==0: | |
changeOfSignMatrix.SetElement(0,0,1) | |
else: | |
changeOfSignMatrix.SetElement(0,0,-1) | |
# | |
if j==0: | |
changeOfSignMatrix.SetElement(1,1,1) | |
else: | |
changeOfSignMatrix.SetElement(1,1,-1) | |
# | |
if j==i: | |
changeOfSignMatrix.SetElement(2,2,1) | |
else: | |
changeOfSignMatrix.SetElement(2,2,-1) | |
# | |
return changeOfSignMatrix | |
# find the registration boneModel0ToBoneModel1Transform | |
boneModel0ToBoneModel1TransformListWithScore = [] | |
for i in range(2): | |
for j in range(2): | |
boneModel0ToBoneModel1Transform = vtk.vtkTransform() | |
boneModel0ToBoneModel1Transform.PostMultiply() | |
# | |
worldToBoneModel0Matrix = vtk.vtkMatrix4x4() | |
worldToBoneModel0Matrix.DeepCopy(model0ToWorldMatrix) | |
worldToBoneModel0Matrix.Invert() | |
# | |
boneModel0ToBoneModel1Transform.Concatenate(worldToBoneModel0Matrix) | |
# | |
scaleModel0ToModel1Transform = vtk.vtkTransform() | |
scaleModel0ToModel1Transform.PostMultiply() | |
scaleModel0ToModel1Transform.Concatenate(invertedScaleMatrix0) | |
scaleModel0ToModel1Transform.Concatenate(scaleMatrix1) | |
# | |
boneModel0ToBoneModel1Transform.Concatenate(scaleModel0ToModel1Transform) | |
# | |
model1ToWorldTransform = vtk.vtkTransform() | |
model1ToWorldTransform.PostMultiply() | |
model1ToWorldTransform.Concatenate(getChangeOfSignMatrix(i,j)) | |
model1ToWorldTransform.Concatenate(temporalModel1ToWorldMatrix) | |
# | |
boneModel0ToBoneModel1Transform.Concatenate(model1ToWorldTransform) | |
# | |
#boneModel0ToBoneModel1Transformer | |
bone0ToBone1Transformer = vtk.vtkTransformFilter() | |
bone0ToBone1Transformer.SetInputData(polydata0) | |
bone0ToBone1Transformer.SetTransform(boneModel0ToBoneModel1Transform) | |
bone0ToBone1Transformer.Update() | |
# | |
from vtk.util.numpy_support import vtk_to_numpy | |
transformedModel0Points_vtk = bone0ToBone1Transformer.GetOutput().GetPoints().GetData() | |
transformedModel0Points = vtk_to_numpy(transformedModel0Points_vtk) | |
# | |
pointsLocator = vtk.vtkPointLocator() | |
pointsLocator.SetDataSet(bone0ToBone1Transformer.GetOutput()) | |
pointsLocator.BuildLocator() | |
# | |
distanceList = [] | |
for k in range(len(model1Points)): | |
closestPointOfTransformedBone0ID = pointsLocator.FindClosestPoint(model1Points[k]) | |
difference = model1Points[k] - transformedModel0Points[closestPointOfTransformedBone0ID] | |
distanceBetweenPoints = np.linalg.norm(difference) | |
distanceList.append(distanceBetweenPoints) | |
# | |
distanceArray = np.array(distanceList) | |
meanDistance = distanceArray.mean(axis=0) | |
# | |
boneModel0ToBoneModel1TransformListWithScore.append([boneModel0ToBoneModel1Transform,meanDistance]) | |
boneModel0ToBoneModel1TransformListWithScore.sort(key = lambda item : item[1]) | |
bone0ToBone1RegistrationTransformNode = slicer.vtkMRMLLinearTransformNode() | |
bone0ToBone1RegistrationTransformNode.SetName("Bone0ToBone1RegistrationTransform") | |
slicer.mrmlScene.AddNode(bone0ToBone1RegistrationTransformNode) | |
bone0ToBone1RegistrationTransformNode.SetMatrixTransformToParent(boneModel0ToBoneModel1TransformListWithScore[0][0].GetMatrix()) | |
boneModel0.SetAndObserveTransformNodeID(bone0ToBone1RegistrationTransformNode.GetID()) | |
#I don't know if this is an accurate metric of the error, it is the average of the distance of all sampledMesh1Points to nearestSampledMesh0Points | |
meanError = boneModel0ToBoneModel1TransformListWithScore[0][1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment