# dermesser/hmc.jl

Last active September 6, 2022 09:06
Primitive Hamiltonian Monte Carlo (HMC) sampler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below.
 using Plots using Random using Distributions, DistributionsAD using LinearAlgebra import Zygote: gradient struct HMC{T,F} sup::AbstractArray{Tuple{T,T}} invM::Matrix{T} pdist::AbstractMvNormal L::Int64 logpdf::F ΔT::T end function HMCnew(logpdf::F; L=10, ΔT::T=0.1, sup=[(-10., 10.)], M=(diagm(ones(length(sup)))))::HMC{T,F} where {T<:Real,F<:Function} # TODO: Adapt mass. HMC(sup, inv(M), MvNormal(zeros(length(sup)), M), L, logpdf, ΔT) end mutable struct HMCState{T} x::AbstractArray{T} p::AbstractArray{T} t::T end function HMCState(hmc::HMC{T,F}) where {T <: Real, F <: Function} dim = length(hmc.sup) x0 = [rand(Uniform(s[1], s[2])) for s in hmc.sup] p0 = zeros(dim) t = 0. HMCState(x0, p0, t) end function copy(s::HMCState{T})::HMCState{T} where {T} HMCState(Base.copy(s.x), Base.copy(s.p), s.t) end function H(logpdf::F, invM::Matrix{T}, p::A, x::A)::T where {F <: Function, T <: Real, A <: AbstractArray{T}} -logpdf(x) + 1/2 * p' * invM * p end function in_support(x::A, sup::AbstractArray{Tuple{T,T}})::Bool where {T<:Real, A<:AbstractArray{T}} all(s[1] <= x[i] && x[i] <= s[2] for (i,s) in enumerate(sup)) end function transition_probability(hmc::HMC{T,F}, s0::HMCState{T}, s1::HMCState{T})::T where {T<:Real, R, F, A<:AbstractArray{T}} if !in_support(s1.x, hmc.sup) return 0. end new, old = H(hmc.logpdf, hmc.invM, s1.p, s1.x), H(hmc.logpdf, hmc.invM, s0.p, s0.x) min(1, exp(-(new-old))) end function leapfrog_step(hmc::HMC{T,F}, s::HMCState{T})::HMCState{T} where {T <: Real, R <: AbstractRNG, F <: Function, A <: AbstractArray{T}} s.p = s.p + hmc.ΔT/2 * gradient(hmc.logpdf, s.x)[1] s.x = s.x + hmc.ΔT * hmc.invM * s.p s.p = s.p + hmc.ΔT/2 * gradient(hmc.logpdf, s.x)[1] s.t += hmc.ΔT s end function sample(hmc::HMC{T,F}, s::HMCState{T})::HMCState{T} where {T, R, F, A <: AbstractArray{T}} u = Uniform(0, 1) s0 = copy(s) s1 = s rand!(hmc.pdist, s1.p) for i in 1:hmc.L s1 = leapfrog_step(hmc, s1) end α = transition_probability(hmc, s0, s1) if rand(u) <= α # Accept! return s1 else return s0 end end function test_sample_snd() nd = Normal(0, 1) hmc = HMCnew(x -> logpdf(nd, x[1]), sup=[(0., 1.)]) hmcs = HMCState(hmc) N = 10000 samples = zeros(N) for i in 1:N hmcs = sample(hmc, hmcs) samples[i] = hmcs.x[1] end samples plot() histogram!(samples, bins=LinRange(-4, 4, 50), normalize=:pdf) plot!(x -> pdf(nd, x)) current() end function normal_ppd(µs, σs, n=1000)::Vector{Float64} s = zeros(n) µs = Random.Sampler(Random.GLOBAL_RNG, µs, Val(1)) σs = Random.Sampler(Random.GLOBAL_RNG, σs, Val(1)) for i in 1:n µ, σ = rand(µs), rand(σs) s[i] = rand(Normal(µ, σ)) end s end function test_sample_mcmc() true_dist = Normal(5, 2) observations = rand(true_dist, 30) prior_µ = Normal(3, 1) prior_σ = Normal(1, 2) loglik(θ) = begin logpdf(prior_µ, θ[1]) + logpdf(prior_σ, θ[2]) + sum(logpdf(Normal(θ[1], abs(θ[2])), o) for o in observations; init=0) end hmc = HMCnew(loglik; sup=[(-10., 10.), (0., 10.)]) hmcs = HMCState(hmc) N = 1000 samples = zeros(2, N) for i in 1:N hmcs = sample(hmc, hmcs) samples[:, i] = hmcs.x end samples end
