Skip to content

Instantly share code, notes, and snippets.

@dermesser
Last active September 6, 2022 09:06
Show Gist options
  • Save dermesser/0030ad422e1aa9cb90743fed1e8a890e to your computer and use it in GitHub Desktop.
Save dermesser/0030ad422e1aa9cb90743fed1e8a890e to your computer and use it in GitHub Desktop.
Primitive Hamiltonian Monte Carlo (HMC) sampler
using Plots
using Random
using Distributions, DistributionsAD
using LinearAlgebra
import Zygote: gradient
struct HMC{T,F}
sup::AbstractArray{Tuple{T,T}}
invM::Matrix{T}
pdist::AbstractMvNormal
L::Int64
logpdf::F
ΔT::T
end
function HMCnew(logpdf::F; L=10, ΔT::T=0.1, sup=[(-10., 10.)], M=(diagm(ones(length(sup)))))::HMC{T,F} where {T<:Real,F<:Function}
# TODO: Adapt mass.
HMC(sup, inv(M), MvNormal(zeros(length(sup)), M), L, logpdf, ΔT)
end
mutable struct HMCState{T}
x::AbstractArray{T}
p::AbstractArray{T}
t::T
end
function HMCState(hmc::HMC{T,F}) where {T <: Real, F <: Function}
dim = length(hmc.sup)
x0 = [rand(Uniform(s[1], s[2])) for s in hmc.sup]
p0 = zeros(dim)
t = 0.
HMCState(x0, p0, t)
end
function copy(s::HMCState{T})::HMCState{T} where {T}
HMCState(Base.copy(s.x), Base.copy(s.p), s.t)
end
function H(logpdf::F, invM::Matrix{T}, p::A, x::A)::T where {F <: Function, T <: Real, A <: AbstractArray{T}}
-logpdf(x) + 1/2 * p' * invM * p
end
function in_support(x::A, sup::AbstractArray{Tuple{T,T}})::Bool where {T<:Real, A<:AbstractArray{T}}
all(s[1] <= x[i] && x[i] <= s[2] for (i,s) in enumerate(sup))
end
function transition_probability(hmc::HMC{T,F}, s0::HMCState{T}, s1::HMCState{T})::T where {T<:Real, R, F, A<:AbstractArray{T}}
if !in_support(s1.x, hmc.sup)
return 0.
end
new, old = H(hmc.logpdf, hmc.invM, s1.p, s1.x), H(hmc.logpdf, hmc.invM, s0.p, s0.x)
min(1, exp(-(new-old)))
end
function leapfrog_step(hmc::HMC{T,F}, s::HMCState{T})::HMCState{T} where {T <: Real, R <: AbstractRNG, F <: Function, A <: AbstractArray{T}}
s.p = s.p + hmc.ΔT/2 * gradient(hmc.logpdf, s.x)[1]
s.x = s.x + hmc.ΔT * hmc.invM * s.p
s.p = s.p + hmc.ΔT/2 * gradient(hmc.logpdf, s.x)[1]
s.t += hmc.ΔT
s
end
function sample(hmc::HMC{T,F}, s::HMCState{T})::HMCState{T} where {T, R, F, A <: AbstractArray{T}}
u = Uniform(0, 1)
s0 = copy(s)
s1 = s
rand!(hmc.pdist, s1.p)
for i in 1:hmc.L
s1 = leapfrog_step(hmc, s1)
end
α = transition_probability(hmc, s0, s1)
if rand(u) <= α
# Accept!
return s1
else
return s0
end
end
function test_sample_snd()
nd = Normal(0, 1)
hmc = HMCnew(x -> logpdf(nd, x[1]), sup=[(0., 1.)])
hmcs = HMCState(hmc)
N = 10000
samples = zeros(N)
for i in 1:N
hmcs = sample(hmc, hmcs)
samples[i] = hmcs.x[1]
end
samples
plot()
histogram!(samples, bins=LinRange(-4, 4, 50), normalize=:pdf)
plot!(x -> pdf(nd, x))
current()
end
function normal_ppd(µs, σs, n=1000)::Vector{Float64}
s = zeros(n)
µs = Random.Sampler(Random.GLOBAL_RNG, µs, Val(1))
σs = Random.Sampler(Random.GLOBAL_RNG, σs, Val(1))
for i in 1:n
µ, σ = rand(µs), rand(σs)
s[i] = rand(Normal(µ, σ))
end
s
end
function test_sample_mcmc()
true_dist = Normal(5, 2)
observations = rand(true_dist, 30)
prior_µ = Normal(3, 1)
prior_σ = Normal(1, 2)
loglik(θ) = begin
logpdf(prior_µ, θ[1]) + logpdf(prior_σ, θ[2]) + sum(logpdf(Normal(θ[1], abs(θ[2])), o) for o in observations; init=0)
end
hmc = HMCnew(loglik; sup=[(-10., 10.), (0., 10.)])
hmcs = HMCState(hmc)
N = 1000
samples = zeros(2, N)
for i in 1:N
hmcs = sample(hmc, hmcs)
samples[:, i] = hmcs.x
end
samples
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment