Last active
June 30, 2021 15:38
-
-
Save dermesser/c7bb020179c305a4633e1ec1a7a4bdea to your computer and use it in GitHub Desktop.
Very very simple implementation of multi-dimensional Metropolis-Hastings algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### A Pluto.jl notebook ### | |
# v0.14.7 | |
using Markdown | |
using InteractiveUtils | |
# ╔═╡ 4f4ffc64-cdf9-11eb-32ef-e5e79401898e | |
using Random | |
# ╔═╡ fdc3d3af-f680-41b0-903e-779be1b24206 | |
using Distributions | |
# ╔═╡ 13db19ca-1904-4bd3-860d-0d9a0e73a543 | |
using Plots | |
# ╔═╡ b80b2eb1-6b72-4862-b6d4-53e0ec4f8cb0 | |
gr(html_output_format=:png) | |
# ╔═╡ 3b247ed5-aa3b-476b-9854-8e61bb5b7274 | |
rng = MersenneTwister() | |
# ╔═╡ d34ddc22-1ada-42e4-8d46-433f616245a8 | |
nd = Normal(0, 1) | |
# ╔═╡ 6ad545f7-04c4-43b4-b2eb-0bcd9664b873 | |
mutable struct MH | |
target::Function | |
uniform::UnivariateDistribution | |
jumpg_type::Type | |
x::Array{Float64} | |
width::Float64 | |
dims::Int64 | |
MH(target, width=1, dims=1) = new(target, Uniform(0,1), Normal, rand(dims), | |
width, dims) | |
end | |
# ╔═╡ 8259e911-6880-4a59-ad2c-c9db7ad99e79 | |
function sinsq_target(x) | |
x = x[1] | |
if x < 0 || x > 2pi | |
return 0 | |
end | |
sin(x)^2 | |
end | |
# ╔═╡ ad08aef0-a737-439f-8932-c9e9395b2012 | |
function sinsq2d(xy) | |
x, y = xy[1], xy[2] | |
if x < 0 || y < 0 || x > 2pi || y > 2pi | |
return 0 | |
end | |
sin(x)^2*sin(y)^2 | |
end | |
# ╔═╡ 3735e359-ced9-492f-b423-ff43ae43bb63 | |
sinsqmh = MH(sinsq_target, 3) | |
# ╔═╡ 62e8c485-c255-4d53-969b-0137b02bb7f3 | |
plot(sinsq_target.([[x] for x = LinRange(0, 2pi, 100)])) | |
# ╔═╡ 98a85bdb-17d8-4179-9f30-789773235523 | |
begin | |
dim = 50 | |
lr = LinRange(0, 2pi, dim) | |
heatmap(lr, lr, reshape([sinsq2d((x,y)) for x = lr, y = lr], (dim,dim))) | |
end | |
# ╔═╡ e1d595bc-d9f2-411c-9d9f-74fc55270a44 | |
md"""We implement the sampling as iterator, allowing easy looping over generated samples.""" | |
# ╔═╡ dc500c3f-47dc-476a-9114-dbc5d188a56d | |
function Base.iterate(mh::MH, st=nothing) :: Union{Nothing, Tuple{Any,Any}} | |
next = [rand(rng, mh.jumpg_type(c, mh.width)) for c = mh.x] | |
accept = mh.target(next)/mh.target(mh.x) | |
dice = rand(rng, mh.uniform) | |
if dice <= accept | |
mh.x = next | |
return (next, nothing) | |
end | |
return (mh.x, nothing) | |
end | |
# ╔═╡ 4b51d5a9-4f06-4d0b-a794-960e66a85199 | |
function sample(mh::MH, n=10000) | |
samples = zeros(n, mh.dims) | |
for (i, x) = enumerate(mh) | |
if i > n | |
return samples | |
end | |
samples[i, :] .= x | |
end | |
end | |
# ╔═╡ 0cc7983d-dea6-4bdc-9560-8336f8be02cd | |
@time samples = sample(sinsqmh); | |
# ╔═╡ 6e5d0dd7-7333-4c5b-9522-78daadd24448 | |
histogram(samples, bins=30) | |
# ╔═╡ d16d5dfd-0af4-4318-833c-b7e2a2198b34 | |
mh2d = MH(sinsq2d, .5, 2) | |
# ╔═╡ 114d7240-9bbd-49f8-9a4e-0153c3eddbd1 | |
@time samples2d = sample(mh2d, 10000) | |
# ╔═╡ c1957efd-43f7-47fc-9945-fc75de656eb5 | |
histogram2d(view(samples2d, :, 1), view(samples2d, :, 2), bins=100) | |
# ╔═╡ 23fbe57b-5b4f-42f1-b56c-1a2de9c35ae8 | |
plot(view(samples2d, 1:700, 1), view(samples2d, 1:700, 2), marker=:o, legend=false) | |
# ╔═╡ Cell order: | |
# ╠═4f4ffc64-cdf9-11eb-32ef-e5e79401898e | |
# ╠═fdc3d3af-f680-41b0-903e-779be1b24206 | |
# ╠═13db19ca-1904-4bd3-860d-0d9a0e73a543 | |
# ╠═b80b2eb1-6b72-4862-b6d4-53e0ec4f8cb0 | |
# ╠═3b247ed5-aa3b-476b-9854-8e61bb5b7274 | |
# ╠═d34ddc22-1ada-42e4-8d46-433f616245a8 | |
# ╠═6ad545f7-04c4-43b4-b2eb-0bcd9664b873 | |
# ╠═8259e911-6880-4a59-ad2c-c9db7ad99e79 | |
# ╠═ad08aef0-a737-439f-8932-c9e9395b2012 | |
# ╠═3735e359-ced9-492f-b423-ff43ae43bb63 | |
# ╠═62e8c485-c255-4d53-969b-0137b02bb7f3 | |
# ╠═98a85bdb-17d8-4179-9f30-789773235523 | |
# ╠═e1d595bc-d9f2-411c-9d9f-74fc55270a44 | |
# ╠═dc500c3f-47dc-476a-9114-dbc5d188a56d | |
# ╠═4b51d5a9-4f06-4d0b-a794-960e66a85199 | |
# ╠═0cc7983d-dea6-4bdc-9560-8336f8be02cd | |
# ╠═6e5d0dd7-7333-4c5b-9522-78daadd24448 | |
# ╠═d16d5dfd-0af4-4318-833c-b7e2a2198b34 | |
# ╠═114d7240-9bbd-49f8-9a4e-0153c3eddbd1 | |
# ╠═c1957efd-43f7-47fc-9945-fc75de656eb5 | |
# ╠═23fbe57b-5b4f-42f1-b56c-1a2de9c35ae8 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment