Skip to content

Instantly share code, notes, and snippets.

@dermesser
Last active June 30, 2021 15:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dermesser/c7bb020179c305a4633e1ec1a7a4bdea to your computer and use it in GitHub Desktop.
Save dermesser/c7bb020179c305a4633e1ec1a7a4bdea to your computer and use it in GitHub Desktop.
Very very simple implementation of multi-dimensional Metropolis-Hastings algorithm
### A Pluto.jl notebook ###
# v0.14.7
using Markdown
using InteractiveUtils
# ╔═╡ 4f4ffc64-cdf9-11eb-32ef-e5e79401898e
using Random
# ╔═╡ fdc3d3af-f680-41b0-903e-779be1b24206
using Distributions
# ╔═╡ 13db19ca-1904-4bd3-860d-0d9a0e73a543
using Plots
# ╔═╡ b80b2eb1-6b72-4862-b6d4-53e0ec4f8cb0
gr(html_output_format=:png)
# ╔═╡ 3b247ed5-aa3b-476b-9854-8e61bb5b7274
rng = MersenneTwister()
# ╔═╡ d34ddc22-1ada-42e4-8d46-433f616245a8
nd = Normal(0, 1)
# ╔═╡ 6ad545f7-04c4-43b4-b2eb-0bcd9664b873
mutable struct MH
target::Function
uniform::UnivariateDistribution
jumpg_type::Type
x::Array{Float64}
width::Float64
dims::Int64
MH(target, width=1, dims=1) = new(target, Uniform(0,1), Normal, rand(dims),
width, dims)
end
# ╔═╡ 8259e911-6880-4a59-ad2c-c9db7ad99e79
function sinsq_target(x)
x = x[1]
if x < 0 || x > 2pi
return 0
end
sin(x)^2
end
# ╔═╡ ad08aef0-a737-439f-8932-c9e9395b2012
function sinsq2d(xy)
x, y = xy[1], xy[2]
if x < 0 || y < 0 || x > 2pi || y > 2pi
return 0
end
sin(x)^2*sin(y)^2
end
# ╔═╡ 3735e359-ced9-492f-b423-ff43ae43bb63
sinsqmh = MH(sinsq_target, 3)
# ╔═╡ 62e8c485-c255-4d53-969b-0137b02bb7f3
plot(sinsq_target.([[x] for x = LinRange(0, 2pi, 100)]))
# ╔═╡ 98a85bdb-17d8-4179-9f30-789773235523
begin
dim = 50
lr = LinRange(0, 2pi, dim)
heatmap(lr, lr, reshape([sinsq2d((x,y)) for x = lr, y = lr], (dim,dim)))
end
# ╔═╡ e1d595bc-d9f2-411c-9d9f-74fc55270a44
md"""We implement the sampling as iterator, allowing easy looping over generated samples."""
# ╔═╡ dc500c3f-47dc-476a-9114-dbc5d188a56d
function Base.iterate(mh::MH, st=nothing) :: Union{Nothing, Tuple{Any,Any}}
next = [rand(rng, mh.jumpg_type(c, mh.width)) for c = mh.x]
accept = mh.target(next)/mh.target(mh.x)
dice = rand(rng, mh.uniform)
if dice <= accept
mh.x = next
return (next, nothing)
end
return (mh.x, nothing)
end
# ╔═╡ 4b51d5a9-4f06-4d0b-a794-960e66a85199
function sample(mh::MH, n=10000)
samples = zeros(n, mh.dims)
for (i, x) = enumerate(mh)
if i > n
return samples
end
samples[i, :] .= x
end
end
# ╔═╡ 0cc7983d-dea6-4bdc-9560-8336f8be02cd
@time samples = sample(sinsqmh);
# ╔═╡ 6e5d0dd7-7333-4c5b-9522-78daadd24448
histogram(samples, bins=30)
# ╔═╡ d16d5dfd-0af4-4318-833c-b7e2a2198b34
mh2d = MH(sinsq2d, .5, 2)
# ╔═╡ 114d7240-9bbd-49f8-9a4e-0153c3eddbd1
@time samples2d = sample(mh2d, 10000)
# ╔═╡ c1957efd-43f7-47fc-9945-fc75de656eb5
histogram2d(view(samples2d, :, 1), view(samples2d, :, 2), bins=100)
# ╔═╡ 23fbe57b-5b4f-42f1-b56c-1a2de9c35ae8
plot(view(samples2d, 1:700, 1), view(samples2d, 1:700, 2), marker=:o, legend=false)
# ╔═╡ Cell order:
# ╠═4f4ffc64-cdf9-11eb-32ef-e5e79401898e
# ╠═fdc3d3af-f680-41b0-903e-779be1b24206
# ╠═13db19ca-1904-4bd3-860d-0d9a0e73a543
# ╠═b80b2eb1-6b72-4862-b6d4-53e0ec4f8cb0
# ╠═3b247ed5-aa3b-476b-9854-8e61bb5b7274
# ╠═d34ddc22-1ada-42e4-8d46-433f616245a8
# ╠═6ad545f7-04c4-43b4-b2eb-0bcd9664b873
# ╠═8259e911-6880-4a59-ad2c-c9db7ad99e79
# ╠═ad08aef0-a737-439f-8932-c9e9395b2012
# ╠═3735e359-ced9-492f-b423-ff43ae43bb63
# ╠═62e8c485-c255-4d53-969b-0137b02bb7f3
# ╠═98a85bdb-17d8-4179-9f30-789773235523
# ╠═e1d595bc-d9f2-411c-9d9f-74fc55270a44
# ╠═dc500c3f-47dc-476a-9114-dbc5d188a56d
# ╠═4b51d5a9-4f06-4d0b-a794-960e66a85199
# ╠═0cc7983d-dea6-4bdc-9560-8336f8be02cd
# ╠═6e5d0dd7-7333-4c5b-9522-78daadd24448
# ╠═d16d5dfd-0af4-4318-833c-b7e2a2198b34
# ╠═114d7240-9bbd-49f8-9a4e-0153c3eddbd1
# ╠═c1957efd-43f7-47fc-9945-fc75de656eb5
# ╠═23fbe57b-5b4f-42f1-b56c-1a2de9c35ae8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment