Last active
February 13, 2017 10:35
-
-
Save Ghazi-Bouabene/6c87642b8ad9da65a34da65433565285 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.example | |
import java.awt.Color | |
import scalismo.io.{LandmarkIO, MeshIO, StatismoIO} | |
import java.io.File | |
import breeze.linalg.{DenseMatrix, DenseVector} | |
import scalismo.common._ | |
import scalismo.geometry._ | |
import scalismo.kernels.{DiagonalKernel, GaussianKernel, MatrixValuedPDKernel, PDKernel} | |
import scalismo.numerics.UniformMeshSampler3D | |
import scalismo.statisticalmodel._ | |
import scalismo.ui.api.{LandmarkView, ScalismoUI} | |
object ShapeCompletionTutorial { | |
def main(args: Array[String]) { | |
// required to initialize native libraries (VTK, HDF5 ..) | |
scalismo.initialize() | |
val ui = ScalismoUI() | |
val noseless = MeshIO.readMesh(new File("datasets/noseless.stl")).get | |
val targetGroup = ui.createGroup("noseless") | |
ui.show(targetGroup, noseless, "noseless") | |
val littleModel = StatismoIO.readStatismoMeshModel(new File("datasets/model.h5")).get | |
val littleModelGroup = ui.createGroup("littleModel group") | |
ui.show(littleModelGroup, littleModel, "littleModel") | |
val zeroMean = VectorField(RealSpace[_3D], (pt:Point[_3D]) => Vector(0,0,0)) | |
val scalarValuedKernel = GaussianKernel[_3D](30) * 10 | |
case class XmirroredKernel(ker : PDKernel[_3D]) extends PDKernel[_3D] { | |
override def domain = RealSpace[_3D] | |
override def k(x: Point[_3D], y: Point[_3D]) = ker(Point(x(0) * -1f ,x(1), x(2)), y) | |
} | |
def SymmetrizeKernel(ker : PDKernel[_3D]) : MatrixValuedPDKernel[_3D] = { | |
val xmirrored = XmirroredKernel(ker) | |
val k1 = DiagonalKernel(ker, 3) | |
val k2 = DiagonalKernel(xmirrored * -1f, xmirrored, xmirrored) | |
k1 + k2 | |
} | |
val sim = SymmetrizeKernel(scalarValuedKernel) | |
val gp = GaussianProcess(zeroMean, sim) | |
val sampler = UniformMeshSampler3D(littleModel.referenceMesh, 500) | |
val lowRankGP = LowRankGaussianProcess.approximateGP(gp, sampler, 50) | |
val model = StatisticalMeshModel.augmentModel(littleModel, lowRankGP) | |
val modelGroup = ui.createGroup("model group") | |
ui.show(modelGroup, model, "model") | |
littleModelGroup.remove() | |
val lms = LandmarkIO.readLandmarksJson[_3D](new File("datasets/noseFittingLandmarks.json")).get | |
ui.show(targetGroup, lms, "noseless") | |
// execute this only if you were unable to click the landmarks | |
val lm = LandmarkIO.readLandmarksJson[_3D](new File("datasets/noseFittingModelLms.json")).get | |
ui.show(modelGroup, lm, "model") | |
val modelLandmarkViews = ui.filter(modelGroup, (lv: LandmarkView) => true) | |
val modelPts: Seq[Point[_3D]] = modelLandmarkViews.map(lv => lv.landmark.point) | |
val noselessPts = ui.filter(targetGroup, (lv: LandmarkView) => true).map(lv => lv.landmark.point) | |
val littleNoise = MultivariateNormalDistribution(DenseVector(0.0, 0.0, 0.0), DenseMatrix((0.5, 0.0, 0.0), (0.0, 0.5, 0.0), (0.0, 0.0, 0.5))) | |
val trainingData = (modelPts zip noselessPts).map{ case (mPt, nPt) => | |
(model.mean.pointSet.findClosestPoint(mPt).id, nPt, littleNoise) | |
} | |
val posterior = model.posterior(trainingData.toIndexedSeq) | |
val posteriorGroup = ui.createGroup("posterior") | |
ui.show(posteriorGroup, posterior, "posterior") | |
val modelLMs = LandmarkIO.readLandmarksJson[_3D](new File("datasets/modelLandmarks.json")).get | |
ui.show(modelGroup, modelLMs, "model") | |
val noselessLMs = LandmarkIO.readLandmarksJson[_3D](new File("datasets/noselessLandmarks.json")).get | |
ui.show(targetGroup, noselessLMs, "noseless") | |
val modelLandmarks = ui.filter(modelGroup, (lv: LandmarkView) => true) | |
val noselessLandmarks = ui.filter(targetGroup, (lv: LandmarkView) => true) | |
val trainingData2 = (modelLandmarks zip noselessLandmarks).map{ case (mLm, nLm) => | |
(model.mean.pointSet.findClosestPoint(mLm.landmark.point).id, nLm.landmark.point, littleNoise) | |
} | |
val betterPosterior = model.posterior(trainingData2.toIndexedSeq) | |
val betterPosteriorGroup = ui.createGroup("better posterior") | |
ui.show(betterPosteriorGroup, betterPosterior, "betterPosterior") | |
val nosePtIDs = model.referenceMesh.pointSet.pointIds.filter { id => | |
(model.referenceMesh.pointSet.point(id) - model.referenceMesh.pointSet.point(PointId(8152))).norm <= 42 | |
} | |
val posteriorNoseModel = betterPosterior.marginal(nosePtIDs.toIndexedSeq) | |
val posteriorNoseGroup = ui.createGroup("posterior nose model") | |
ui.show(posteriorNoseGroup, posteriorNoseModel, "posteriorNoseModel") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment