Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Last active May 17, 2019 11:17
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save torfjelde/c3df745703a6fdd544f3d944fb04cd6b to your computer and use it in GitHub Desktop.
Save torfjelde/c3df745703a6fdd544f3d944fb04cd6b to your computer and use it in GitHub Desktop.
An example of how one can implement Automatic Derivative Variational Inference (VI) for Turing.jl.
using Turing, Bijectors, ForwardDiff, LinearAlgebra, Optim, LineSearches
import StatsBase: sample
import Optim: optimize
#############
# Utilities #
#############
function jac_inv_transform(dist::Distribution, x::T where T<:Real)
ForwardDiff.derivative(x -> invlink(dist, x), x)
end
function jac_inv_transform(dist::Distribution, x::Array{T} where T <: Real)
ForwardDiff.jacobian(x -> invlink(dist, x), x)
end
function center_diag_gaussian(x, μ, σ)
# instead of creating a diagonal matrix, we just do elementwise multiplication
(σ .^(-1)) .* (x - μ)
end
function center_diag_gaussian_inv(η, μ, σ)
(η .* σ) + μ
end
#########################
# Variational Inference #
#########################
abstract type VariationalInference end
"""
sample(vi::VariationalInference, num_samples)
Produces `num_samples` samples for the given VI method using number of samples equal to `num_samples`.
"""
function sample(vi::VariationalInference, num_samples) end
"""
elbo(vi::VariationalInference, num_samples)
Computes empirical estimates of ELBO for the given VI method using number of samples equal to `num_samples`.
"""
function elbo(vi::VariationalInference, num_samples) end
"""
optimize(vi::VariationalInference)
Finds parameters which maximizes the ELBO for the given VI method.
"""
function optimize(vi::VariationalInference) end
"""
ADVI(model::Turing.Model)
Automatic Differentiation Variational Inference (ADVI) for a given model.
"""
struct ADVI{T <: Real} <: VariationalInference
model::Turing.Model
μ::Vector{T}
ω::Vector{T}
end
ADVI(model::Turing.Model) = begin
# setup
var_info = Turing.VarInfo()
model(var_info, Turing.SampleFromUniform())
num_params = size(var_info.vals, 1)
ADVI(model, zeros(num_params), zeros(num_params))
end
function sample(vi::ADVI, num_samples)
# setup
var_info = Turing.VarInfo()
vi.model(var_info, Turing.SampleFromUniform())
num_params = size(var_info.vals, 1)
# convenience
μ, ω = vi.μ, vi.ω
# buffer
samples = zeros(num_samples, num_params)
for i = 1:size(var_info.dists, 1)
prior = var_info.dists[i]
r = var_info.ranges[i]
# initials
μ_i = μ[r]
ω_i = ω[r]
# # sample from VI posterior
θ_acc = zeros(length(μ_i))
for j = 1:num_samples
η = randn(length(μ_i))
ζ = center_diag_gaussian_inv(η, μ_i, exp.(ω_i))
θ = invlink(prior, ζ)
samples[j, r] = θ
end
end
return samples
end
function optimize(vi::ADVI; samples_per_step = 10, max_iters = 500)
# setup
var_info = Turing.VarInfo()
vi.model(var_info, Turing.SampleFromUniform())
num_params = size(var_info.vals, 1)
function objective(x)
# extract the mean-field Gaussian params
μ, ω = x[1:num_params], x[num_params + 1: end]
- elbo(vi, μ, ω, samples_per_step)
end
# for every param we need a mean μ and variance ω
x = zeros(2 * num_params)
diff_result = DiffResults.GradientResult(x)
# used for truncated adaGrad as suggested in (Blei et al, 2015).
η = 0.1
τ = 1.0
ρ = zeros(2 * num_params)
s = zeros(2 * num_params)
g² = zeros(2 * num_params)
# number of previous gradients to use to compute `s` in adaGrad
stepsize_num_prev = 10
i = 0
while (i < max_iters) # & converged # <= add criterion? A running mean maybe?
# compute gradient
ForwardDiff.gradient!(diff_result, objective, x)
# recursive implementation of updating the step-size
# if beyound first sequence of steps we subtract of the previous g² before adding the next
if i > stepsize_num_prev
s -= g²
end
# update parameters for adaGrad
g² .= DiffResults.gradient(diff_result).^2
s += g²
# compute stepsize
@. ρ = η / (τ + sqrt(s))
x .= x - ρ .* DiffResults.gradient(diff_result)
@info "Step $i" ρ DiffResults.value(diff_result) norm(DiffResults.gradient(diff_result))
i += 1
end
μ, ω = x[1:num_params], x[num_params + 1: end]
return μ, ω
end
function elbo(vi::ADVI, μ::Vector{T}, ω::Vector{T}, num_samples) where T <: Real
# setup
var_info = Turing.VarInfo()
# initial `Var_Info` object
vi.model(var_info, Turing.SampleFromUniform())
num_params = size(var_info.vals, 1)
elbo_acc = 0.0
for i = 1:num_samples
# iterate through priors, sample and update
for i = 1:size(var_info.dists, 1)
prior = var_info.dists[i]
r = var_info.ranges[i]
# mean-field params for this set of model params
μ_i = μ[r]
ω_i = ω[r]
# obtain samples from mean-field posterior approximation
η = randn(length(μ_i))
ζ = center_diag_gaussian_inv(η, μ_i, exp.(ω_i))
# inverse-transform back to original param space
θ = invlink(prior, ζ)
# update
var_info.vals[r] = θ
# add the log-det-jacobian of inverse transform
elbo_acc += log(abs(det(jac_inv_transform(prior, ζ)))) / num_samples
end
# sample with updated variables
vi.model(var_info)
elbo_acc += var_info.logp / num_samples
end
# add the term for the entropy of the variational posterior
variational_posterior_entropy = sum(ω)
elbo_acc += variational_posterior_entropy
elbo_acc
end
function elbo(vi::ADVI, num_samples)
# extract the mean-field Gaussian params
μ, ω = vi.μ, vi.ω
elbo(vi, μ, ω, num_samples)
end
##################
# Simple example #
##################
@model demo(x) = begin
s ~ InverseGamma(2,3)
m ~ Normal(0.0, sqrt(s)) # `Normal(μ, σ)` has mean μ and variance σ², i.e. parametrize with std. not variance
for i = 1:length(x)
x[i] ~ Normal(m, sqrt(s))
end
end
# generate data
x = randn(1, 1000);
# produce "true" samples using NUTS
m = demo(x)
chain = sample(m, NUTS(2000, 200, 0.65))
# ADVI
m = demo(x)
vi = ADVI(m) # default construction of ADVI
μ, ω = optimize(vi, samples_per_step = 5, max_iters = 5000) # maximize ELBO
vi = ADVI(m, μ, ω) # construct new from optimized values
samples = sample(vi, 2000)
# quick check
println([mean(samples, dims=1), [var(x), mean(x)]])
# closed form
using ConjugatePriors
# prior
# notation mapping has been verified by explicitly computing expressions
# in "Conjugate Bayesian analysis of the Gaussian distribution" by Murphy
μ₀ = 0.0 # => μ
κ₀ = 1.0 # => ν, which scales the precision of the Normal
α₀ = 2.0 # => "shape"
β₀ = 3.0 # => "rate", which is 1 / θ, where θ is "scale"
pri = NormalGamma(μ₀, κ₀, α₀, β₀)
# posterior
post = posterior(pri, Normal, x)
# marginal distribution of τ = 1 / σ²
# Eq. (90) in "Conjugate Bayesian analysis of the Gaussian distribution" by Murphy
# `scale(post)` = θ
p_τ = Gamma(post.shape, scale(post))
p_σ²_pdf = z -> pdf(p_τ, 1 / z) # τ => 1 / σ²
# marginal of μ
# Eq. (91) in "Conjugate Bayesian analysis of the Gaussian distribution" by Murphy
p_μ = TDist(2 * post.shape)
μₙ = post.mu # μ → μ
κₙ = post.nu # κ → ν
αₙ = post.shape # α → shape
βₙ = post.rate # β → rate
# numerically more stable but doesn't seem to have effect; issue is probably internal to
# `pdf` which needs to compute ≈ Γ(1000)
p_μ_pdf = z -> exp(logpdf(p_μ, (z - μₙ) * exp(- 0.5 * log(βₙ) + 0.5 * log(αₙ) + 0.5 * log(κₙ))))
# p_μ_pdf1 = z -> pdf(p_μ, (z - μₙ) / √(βₙ / (αₙ * κₙ)))
#################
# Visualization #
#################
# visualize
using Plots, StatsPlots, LaTeXStrings
pyplot()
p1 = plot();
density!(samples[:, 1], label = "s (ADVI)", color = :blue, linestyle = :dash)
histogram!(samples[:, 1], label = "", normed = true, alpha = 0.3, color = :blue);
density!([chain[:s].value...], label = "s (NUTS)", color = :green, linestyle = :dashdot)
histogram!([chain[:s].value...], label = "", normed = true, color = :green, alpha = 0.3)
# normalize using Riemann approx. because of (almost certainly) numerical issues
Δ = 0.001
r = 0.75:0.001:1.25
norm_const = sum(p_σ²_pdf.(r) .* Δ)
plot!(r, p_σ²_pdf, label = "s (posterior)", color = :red);
vline!([var(x)], label = "s (data)", linewidth = 1.5, color = :black, alpha = 0.7)
xlims!(0.5, 1.5)
title!(L"$x_i \sim \mathcal{N}(0, 1)$ for $i = 1,\dots,1000$")
p2 = plot()
density!(samples[:, 2], label = "m (ADVI)", color = :blue, linestyle = :dash)
histogram!(samples[:, 2], label = "", normed = true, alpha = 0.3, color = :blue)
density!([chain[:m].value...], label = "m (NUTS)", color = :green, linestyle = :dashdot)
histogram!([chain[:m].value...], label = "", normed = true, color = :green, alpha = 0.3)
# normalize using Riemann approx. because of (almost certainly) numerical issues
Δ = 0.0001
r = -0.1 + mean(x):Δ:0.1 + mean(x)
norm_const = sum(p_μ_pdf.(r) .* Δ)
plot!(r, z -> p_μ_pdf(z) / norm_const, label = "m (posterior)", color = :red);
vline!([mean(x)], label = "m (data)", linewidth = 1.5, color = :black, alpha = 0.7)
xlims!(-0.25, 0.25)
p = plot(p1, p2; layout = (2, 1))
savefig(p, "advi_proper.png")
@xukai92
Copy link

xukai92 commented Apr 3, 2019

OK. I found it a bit difficult to comment on the code. Do you know if there is a way that I can do the comment thing like during reviewing a PR?

@torfjelde
Copy link
Author

torfjelde commented Apr 5, 2019

Made the suggested changes and the comments you made before.

Also realized I'd put the entropy term for the variational posterior inside the empirical estimate loop; fixed it!

@torfjelde
Copy link
Author

I've made it so that we instead get an empirical estimate of the distribution using the ADVI estimate, rather than a point-estimate as before. Added comparison with NUTS.

I'll look into the closed form expression you mentioned as a point of comparison next.

@torfjelde
Copy link
Author

torfjelde commented Apr 9, 2019

Added closed form posterior for comparison, though have to use a temporary "work-around" numerical issues. Probably due to internal computations using gamma-function for large values. Would be interesting to into further.

Also, this snippet has gotten to a point where it might be suitable for a standalone project. But oh well.

@torfjelde
Copy link
Author

Using Optim.jl with GradientDescent and LineSearches.BackTracking(order=3) would at times fail (line-search would reach maximum iterations). As mentioned before Optim.jl is not made for stochastic optimization. Though when it did converge, the results were really good.

Therefore, for consistency of experiments, I've now changed to using a manual implementation of AdaGrad. This at least produce consistent results across runs, in contrast to using Optim.jl. Will be interesting to explore the use of optimization algorithm in the future.

References:
[1] Kucukelbir, A., Ranganath, R., Gelman, A., & Blei, D. M., Automatic variational inference in stan, CoRR, (), (2015).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment