Skip to content

Instantly share code, notes, and snippets.

# dermesser/hmc.jl

Last active September 6, 2022 09:06
Show Gist options
• Save dermesser/0030ad422e1aa9cb90743fed1e8a890e to your computer and use it in GitHub Desktop.
Primitive Hamiltonian Monte Carlo (HMC) sampler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 using Plots using Random using Distributions, DistributionsAD using LinearAlgebra import Zygote: gradient struct HMC{T,F} sup::AbstractArray{Tuple{T,T}} invM::Matrix{T} pdist::AbstractMvNormal L::Int64 logpdf::F ΔT::T end function HMCnew(logpdf::F; L=10, ΔT::T=0.1, sup=[(-10., 10.)], M=(diagm(ones(length(sup)))))::HMC{T,F} where {T<:Real,F<:Function} # TODO: Adapt mass. HMC(sup, inv(M), MvNormal(zeros(length(sup)), M), L, logpdf, ΔT) end mutable struct HMCState{T} x::AbstractArray{T} p::AbstractArray{T} t::T end function HMCState(hmc::HMC{T,F}) where {T <: Real, F <: Function} dim = length(hmc.sup) x0 = [rand(Uniform(s[1], s[2])) for s in hmc.sup] p0 = zeros(dim) t = 0. HMCState(x0, p0, t) end function copy(s::HMCState{T})::HMCState{T} where {T} HMCState(Base.copy(s.x), Base.copy(s.p), s.t) end function H(logpdf::F, invM::Matrix{T}, p::A, x::A)::T where {F <: Function, T <: Real, A <: AbstractArray{T}} -logpdf(x) + 1/2 * p' * invM * p end function in_support(x::A, sup::AbstractArray{Tuple{T,T}})::Bool where {T<:Real, A<:AbstractArray{T}} all(s[1] <= x[i] && x[i] <= s[2] for (i,s) in enumerate(sup)) end function transition_probability(hmc::HMC{T,F}, s0::HMCState{T}, s1::HMCState{T})::T where {T<:Real, R, F, A<:AbstractArray{T}} if !in_support(s1.x, hmc.sup) return 0. end new, old = H(hmc.logpdf, hmc.invM, s1.p, s1.x), H(hmc.logpdf, hmc.invM, s0.p, s0.x) min(1, exp(-(new-old))) end function leapfrog_step(hmc::HMC{T,F}, s::HMCState{T})::HMCState{T} where {T <: Real, R <: AbstractRNG, F <: Function, A <: AbstractArray{T}} s.p = s.p + hmc.ΔT/2 * gradient(hmc.logpdf, s.x)[1] s.x = s.x + hmc.ΔT * hmc.invM * s.p s.p = s.p + hmc.ΔT/2 * gradient(hmc.logpdf, s.x)[1] s.t += hmc.ΔT s end function sample(hmc::HMC{T,F}, s::HMCState{T})::HMCState{T} where {T, R, F, A <: AbstractArray{T}} u = Uniform(0, 1) s0 = copy(s) s1 = s rand!(hmc.pdist, s1.p) for i in 1:hmc.L s1 = leapfrog_step(hmc, s1) end α = transition_probability(hmc, s0, s1) if rand(u) <= α # Accept! return s1 else return s0 end end function test_sample_snd() nd = Normal(0, 1) hmc = HMCnew(x -> logpdf(nd, x[1]), sup=[(0., 1.)]) hmcs = HMCState(hmc) N = 10000 samples = zeros(N) for i in 1:N hmcs = sample(hmc, hmcs) samples[i] = hmcs.x[1] end samples plot() histogram!(samples, bins=LinRange(-4, 4, 50), normalize=:pdf) plot!(x -> pdf(nd, x)) current() end function normal_ppd(µs, σs, n=1000)::Vector{Float64} s = zeros(n) µs = Random.Sampler(Random.GLOBAL_RNG, µs, Val(1)) σs = Random.Sampler(Random.GLOBAL_RNG, σs, Val(1)) for i in 1:n µ, σ = rand(µs), rand(σs) s[i] = rand(Normal(µ, σ)) end s end function test_sample_mcmc() true_dist = Normal(5, 2) observations = rand(true_dist, 30) prior_µ = Normal(3, 1) prior_σ = Normal(1, 2) loglik(θ) = begin logpdf(prior_µ, θ[1]) + logpdf(prior_σ, θ[2]) + sum(logpdf(Normal(θ[1], abs(θ[2])), o) for o in observations; init=0) end hmc = HMCnew(loglik; sup=[(-10., 10.), (0., 10.)]) hmcs = HMCState(hmc) N = 1000 samples = zeros(2, N) for i in 1:N hmcs = sample(hmc, hmcs) samples[:, i] = hmcs.x end samples end
to join this conversation on GitHub. Already have an account? Sign in to comment