-
-
Save ekmett/5739676c452ada86af2a to your computer and use it in GitHub Desktop.
Hamiltonian Monte Carlo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rnorm :: Double -> Double -> Gen Double | |
rnorm mean stdev = do | |
u1 <- choose (0,1) | |
u2 <- choose (0,1) | |
return $ mean + stdev*sqrt (-2*log u1)*sin(2*pi*u2) | |
rep :: Int -> a -> (a -> a) -> a | |
rep n0 a0 f = go n0 a0 where | |
go 0 a = a | |
go n a = f $! go (n - 1) a | |
{-# INLINE rep #-} | |
-- Take a Hamiltonian Monte Carlo step | |
hmc :: (Double -> Double) -> -- unnormalized negative log pdf | |
(Double -> Double) -> -- grad of U | |
Double -> -- momentum variance | |
Gen Double -> -- leapfrog epsilon | |
Int -> -- leapfrog steps | |
Double -> -- q | |
Gen Double -- result of one time step | |
hmc u gu m meps l q0 = do | |
eps <- meps | |
p0 <- rnorm 0 m | |
let k p = p * p / (2 * m) | |
halfStepP p q = p - eps*gu q/2 | |
fullStepP p q = p - eps*gu q | |
fullStepQ q p = q + eps*p | |
(q',p') = rep (l-1) (q0, halfStepP p0 q0) $ \(pt,qt) -> | |
let qt' = fullStepQ qt pt in (qt', fullStepP pt (gu qt')) | |
qs = fullStepQ q' p' | |
ps = halfStepP p' qs | |
ps' = negate ps | |
alpha = exp (u q0 - u qs + k p0 - k ps') | |
v <- choose (0,1) | |
return $ if v <= alpha then qs else q0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment