Skip to content

Instantly share code, notes, and snippets.

@ImScientist
Last active November 1, 2021 11:58
Show Gist options
  • Save ImScientist/88091389e0c91669187bb77ff5a3845b to your computer and use it in GitHub Desktop.
Save ImScientist/88091389e0c91669187bb77ff5a3845b to your computer and use it in GitHub Desktop.
Variational inference example
import logging
import pandas as pd
import numpy as np
import scipy.stats as sc
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
from tabulate import tabulate
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
tfd = tfp.distributions
tfb = tfp.bijectors
dtype = tf.float32
tf.get_logger().setLevel('ERROR')
def growth_analytical_exp_prior(
lam_0: float, sig: float, ys: np.array, xs: np.array
):
"""
Analytical solution for the parameters of the truncated normal distribution
of the growth rate
"""
mu_num = (ys * xs).sum() - lam_0 * sig ** 2
denom = (xs ** 2).sum()
mu_post = mu_num / denom
sig_post = sig / np.sqrt(denom)
return dict(mu=mu_post, sig=sig_post)
def build_model_log_prob_fn(lam_0: float, sig: float, xs: np.array, ys: np.array):
""" Build model and log-likelihood fct
p(y,w|x) = p(y|w,x) * p(w|x) = p(y|w,x) * p(w)
p(w) = lam * exp( -lam * w )
p(y|w,x) = \Prod_{i} Normal(y^{i} - w * x^{i}, sig)
ln( p(y,w|x) ) = ln(p(y|w,x)) + ln(p(w))
"""
lam_0 = tf.cast(lam_0, dtype)
sig = tf.cast(sig, dtype)
xs = tf.cast(xs, dtype)
ys = tf.cast(ys, dtype)
model = tfd.JointDistributionSequentialAutoBatched([
tfd.Exponential(rate=lam_0, name='w', force_probs_to_zero_outside_support=True),
lambda w: tfd.Normal(loc=w * xs, scale=sig, name='ys')
])
model_log_prob_fn = lambda *x: model.log_prob(x + (ys,))
return model, model_log_prob_fn
def build_surrogate_posterior(type_: str = 'truncated_normal'):
""" Build surrogate posterior with trainable parameters """
assert type_ in ('truncated_normal', 'log_normal')
if type_ == 'log_normal':
Q = tfd.JointDistributionSequentialAutoBatched([
tfd.LogNormal(
loc=tf.Variable(tf.random.normal(shape=(1,), stddev=.01, dtype=dtype), name='w_loc'),
scale=tfp.util.TransformedVariable(tf.constant([1], dtype), tfb.Softplus(), name='w_scale'))
])
else:
Q = tfd.JointDistributionSequentialAutoBatched([
tfd.TruncatedNormal(
loc=tf.Variable(tf.random.normal(shape=(1,), stddev=.01, dtype=dtype), name='w_loc'),
scale=tfp.util.TransformedVariable(tf.constant([1], dtype), tfb.Softplus(), name='w_scale'),
low=0.,
high=10.)
])
# tf.random.normal(shape=(1,), stddev=.01, dtype=dtype)
return Q
def generate_data(w_mu: float = .5, w_sig: float = 0, ep: float = 2.1, n: int = 500):
np.random.seed(10)
xs = sc.expon(scale=20, loc=30).rvs(n)
ys = (w_mu + np.random.randn(n) * w_sig) * xs + np.random.randn(n) * ep
params = dict(w_mu=w_mu, w_sig=w_sig, ep=ep)
return params, ys, xs
def truncated_normal_dist(x, mu, sig):
""" truncated normal distribution that is set to 0 for x < 0 """
rv = sc.norm(loc=mu, scale=sig)
norm = 1 - rv.cdf(0)
return (x >= 0) * rv.pdf(x) / norm
@tf.function(autograph=False, experimental_compile=True)
def run_chain(init_state, step_size, target_log_prob_fn, unconstraining_bijectors,
num_steps=500, burnin=50):
""" MCMC using the No-U-Turn Sampler
source:
https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks
"""
def trace_fn(_, pkr):
return (pkr.inner_results.inner_results.target_log_prob,
pkr.inner_results.inner_results.leapfrogs_taken,
pkr.inner_results.inner_results.has_divergence,
pkr.inner_results.inner_results.energy,
pkr.inner_results.inner_results.log_accept_ratio)
kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.NoUTurnSampler(target_log_prob_fn, step_size=step_size),
bijector=unconstraining_bijectors
)
adaptive_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
inner_kernel=kernel,
num_adaptation_steps=burnin,
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
inner_results=pkr.inner_results._replace(step_size=new_step_size)),
step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio
)
# Sampling from the chain.
chain_state, sampler_stat = tfp.mcmc.sample_chain(
num_results=num_steps,
num_burnin_steps=burnin,
current_state=init_state,
kernel=adaptive_kernel,
trace_fn=trace_fn)
return chain_state, sampler_stat
# ##### Generate data
n_max = 500
var_posterior_type = 'truncated_normal' # log_normal
model_params, ys, xs = generate_data(n=n_max)
plt.figure(figsize=(10, 4))
plt.scatter(xs, ys, alpha=.3)
plt.grid()
plt.xlabel('age [yr]')
plt.ylabel('height')
plt.show()
# ##### Analytical solution
lam_0 = 200
sig = 3
ns = [2, 3, 10, 100]
results = []
for n in ns:
res = growth_analytical_exp_prior(lam_0, sig, ys[:n], xs[:n])
results.append(res)
df_true = pd.DataFrame(results, index=ns)
df_true.columns = pd.MultiIndex.from_tuples([('true', c) for c in df_true.columns])
logger.info('\n\ndf_true:\n' + tabulate(df_true, headers="keys"))
# ##### Variational inference
qs = []
for n in ns:
logger.info(f'n = {n}')
q = build_surrogate_posterior('truncated_normal')
model, target_log_prob_fn = build_model_log_prob_fn(lam_0, sig, xs[:n], ys[:n])
optimizer = tf.optimizers.Adam(learning_rate=1e-2)
# fails if `jit_compile=True`
@tf.function()
def fit_vi():
return tfp.vi.fit_surrogate_posterior(
target_log_prob_fn=target_log_prob_fn,
surrogate_posterior=q,
optimizer=optimizer,
num_steps=100000,
# sample_size=1
)
losses = fit_vi()
qs.append(q)
results_vi = [dict(mu=q.variables[0].numpy(),
sig=tfb.Softplus(1.).forward(q.variables[1]).numpy()) for q in qs]
df_vi = pd.DataFrame(results_vi, index=ns)
df_vi.columns = pd.MultiIndex.from_tuples([('vi', c) for c in df_vi.columns])
logger.info('\n\ndf_vi:\n' + tabulate(df_vi, headers="keys"))
# ##### Variational inference lognormal
qs_lognorm = []
for n in ns:
logger.info(f'n = {n}')
q = build_surrogate_posterior(type_='log_normal')
model, target_log_prob_fn = build_model_log_prob_fn(lam_0, sig, xs[:n], ys[:n])
optimizer = tf.optimizers.Adam(learning_rate=1e-2)
@tf.function(jit_compile=True)
def fit_vi():
return tfp.vi.fit_surrogate_posterior(
target_log_prob_fn=target_log_prob_fn,
surrogate_posterior=q,
optimizer=optimizer,
num_steps=20000
)
losses = fit_vi()
qs_lognorm.append(q)
results_vi_lognorm = [dict(mu=q.variables[0].numpy(),
sig=tfb.Softplus(1.).forward(q.variables[1]).numpy())
for q in qs_lognorm]
df_vi_lognorm = pd.DataFrame(results_vi_lognorm, index=ns)
df_vi_lognorm.columns = pd.MultiIndex.from_tuples([('vi_lognormal', c) for c in df_vi_lognorm.columns])
logger.info('\n\ndf_vi_lognorm:\n' + tabulate(df_vi_lognorm, headers="keys"))
# ##### MCMC
results_mcmc = []
for n in ns:
nchain = 8
model, target_log_prob_fn = build_model_log_prob_fn(lam_0, sig, xs[:n], ys[:n])
w0, _ = model.sample(nchain)
init_state = [w0]
step_size = [tf.cast(i, dtype=dtype) for i in [.01]]
unconstraining_bijectors = [tfb.Identity()]
samples, sampler_stat = run_chain(
init_state, step_size, target_log_prob_fn, unconstraining_bijectors, num_steps=5000)
results_mcmc.append((samples, sampler_stat))
data = []
for r in results_mcmc:
samples, _ = r
w_samples = samples[0].numpy().reshape(-1)
data.append([w_samples.mean(), w_samples.std()])
df_mcmc = pd.DataFrame(data, columns=['mu', 'sig'], index=ns)
df_mcmc.columns = pd.MultiIndex.from_tuples([('mcmc', c) for c in df_mcmc.columns])
logger.info('\n\ndf_mcmc:\n' + tabulate(df_mcmc, headers="keys"))
# ##### Plot results
# The (mu, sig) in MCMC refer to the posterior sample mean and std and no the the parameters of
# the Truncated Normal distribution
df = pd.concat([df_true, df_vi, df_mcmc], 1)
logger.info('\n\nAll results:\n' + tabulate(df, headers="keys"))
rows = len(ns) // 2
fig = plt.figure(figsize=(15, 4 * rows))
for idx, n in enumerate(ns, 0):
res_true = results[idx]
res_vi = results_vi[idx]
w_samples = results_mcmc[idx][0][0].numpy().reshape(-1)
ax = fig.add_subplot(rows, 2, idx + 1)
x = np.linspace(-.05, .55, 200)
ax.plot(x, truncated_normal_dist(x, res_true['mu'], res_true['sig']),
label=f'analytical (n={n})')
ax.plot(x, truncated_normal_dist(x, res_vi['mu'], res_vi['sig']),
label=f'variational inference (n={n})')
ax.plot(x, sc.kde.gaussian_kde(w_samples)(x), label=f'MCMC (n={n})')
ax.legend()
ax.set_xlabel(r'$\omega$')
plt.show()
fig.savefig('results.png', dpi=200, bbox_inches='tight')
# ##### Plot results Lognormal dist
rows = len(ns) // 2
fig = plt.figure(figsize=(15, 4 * rows))
for idx, n in enumerate(ns, 0):
res_true = results[idx]
res_vi_lognorm = results_vi_lognorm[idx]
w_samples = results_mcmc[idx][0][0].numpy().reshape(-1)
ax = fig.add_subplot(rows, 2, idx + 1)
x = np.linspace(-.05, .55, 200)
ax.plot(x, truncated_normal_dist(x, res_true['mu'], res_true['sig']),
label=f'analytical (n={n})')
rv_norm = sc.norm(loc=res_vi_lognorm['mu'], scale=res_vi_lognorm['sig'])
ax.plot(x, rv_norm.pdf(np.log(x)) / x,
label=f'variational inference (n={n})')
ax.plot(x, sc.kde.gaussian_kde(w_samples)(x), label=f'MCMC (n={n})')
ax.legend()
ax.set_xlabel(r'$\omega$')
plt.show()
fig.savefig('results_lognormal.png', dpi=200, bbox_inches='tight')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment