Skip to content

Instantly share code, notes, and snippets.

@jtobin
Created October 1, 2016 11:13
Show Gist options
  • Save jtobin/0c097884f6340f29bd23ba52564e82f8 to your computer and use it in GitHub Desktop.
Save jtobin/0c097884f6340f29bd23ba52564e82f8 to your computer and use it in GitHub Desktop.
A Metropolis sampler.
module Metropolis where
import Control.Monad
import Control.Monad.Primitive
import System.Random.MWC as MWC
import System.Random.MWC.Distributions as MWC
propose :: [Double] -> Gen RealWorld -> IO [Double]
propose location gen = traverse (perturb gen) location where
perturb rng m = MWC.normal m 1 rng
moveProbability :: ([Double] -> Double) -> [Double] -> [Double] -> Double
moveProbability altitude current proposed =
whenNaN 0 (exp (min 0 (altitude proposed - altitude current)))
where
whenNaN val x
| isNaN x = val
| otherwise = x
decide :: [Double] -> [Double] -> Double -> Gen RealWorld -> IO [Double]
decide current proposed prob rng = do
accept <- MWC.bernoulli prob rng
return $
if accept
then proposed
else current
metropolis :: ([Double] -> Double) -> [Double] -> Gen RealWorld -> IO [Double]
metropolis altitude current rng = do
proposed <- propose current rng
let prob = moveProbability altitude current proposed
decide current proposed prob rng
chain :: Int -> ([Double] -> Double) -> [Double] -> Gen RealWorld -> IO [[Double]]
chain n altitude origin rng = loop n [origin] where
loop j history@(current:_)
| j <= 0 = return history
| otherwise = do
next <- metropolis altitude current rng
loop (j - 1) (next:history)
---
landscape :: [Double] -> Double
landscape xs = -0.5 * (x0 ^ 2 * x1 ^ 2 + x0 ^ 2 + x1 ^ 2 - 8 * x0 - 8 * x1)
where
x0 = xs !! 0
x1 = xs !! 1
main :: IO ()
main = do
rng <- MWC.createSystemRandom
let origin = [-0.2, 0.3]
trace <- chain 1000 landscape origin rng
mapM_ print trace
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment