Skip to content

Instantly share code, notes, and snippets.

@Ghazi-Bouabene
Last active February 13, 2017 10:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ghazi-Bouabene/6c87642b8ad9da65a34da65433565285 to your computer and use it in GitHub Desktop.
Save Ghazi-Bouabene/6c87642b8ad9da65a34da65433565285 to your computer and use it in GitHub Desktop.
package com.example
import java.awt.Color
import scalismo.io.{LandmarkIO, MeshIO, StatismoIO}
import java.io.File
import breeze.linalg.{DenseMatrix, DenseVector}
import scalismo.common._
import scalismo.geometry._
import scalismo.kernels.{DiagonalKernel, GaussianKernel, MatrixValuedPDKernel, PDKernel}
import scalismo.numerics.UniformMeshSampler3D
import scalismo.statisticalmodel._
import scalismo.ui.api.{LandmarkView, ScalismoUI}
object ShapeCompletionTutorial {
def main(args: Array[String]) {
// required to initialize native libraries (VTK, HDF5 ..)
scalismo.initialize()
val ui = ScalismoUI()
val noseless = MeshIO.readMesh(new File("datasets/noseless.stl")).get
val targetGroup = ui.createGroup("noseless")
ui.show(targetGroup, noseless, "noseless")
val littleModel = StatismoIO.readStatismoMeshModel(new File("datasets/model.h5")).get
val littleModelGroup = ui.createGroup("littleModel group")
ui.show(littleModelGroup, littleModel, "littleModel")
val zeroMean = VectorField(RealSpace[_3D], (pt:Point[_3D]) => Vector(0,0,0))
val scalarValuedKernel = GaussianKernel[_3D](30) * 10
case class XmirroredKernel(ker : PDKernel[_3D]) extends PDKernel[_3D] {
override def domain = RealSpace[_3D]
override def k(x: Point[_3D], y: Point[_3D]) = ker(Point(x(0) * -1f ,x(1), x(2)), y)
}
def SymmetrizeKernel(ker : PDKernel[_3D]) : MatrixValuedPDKernel[_3D] = {
val xmirrored = XmirroredKernel(ker)
val k1 = DiagonalKernel(ker, 3)
val k2 = DiagonalKernel(xmirrored * -1f, xmirrored, xmirrored)
k1 + k2
}
val sim = SymmetrizeKernel(scalarValuedKernel)
val gp = GaussianProcess(zeroMean, sim)
val sampler = UniformMeshSampler3D(littleModel.referenceMesh, 500)
val lowRankGP = LowRankGaussianProcess.approximateGP(gp, sampler, 50)
val model = StatisticalMeshModel.augmentModel(littleModel, lowRankGP)
val modelGroup = ui.createGroup("model group")
ui.show(modelGroup, model, "model")
littleModelGroup.remove()
val lms = LandmarkIO.readLandmarksJson[_3D](new File("datasets/noseFittingLandmarks.json")).get
ui.show(targetGroup, lms, "noseless")
// execute this only if you were unable to click the landmarks
val lm = LandmarkIO.readLandmarksJson[_3D](new File("datasets/noseFittingModelLms.json")).get
ui.show(modelGroup, lm, "model")
val modelLandmarkViews = ui.filter(modelGroup, (lv: LandmarkView) => true)
val modelPts: Seq[Point[_3D]] = modelLandmarkViews.map(lv => lv.landmark.point)
val noselessPts = ui.filter(targetGroup, (lv: LandmarkView) => true).map(lv => lv.landmark.point)
val littleNoise = MultivariateNormalDistribution(DenseVector(0.0, 0.0, 0.0), DenseMatrix((0.5, 0.0, 0.0), (0.0, 0.5, 0.0), (0.0, 0.0, 0.5)))
val trainingData = (modelPts zip noselessPts).map{ case (mPt, nPt) =>
(model.mean.pointSet.findClosestPoint(mPt).id, nPt, littleNoise)
}
val posterior = model.posterior(trainingData.toIndexedSeq)
val posteriorGroup = ui.createGroup("posterior")
ui.show(posteriorGroup, posterior, "posterior")
val modelLMs = LandmarkIO.readLandmarksJson[_3D](new File("datasets/modelLandmarks.json")).get
ui.show(modelGroup, modelLMs, "model")
val noselessLMs = LandmarkIO.readLandmarksJson[_3D](new File("datasets/noselessLandmarks.json")).get
ui.show(targetGroup, noselessLMs, "noseless")
val modelLandmarks = ui.filter(modelGroup, (lv: LandmarkView) => true)
val noselessLandmarks = ui.filter(targetGroup, (lv: LandmarkView) => true)
val trainingData2 = (modelLandmarks zip noselessLandmarks).map{ case (mLm, nLm) =>
(model.mean.pointSet.findClosestPoint(mLm.landmark.point).id, nLm.landmark.point, littleNoise)
}
val betterPosterior = model.posterior(trainingData2.toIndexedSeq)
val betterPosteriorGroup = ui.createGroup("better posterior")
ui.show(betterPosteriorGroup, betterPosterior, "betterPosterior")
val nosePtIDs = model.referenceMesh.pointSet.pointIds.filter { id =>
(model.referenceMesh.pointSet.point(id) - model.referenceMesh.pointSet.point(PointId(8152))).norm <= 42
}
val posteriorNoseModel = betterPosterior.marginal(nosePtIDs.toIndexedSeq)
val posteriorNoseGroup = ui.createGroup("posterior nose model")
ui.show(posteriorNoseGroup, posteriorNoseModel, "posteriorNoseModel")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment