Skip to content

Instantly share code, notes, and snippets.

@reisepass
Created June 30, 2015 13:32
Show Gist options
  • Save reisepass/e1b3b48bbbe8217fbfe6 to your computer and use it in GitHub Desktop.
Save reisepass/e1b3b48bbbe8217fbfe6 to your computer and use it in GitHub Desktop.
Core portion of the Mean Field algorithm. Prettymuch one to one of the Matlab code
val Q = DenseMatrix.ones[Double](numRegions, numClasses)
var lastMaxE = 0.0
var lastMinE = 0.0
var numNoChange = 0
for (iter <- 0 until maxIterations) {
var numUnchangedQs = 0
val lastQ = Q;
val xiLab = (0 until numClasses).par
val allXiperLabel = xiLab.map(curLab => ((curLab,
for (xi <- 0 until graph.size) yield {
val neigh = graph.getC(xi).toArray
val allClasses = (0 until numClasses).toList
val newQest = neigh.toList.map { neighIdx =>
allClasses.foldLeft(0.0) { (running, curClass) =>
{
running + Math.exp(lastQ(neighIdx, curClass) * (if (DISABLE_PAIRWISE) 0 else thetaPairwise(curClass, curLab))) * Math.exp((1 / temp) * thetaUnary(xi, curLab))
}
}
}.sum
(1 / temp) * newQest
})))
for (labAgain <- 0 until numClasses) {
val allXi = allXiperLabel(labAgain)._2.toArray
Q(::, labAgain) := DenseVector(allXi)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment