Skip to content

Instantly share code, notes, and snippets.

@jonnylaw
Created June 1, 2016 16:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jonnylaw/0d01e2b67a5d4499ecd3674f97092aba to your computer and use it in GitHub Desktop.
Save jonnylaw/0d01e2b67a5d4499ecd3674f97092aba to your computer and use it in GitHub Desktop.
import breeze.stats.distributions.Gaussian
import Stream._
import java.io.{PrintWriter, File}
object MetropolisHastings {
case class Parameter(mu: Double, sigma: Double) {
override def toString = s"$mu, $sigma"
}
def sims(n: Int, params: Parameter): List[Double] = {
List.fill(n)(Gaussian(params.mu, params.sigma).draw)
}
def propose(p: Parameter, delta: Double): Parameter = {
Parameter(Gaussian(p.mu, delta).draw, p.sigma * math.exp(Gaussian(0, delta).draw))
}
def likelihood(y: Double): Parameter => Double = p => {
Gaussian(p.mu, p.sigma).logPdf(y)
}
def allLikelihood(data: List[Double]): Parameter => Double = p => {
data.foldLeft(0.0)((l, y) => l + likelihood(y)(p))
}
case class MetropState(ll: Double, p: Parameter, a: Int)
def mh(likelihood: Parameter => Double, p: Parameter, delta: Double): Stream[MetropState] = {
val rand = new java.util.Random
val initState = MetropState(-1e99, p, 0)
Stream.iterate(initState)(s => {
val propParam = propose(p, delta)
val propll = likelihood(propParam)
val a = propll - s.ll
if (math.log(rand.nextDouble) < a) {
MetropState(propll, propParam, s.a + 1)
} else {
s
}
})
}
def main(args: Array[String]): Unit = {
val p = Parameter(0.0, 1.0)
val data = sims(100, p)
val iters = mh(allLikelihood(data), p, 0.1).take(10000)
println(s"Accepted: ${iters.last.a.toDouble/10000}")
val pw = new PrintWriter(new File("MetropOut.csv"))
pw.write(iters.map(_.p).toVector.mkString("\n"))
pw.close()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment