Skip to content

Instantly share code, notes, and snippets.

@devmotion
Last active December 19, 2019 22:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save devmotion/601d95112df0920cdb9098cd2f2942e2 to your computer and use it in GitHub Desktop.
Save devmotion/601d95112df0920cdb9098cd2f2942e2 to your computer and use it in GitHub Desktop.
ESS examples
using Turing
using StatsPlots
using Random
using Statistics
function demo(N::Int; n::Int = 10)
# observation noise
σ² = 0.3
# define model
@model gdemo(x) = begin
m ~ Normal(0, 1)
x ~ MvNormal(fill(m, length(x)), sqrt(σ²))
end
# define observations
Random.seed!(1234)
x = vec(rand(Normal(1.4, sqrt(σ²)), n))
# generate MCMC chain
chain = sample(gdemo(x), ESS(), N)
#chain = sample(gdemo(x), NUTS(), N)
# compute posterior solution
τ² = inv(1 + length(x) / σ²)
μ = τ² / σ² * sum(x)
posterior = Normal(μ, sqrt(τ²))
# compare estimates
chain_array = vec(convert(Array, chain[:m]))
@show μ, mean(chain_array)
@show τ², var(chain_array)
# plot chain and posterior pdf
plot(chain)
plot!(posterior; subplot = 2, linestyle = :dash)
end
function gdemo(N::Int; nparticles::Int = 15)
# define model
@model gdemo(x, y) = begin
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x ~ Normal(m, sqrt(s))
y ~ Normal(m, sqrt(s))
return s, m
end
# generate MCMC chain
Random.seed!(100)
chain = sample(gdemo(1.5, 2.0), Gibbs(CSMC(nparticles, :s), ESS(:m)), N)
#chain = sample(gdemo(1.5, 2.0), Gibbs(CSMC(nparticles, :s), HMC(0.2, 4, :m)), N)
# define posterior solutions
μ = 7 / 6
λ = 3
α = 3
β = 49 / 12
posterior_m = LocationScale(μ, sqrt(β / (λ * α)), TDist(2 * α))
posterior_s = InverseGamma(α, β)
# compare estimates
chain_array_m = vec(convert(Array, chain[:m]))
@show mean(posterior_m), mean(chain_array_m)
@show var(posterior_m), var(chain_array_m)
chain_array_s = vec(convert(Array, chain[:s]))
@show mean(posterior_s), mean(chain_array_s)
@show var(posterior_s), var(chain_array_s)
# plot chain and posterior pdf
p1 = plot(chain[:m])
plot!(p1, posterior_m; subplot = 2, linestyle = :dash)
p2 = plot(chain[:s])
plot!(p2, posterior_s; subplot = 2, linestyle = :dash)
plot(p1, p2; layout = @layout [a; b])
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment