Skip to content

Instantly share code, notes, and snippets.

@michaelchughes
Created November 1, 2021 04:04
Show Gist options
  • Save michaelchughes/278b4e70274ce778f1231b74701e1dce to your computer and use it in GitHub Desktop.
Save michaelchughes/278b4e70274ce778f1231b74701e1dce to your computer and use it in GitHub Desktop.
Demonstration of ELBO computation using Monte Carlo method
''' VI for Poisson Normal
Model
-----
Latent variable z is drawn from a Normal prior: z ~ Normal( 40, 10)
Data y is drawn iid from a Poisson likelihood: y_n ~ Poisson(z)
Approx Posterior
----------------
Posterior on z is assumed to be Normal with unknown mean and stddev
'''
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy.stats
import jax
import jax.numpy as jnp
import jax.scipy.stats as jstats
def calc_ELBO(q, prior, data, random_state=None, n_mc_samples=100):
''' Estimate the ELBO objective via Monte Carlo samples
'''
S = n_mc_samples
N = data['y_N'].size
z_S = random_state.randn(S) * q['stddev'] + q['mean']
log_prior_pdf_S = jstats.norm.logpdf(
z_S, prior['mean'], prior['stddev'])
log_q_pdf_S = jstats.norm.logpdf(
z_S, q['mean'], q['stddev'])
log_lik_pdf_NS = jstats.poisson.logpmf(
data['y_N'].reshape((N,1)), z_S.reshape((1,S)))
elbo_S = jnp.sum(log_lik_pdf_NS, axis=0) + log_prior_pdf_S - log_q_pdf_S
return jnp.mean(elbo_S) / N
if __name__ == '__main__':
n_mc_samples = 1000
random_state = np.random.RandomState(0)
z_true = 50.0
N = 100
y_N = scipy.stats.poisson(z_true).rvs(N, random_state)
data = {
'y_N':y_N
}
prior = {
'mean': 40.0,
'stddev': 10.0,
}
# Try q where the mean is varied from far below to far above true value
m_list = list()
elbo_list = list()
for delta in [-10, -5, 0, 5, 10]:
q = {
'mean': z_true + delta,
'stddev': 0.001
}
elbo = calc_ELBO(q, prior, data, random_state, n_mc_samples)
elbo_list.append(elbo)
m_list.append(q['mean'])
plt.plot(m_list, elbo_list, label='ELBO')
plt.plot(z_true * np.ones(2), [np.min(elbo_list), np.max(elbo_list)], '--', label='true z')
plt.xlabel('mean of q')
plt.legend()
plt.figure()
# Try q where the STDDEV is varied from far below to far above ideal value
n_reps = 5
for rr in range(n_reps):
s_list = list()
elbo_list = list()
for stddev in [0.01, 0.03, 0.1, 0.3, 1, 3.0, 10.0, 30.]:
q = {
'mean': z_true,
'stddev': stddev,
}
elbo = calc_ELBO(q, prior, data, random_state, n_mc_samples)
elbo_list.append(elbo)
s_list.append(q['stddev'])
plt.plot(np.log10(s_list), elbo_list, label='rep %02d' % (rr+1))
plt.xlabel('log stddev of q')
plt.ylabel('ELBO (estimated with %d samples)' % n_mc_samples)
plt.legend()
plt.show()
@michaelchughes
Copy link
Author

Expected output plot

image

@michaelchughes
Copy link
Author

Expected output plot: ELBO vs mean of q, holding the stddev of q fixed at a reasonable value

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment