Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Last active March 21, 2024 15:50
Show Gist options
  • Save vankesteren/33f2f9b077642758232c515bdf4b8862 to your computer and use it in GitHub Desktop.
Save vankesteren/33f2f9b077642758232c515bdf4b8862 to your computer and use it in GitHub Desktop.
Figuring out some ELBO stuff...
# Let's figure out this ELBO thing
using Distributions, StatsPlots, Optim, Random
Random.seed!(45)
# The target distribution. Assume we don't know it but
# we can compute the (unnormalized) logpdf and sample
# from it. For illustration, let's make it a weird mixture
comps = [Normal(2, 3), Normal(-3, 1.5), LogNormal(3, 0.4)]
probs = [.1, .1, .8]
p = MixtureModel(comps, probs)
# Let's take a look!
histogram(rand(p, 100_000), normalize = true, label = "p")
# A function for computing the ELBO from two distributions, using
# Monte Carlo integration as in equation 7 of the following paper:
# https://www.jmlr.org/papers/volume23/21-0889/21-0889.pdf
function ELBO(p::Distribution, q::Distribution; K::Int = 1000)
x = rand(q, K) # sample from q for MC integration
return mean(logpdf.(p, x) - logpdf.(q, x))
end
# We can compute the ELBO for a particular distribution
ELBO(p, Normal(0, 1))
ELBO(p, Normal(10, 8))
# we can see that Normal(10, 8) is better, as it has higher ELBO
# now define a loss function to minimize
loss = θ -> -ELBO(p, Normal(θ[1], exp(θ[2])))
res = optimize(loss, ones(2), SimulatedAnnealing())
θ = res.minimizer
q = Normal(θ[1], exp(θ[2]))
# Now let's plot!
plot!(q, label = "q (ELBO)", color = "green")
@vankesteren
Copy link
Author

elbo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment