Created
February 23, 2016 10:21
-
-
Save Ghazi-Bouabene/ce42c654c33cb432cd41 to your computer and use it in GitHub Desktop.
Evaluation method for the femur reconstruction contest of the FutureLearn course: Statistical Shape Modelling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package ch.unibas.cs.gravis.shapemodelling | |
import java.io.{PrintWriter, File} | |
import scalismo.geometry._ | |
import scalismo.io.MeshIO | |
import scalismo.mesh.{MeshMetrics, Mesh, TriangleMesh} | |
import scalismo.utils.MeshConversion | |
import vtk.{vtkPolyData, vtkPolyDataConnectivityFilter} | |
object ReconstructionEval { | |
/** | |
* Reconstruction evaluation function. | |
* | |
* Attention: this evaluates mesh distances only on the reconstructed patches and not on the entire shape | |
* | |
* @param givenPartialMesh the partial mesh that was given to reconstruct | |
* @param groundTruth the complete shape from which givenPartialMesh was generated | |
* @param reconstruction the reconstruction to be evaluated. This mesh needs to be in correspondence with the reference mesh provided in step 1 of the project | |
* | |
* @return tuple containing first the average mesh distance, then the hausdorf between the 2 reconstructed patches | |
* | |
**/ | |
def eval(givenPartialMesh: TriangleMesh, groundTruth: TriangleMesh, reconstruction: TriangleMesh): (Double, Double) = { | |
// identify the point identifiers for the patches that need to be reconstructed | |
val ptIDs = groundTruth.pointsWithId.filter { case (pt, id) => | |
val distanceToPartial = (givenPartialMesh.findClosestPoint(pt).point - pt).norm | |
distanceToPartial < 0.1 | |
}.map(_._2).toIndexedSeq | |
// cut the reconstructed patches from the ground truth and the reconstruction | |
val truePatch = getBiggestConnectedComponent(Mesh.clipMesh(groundTruth, (pt: Point[_3D]) => ptIDs.contains(groundTruth.findClosestPoint(pt).id))) | |
val reconstructedPatch = getBiggestConnectedComponent(Mesh.clipMesh(reconstruction, (pt: Point[_3D]) => ptIDs.contains(reconstruction.findClosestPoint(pt).id))) | |
// compute average mesh distance in both directions | |
val averageOneWay = MeshMetrics.avgDistance(truePatch, reconstructedPatch) | |
val averageOtherWay = MeshMetrics.avgDistance(reconstructedPatch, truePatch) | |
// Hausdorff distance (symmetric) | |
val hausdorff = MeshMetrics.hausdorffDistance(reconstructedPatch, truePatch) | |
((averageOneWay + averageOtherWay) * 0.5, hausdorff) | |
} | |
/** | |
* Extracts the biggest connected component in a mesh. | |
* | |
* We use this to be safe and guarantee that no disconnected points or cells are used when comparing the 2 patches | |
**/ | |
def getBiggestConnectedComponent(mesh: TriangleMesh): TriangleMesh = { | |
val connectivity = new vtkPolyDataConnectivityFilter() | |
val polydata = MeshConversion.meshToVtkPolyData(mesh) | |
connectivity.SetExtractionModeToSpecifiedRegions() | |
connectivity.SetInputData(polydata) | |
connectivity.Update() | |
val count = connectivity.GetNumberOfExtractedRegions() | |
val (largestSize, largestId) = { | |
val vtk = connectivity.GetRegionSizes() | |
val out = for (i <- -0 until count) yield vtk.GetValue(i) | |
vtk.Delete() | |
out | |
}.zipWithIndex.sortBy(_._1).reverse.head | |
connectivity.InitializeSpecifiedRegionList() | |
connectivity.AddSpecifiedRegion(largestId) | |
connectivity.Update() | |
val out = new vtkPolyData() | |
out.DeepCopy(connectivity.GetOutput()) | |
MeshConversion.vtkPolyDataToCorrectedTriangleMesh(out).get | |
} | |
def main(args: Array[String]): Unit = { | |
scalismo.initialize() | |
if (args.size < 4) { | |
println("usage: run <path to ground truth> <path to partialMesh> <path to reconstructedMesh> <output file>") | |
System.exit(-1) | |
} | |
val groundTruth = MeshIO.readMesh(new File(args(0))).get | |
val partialMesh = MeshIO.readMesh(new File(args(1))).get | |
val reconstruction = MeshIO.readMesh(new File(args(2))).get | |
val (average, hausdorff) = eval(partialMesh, groundTruth, reconstruction) | |
val pw = new PrintWriter(args(3)) | |
pw.println( s"""{"hausdorff":$hausdorff, "average":$average}""") | |
pw.close | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment