Last active
November 1, 2021 11:58
-
-
Save ImScientist/88091389e0c91669187bb77ff5a3845b to your computer and use it in GitHub Desktop.
Variational inference example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import pandas as pd | |
import numpy as np | |
import scipy.stats as sc | |
import tensorflow as tf | |
import tensorflow_probability as tfp | |
import matplotlib.pyplot as plt | |
from tabulate import tabulate | |
logging.basicConfig() | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
tfd = tfp.distributions | |
tfb = tfp.bijectors | |
dtype = tf.float32 | |
tf.get_logger().setLevel('ERROR') | |
def growth_analytical_exp_prior( | |
lam_0: float, sig: float, ys: np.array, xs: np.array | |
): | |
""" | |
Analytical solution for the parameters of the truncated normal distribution | |
of the growth rate | |
""" | |
mu_num = (ys * xs).sum() - lam_0 * sig ** 2 | |
denom = (xs ** 2).sum() | |
mu_post = mu_num / denom | |
sig_post = sig / np.sqrt(denom) | |
return dict(mu=mu_post, sig=sig_post) | |
def build_model_log_prob_fn(lam_0: float, sig: float, xs: np.array, ys: np.array): | |
""" Build model and log-likelihood fct | |
p(y,w|x) = p(y|w,x) * p(w|x) = p(y|w,x) * p(w) | |
p(w) = lam * exp( -lam * w ) | |
p(y|w,x) = \Prod_{i} Normal(y^{i} - w * x^{i}, sig) | |
ln( p(y,w|x) ) = ln(p(y|w,x)) + ln(p(w)) | |
""" | |
lam_0 = tf.cast(lam_0, dtype) | |
sig = tf.cast(sig, dtype) | |
xs = tf.cast(xs, dtype) | |
ys = tf.cast(ys, dtype) | |
model = tfd.JointDistributionSequentialAutoBatched([ | |
tfd.Exponential(rate=lam_0, name='w', force_probs_to_zero_outside_support=True), | |
lambda w: tfd.Normal(loc=w * xs, scale=sig, name='ys') | |
]) | |
model_log_prob_fn = lambda *x: model.log_prob(x + (ys,)) | |
return model, model_log_prob_fn | |
def build_surrogate_posterior(type_: str = 'truncated_normal'): | |
""" Build surrogate posterior with trainable parameters """ | |
assert type_ in ('truncated_normal', 'log_normal') | |
if type_ == 'log_normal': | |
Q = tfd.JointDistributionSequentialAutoBatched([ | |
tfd.LogNormal( | |
loc=tf.Variable(tf.random.normal(shape=(1,), stddev=.01, dtype=dtype), name='w_loc'), | |
scale=tfp.util.TransformedVariable(tf.constant([1], dtype), tfb.Softplus(), name='w_scale')) | |
]) | |
else: | |
Q = tfd.JointDistributionSequentialAutoBatched([ | |
tfd.TruncatedNormal( | |
loc=tf.Variable(tf.random.normal(shape=(1,), stddev=.01, dtype=dtype), name='w_loc'), | |
scale=tfp.util.TransformedVariable(tf.constant([1], dtype), tfb.Softplus(), name='w_scale'), | |
low=0., | |
high=10.) | |
]) | |
# tf.random.normal(shape=(1,), stddev=.01, dtype=dtype) | |
return Q | |
def generate_data(w_mu: float = .5, w_sig: float = 0, ep: float = 2.1, n: int = 500): | |
np.random.seed(10) | |
xs = sc.expon(scale=20, loc=30).rvs(n) | |
ys = (w_mu + np.random.randn(n) * w_sig) * xs + np.random.randn(n) * ep | |
params = dict(w_mu=w_mu, w_sig=w_sig, ep=ep) | |
return params, ys, xs | |
def truncated_normal_dist(x, mu, sig): | |
""" truncated normal distribution that is set to 0 for x < 0 """ | |
rv = sc.norm(loc=mu, scale=sig) | |
norm = 1 - rv.cdf(0) | |
return (x >= 0) * rv.pdf(x) / norm | |
@tf.function(autograph=False, experimental_compile=True) | |
def run_chain(init_state, step_size, target_log_prob_fn, unconstraining_bijectors, | |
num_steps=500, burnin=50): | |
""" MCMC using the No-U-Turn Sampler | |
source: | |
https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks | |
""" | |
def trace_fn(_, pkr): | |
return (pkr.inner_results.inner_results.target_log_prob, | |
pkr.inner_results.inner_results.leapfrogs_taken, | |
pkr.inner_results.inner_results.has_divergence, | |
pkr.inner_results.inner_results.energy, | |
pkr.inner_results.inner_results.log_accept_ratio) | |
kernel = tfp.mcmc.TransformedTransitionKernel( | |
inner_kernel=tfp.mcmc.NoUTurnSampler(target_log_prob_fn, step_size=step_size), | |
bijector=unconstraining_bijectors | |
) | |
adaptive_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation( | |
inner_kernel=kernel, | |
num_adaptation_steps=burnin, | |
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace( | |
inner_results=pkr.inner_results._replace(step_size=new_step_size)), | |
step_size_getter_fn=lambda pkr: pkr.inner_results.step_size, | |
log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio | |
) | |
# Sampling from the chain. | |
chain_state, sampler_stat = tfp.mcmc.sample_chain( | |
num_results=num_steps, | |
num_burnin_steps=burnin, | |
current_state=init_state, | |
kernel=adaptive_kernel, | |
trace_fn=trace_fn) | |
return chain_state, sampler_stat | |
# ##### Generate data | |
n_max = 500 | |
var_posterior_type = 'truncated_normal' # log_normal | |
model_params, ys, xs = generate_data(n=n_max) | |
plt.figure(figsize=(10, 4)) | |
plt.scatter(xs, ys, alpha=.3) | |
plt.grid() | |
plt.xlabel('age [yr]') | |
plt.ylabel('height') | |
plt.show() | |
# ##### Analytical solution | |
lam_0 = 200 | |
sig = 3 | |
ns = [2, 3, 10, 100] | |
results = [] | |
for n in ns: | |
res = growth_analytical_exp_prior(lam_0, sig, ys[:n], xs[:n]) | |
results.append(res) | |
df_true = pd.DataFrame(results, index=ns) | |
df_true.columns = pd.MultiIndex.from_tuples([('true', c) for c in df_true.columns]) | |
logger.info('\n\ndf_true:\n' + tabulate(df_true, headers="keys")) | |
# ##### Variational inference | |
qs = [] | |
for n in ns: | |
logger.info(f'n = {n}') | |
q = build_surrogate_posterior('truncated_normal') | |
model, target_log_prob_fn = build_model_log_prob_fn(lam_0, sig, xs[:n], ys[:n]) | |
optimizer = tf.optimizers.Adam(learning_rate=1e-2) | |
# fails if `jit_compile=True` | |
@tf.function() | |
def fit_vi(): | |
return tfp.vi.fit_surrogate_posterior( | |
target_log_prob_fn=target_log_prob_fn, | |
surrogate_posterior=q, | |
optimizer=optimizer, | |
num_steps=100000, | |
# sample_size=1 | |
) | |
losses = fit_vi() | |
qs.append(q) | |
results_vi = [dict(mu=q.variables[0].numpy(), | |
sig=tfb.Softplus(1.).forward(q.variables[1]).numpy()) for q in qs] | |
df_vi = pd.DataFrame(results_vi, index=ns) | |
df_vi.columns = pd.MultiIndex.from_tuples([('vi', c) for c in df_vi.columns]) | |
logger.info('\n\ndf_vi:\n' + tabulate(df_vi, headers="keys")) | |
# ##### Variational inference lognormal | |
qs_lognorm = [] | |
for n in ns: | |
logger.info(f'n = {n}') | |
q = build_surrogate_posterior(type_='log_normal') | |
model, target_log_prob_fn = build_model_log_prob_fn(lam_0, sig, xs[:n], ys[:n]) | |
optimizer = tf.optimizers.Adam(learning_rate=1e-2) | |
@tf.function(jit_compile=True) | |
def fit_vi(): | |
return tfp.vi.fit_surrogate_posterior( | |
target_log_prob_fn=target_log_prob_fn, | |
surrogate_posterior=q, | |
optimizer=optimizer, | |
num_steps=20000 | |
) | |
losses = fit_vi() | |
qs_lognorm.append(q) | |
results_vi_lognorm = [dict(mu=q.variables[0].numpy(), | |
sig=tfb.Softplus(1.).forward(q.variables[1]).numpy()) | |
for q in qs_lognorm] | |
df_vi_lognorm = pd.DataFrame(results_vi_lognorm, index=ns) | |
df_vi_lognorm.columns = pd.MultiIndex.from_tuples([('vi_lognormal', c) for c in df_vi_lognorm.columns]) | |
logger.info('\n\ndf_vi_lognorm:\n' + tabulate(df_vi_lognorm, headers="keys")) | |
# ##### MCMC | |
results_mcmc = [] | |
for n in ns: | |
nchain = 8 | |
model, target_log_prob_fn = build_model_log_prob_fn(lam_0, sig, xs[:n], ys[:n]) | |
w0, _ = model.sample(nchain) | |
init_state = [w0] | |
step_size = [tf.cast(i, dtype=dtype) for i in [.01]] | |
unconstraining_bijectors = [tfb.Identity()] | |
samples, sampler_stat = run_chain( | |
init_state, step_size, target_log_prob_fn, unconstraining_bijectors, num_steps=5000) | |
results_mcmc.append((samples, sampler_stat)) | |
data = [] | |
for r in results_mcmc: | |
samples, _ = r | |
w_samples = samples[0].numpy().reshape(-1) | |
data.append([w_samples.mean(), w_samples.std()]) | |
df_mcmc = pd.DataFrame(data, columns=['mu', 'sig'], index=ns) | |
df_mcmc.columns = pd.MultiIndex.from_tuples([('mcmc', c) for c in df_mcmc.columns]) | |
logger.info('\n\ndf_mcmc:\n' + tabulate(df_mcmc, headers="keys")) | |
# ##### Plot results | |
# The (mu, sig) in MCMC refer to the posterior sample mean and std and no the the parameters of | |
# the Truncated Normal distribution | |
df = pd.concat([df_true, df_vi, df_mcmc], 1) | |
logger.info('\n\nAll results:\n' + tabulate(df, headers="keys")) | |
rows = len(ns) // 2 | |
fig = plt.figure(figsize=(15, 4 * rows)) | |
for idx, n in enumerate(ns, 0): | |
res_true = results[idx] | |
res_vi = results_vi[idx] | |
w_samples = results_mcmc[idx][0][0].numpy().reshape(-1) | |
ax = fig.add_subplot(rows, 2, idx + 1) | |
x = np.linspace(-.05, .55, 200) | |
ax.plot(x, truncated_normal_dist(x, res_true['mu'], res_true['sig']), | |
label=f'analytical (n={n})') | |
ax.plot(x, truncated_normal_dist(x, res_vi['mu'], res_vi['sig']), | |
label=f'variational inference (n={n})') | |
ax.plot(x, sc.kde.gaussian_kde(w_samples)(x), label=f'MCMC (n={n})') | |
ax.legend() | |
ax.set_xlabel(r'$\omega$') | |
plt.show() | |
fig.savefig('results.png', dpi=200, bbox_inches='tight') | |
# ##### Plot results Lognormal dist | |
rows = len(ns) // 2 | |
fig = plt.figure(figsize=(15, 4 * rows)) | |
for idx, n in enumerate(ns, 0): | |
res_true = results[idx] | |
res_vi_lognorm = results_vi_lognorm[idx] | |
w_samples = results_mcmc[idx][0][0].numpy().reshape(-1) | |
ax = fig.add_subplot(rows, 2, idx + 1) | |
x = np.linspace(-.05, .55, 200) | |
ax.plot(x, truncated_normal_dist(x, res_true['mu'], res_true['sig']), | |
label=f'analytical (n={n})') | |
rv_norm = sc.norm(loc=res_vi_lognorm['mu'], scale=res_vi_lognorm['sig']) | |
ax.plot(x, rv_norm.pdf(np.log(x)) / x, | |
label=f'variational inference (n={n})') | |
ax.plot(x, sc.kde.gaussian_kde(w_samples)(x), label=f'MCMC (n={n})') | |
ax.legend() | |
ax.set_xlabel(r'$\omega$') | |
plt.show() | |
fig.savefig('results_lognormal.png', dpi=200, bbox_inches='tight') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment