Skip to content

Instantly share code, notes, and snippets.

@cscherrer
Last active July 1, 2019 23:44
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save cscherrer/72062a3b9f264a328f00159d84a61b98 to your computer and use it in GitHub Desktop.
Variational Importance Sampling
using Pkg
Pkg.add.(
["Distributions"
, "MonteCarloMeasurements"
, "StatsFuns"
, "LaTeXStrings"
, "Plots"])
using Distributions
using MonteCarloMeasurements
using StatsFuns
# Just some setup so inequalities propagate through particles
for rel in [<,>,<=,>=]
register_primitive(rel)
end
function fromObs(x,y)
function logp(α,β)
= 0.0
+= logpdf(Normal(0,1), α)
+= logpdf(Normal(0,2), β)
yhat = α .+ β .* x
+= sum(logpdf.(Normal.(yhat, 1), y) )
end
end
drawcat(ℓ, k) = [argmax(ℓ + Particles(1000,Gumbel())) for j in 1:k]
asmatrix(ps...) = Matrix([ps...])'
# Kish's effective sample size,
# $n _ { \mathrm { eff } } = \frac { \left( \sum _ { i = 1 } ^ { n } w _ { i } \right) ^ { 2 } } { \sum _ { i = 1 } ^ { n } w _ { i } ^ { 2 } }$
function n_eff(ℓ)
logw =.particles
exp(2 * logsumexp(logw) - logsumexp(2 .* logw))
end
function f(a,b)
# generate data
x = rand(Normal(),100)
yhat = a .+ b .* x
y = rand.(Normal.(yhat, 1))
# generate p
logp = fromObs(x,y)
runInference(x,y,logp)
end
function runInference(x,y,logp)
N = 1000
# initialize q
q = MvNormal(2,100000.0) # Really this would be fit from a sample from the prior
α,β = Particles(N,q)
m = asmatrix(α,β)
= sum(logp(α,β)) - Particles(logpdf(q,m))
numiters = 60
elbo = Vector{Float64}(undef, numiters)
for j in 1:numiters
α,β = Particles(N,q)
m = asmatrix(α,β)
= logp(α,β) - Particles(logpdf(q,m))
elbo[j] = mean(ℓ)
ss = suffstats(MvNormal, m, exp(ℓ - maximum(ℓ)).particles .+ 1/N)
q = fit_mle(MvNormal, ss)
end
(α,β,q,ℓ,elbo)
end
(α,β,q,ℓ,elbo) = f(3,4)
using LaTeXStrings
using Plots
plot(1:60, -elbo
, xlabel="Iteration"
, ylabel="Negative ELBO"
, legend=false
, yscale=:log10)
xticks!([0,20,40,60], [L"0",L"20", L"40",L"60"])
yticks!(10 .^ [3,6,9,12], [L"10^3", L"10^6",L"10^9",L"10^{12}"])
savefig("neg-elbo.svg")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment