Skip to content

Instantly share code, notes, and snippets.

@maedoc
Last active October 20, 2021 06:33
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maedoc/fde96c9336b7129a4c0bc12d66473973 to your computer and use it in GitHub Desktop.
Save maedoc/fde96c9336b7129a4c0bc12d66473973 to your computer and use it in GitHub Desktop.
Simple example of variational inference with autograd
from autograd import grad, numpy as np
from autograd.scipy.stats import norm
from autograd.misc.optimizers import adam
def simple_vi(n=2000):
x = np.random.normal(loc=+1.5, scale=+0.3, size=10)
log_p = lambda z: np.mean(norm.logpdf(x[:, None], z[0], np.exp(z[1])), axis=0)
log_q = lambda z, l: norm.logpdf(z, l[:, None], 0.3)
samp_q = lambda l: np.random.normal(l[:, None], 0.3, (2, n))
# https://arxiv.org/pdf/1401.0118.pdf, eq 3
elbo = lambda l,z: log_q(z,l)*(log_p(z) - log_q(z,l))
gelbo = elementwise_grad(elbo)
l = np.r_[0.1, 0.1]
g = grad(elbo)
for i in range(2000):
z = samp_q(l)
g = gelbo(l, z)
l += 0.01 * gelbo(l,z) / n
if i % 100 == 0:
print(l, g, np.sum(elbo(l,z)))
print('found', l[0], np.exp(l[1]), 'expected', 1.5, 0.3)
simple_vi()
import numpy as np
import sympy as sp
from autograd.misc.optimizers import adam
# helper to construct normal distribution
N = lambda x, mu, sig: sp.log((1/(sig*sp.sqrt(sp.pi*2)))*sp.exp((-1/2)*((x-mu)/sig)**2)).simplify()
# make vars & distributions
x, z, l = sp.symbols('x,z,l')
lp = N(x, z, 1/3)
lq = N(z, l, 1/3)
# variational objective & derivative
elbo = -(lq * (lp - lq))
elbo_l = elbo.diff(l).simplify()
elbo_l_np = sp.lambdify([x,z,l], elbo_l)
# data and adam-compat function
x_ = np.random.randn()/3 + 1.5
def loss(l, i):
z = np.random.normal(l, 1/3)
return elbo_l_np(x_, z, l)
# run and print result
lhat = adam(loss, 0.1, step_size=0.1, num_iters=1000)
print('found', lhat, 'expected', 1.5)
@maedoc
Copy link
Author

maedoc commented Dec 16, 2019

The BBVI paper proposes some variance reduction techniques (Rao Blackwellization + control variate) but they seem like tricks while this is sort of a core essence of VI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment