Last active
August 29, 2022 14:12
-
-
Save ImScientist/4807b46a4f796220d102798216a2d7be to your computer and use it in GitHub Desktop.
Local level model for M trajectories of 2 elements each
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from tabulate import tabulate | |
from typing import List, Dict, Tuple, Any | |
import numpy as np | |
import pandas as pd | |
import scipy.stats as sc | |
import tensorflow as tf | |
import tensorflow_probability as tfp | |
import matplotlib.pyplot as plt | |
tfd = tfp.distributions | |
tfb = tfp.bijectors | |
dtype = tf.float32 | |
logging.basicConfig() | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
tf.get_logger().setLevel('ERROR') | |
def cov_mat(t0, t1, sig_z, sig_y): | |
""" Cov 2x2 matrix for two observations from a local level process | |
(t0 < t1) | |
""" | |
cov = (sig_z ** 2 * np.array([[t0, t0], | |
[t0, t1]]) + | |
sig_y ** 2 * np.eye(2) | |
) | |
return cov | |
def generate_data( | |
w: float = 1.4, | |
sig_z: float = .2, | |
sig_y: float = .4, | |
nmax: int = 8, | |
tmin: int = 2, | |
tmax: int = 10, | |
seed: int = 12, | |
**kwargs | |
): | |
""" Generate nmax trajectories from a local level model with drift w | |
""" | |
assert 0 < tmin < tmax, "tmin < tmax should be fulfilled" | |
np.random.seed(seed) | |
ep = np.random.randn(nmax, tmax) * sig_y | |
et = np.random.randn(nmax, tmax) * sig_z | |
et[:, 0] = 0 | |
drift = np.arange(tmax) * np.ones(shape=(nmax, 1)) * w | |
ys = drift + ep + et.cumsum(axis=1) | |
# take 2 time indices from every trajectory; shape = (nmax, 2) | |
ts_obs = np.stack([np.random.choice(range(tmin, tmax), | |
size=(2,), | |
replace=False) for _ in range(nmax)]) | |
ts_obs = np.sort(ts_obs) | |
# shape = (nmax, 2) | |
ys_obs = np.stack([y[t] for t, y in zip(ts_obs, ys)]) | |
return ys, ts_obs, ys_obs | |
def get_ab_params( | |
y: float, yp: float, t: float, tp: float, sig_z: float, sig_y: float | |
): | |
""" Get a, b parameters used in the analytical solution """ | |
sig_t = np.sqrt(t * sig_z ** 2 + sig_y ** 2) | |
sig_tp = np.sqrt(tp * sig_z ** 2 + sig_y ** 2) | |
den = ((sig_t ** 2) * (sig_tp ** 2) - | |
np.minimum(t, tp) ** 2 * (sig_z ** 4) | |
) | |
b_num = (y * t * sig_tp ** 2 + | |
yp * tp * sig_t ** 2 - | |
np.minimum(t, tp) * (sig_z ** 2) * (y * tp + yp * t) | |
) | |
a_num = (t ** 2 * sig_tp ** 2 + | |
tp ** 2 * sig_t ** 2 - | |
2 * np.minimum(t, tp) * sig_z ** 2 * t * tp | |
) | |
a = a_num / den | |
b = b_num / den | |
return a, b | |
def growth_analytical_exp_prior( | |
y: np.array, yp: np.array, t: np.array, tp: np.array, | |
sig_z: float, sig_y: float, lam: float | |
): | |
""" Extend the previous function to the case of having multiple trees """ | |
assert len(y) == len(yp) == len(t) == len(tp) | |
assert np.all(tp < t), "tp < t should be fulfilled" | |
a, b = get_ab_params(y, yp, t, tp, sig_z, sig_y) | |
mu = (b.sum() - lam) / a.sum() | |
sig = np.sqrt(1 / a.sum()) | |
return dict(mu=mu, sig=sig) | |
def truncated_normal_dist(x, mu, sig): | |
""" Truncated normal distribution that is set to 0 for x < 0 """ | |
rv = sc.norm(loc=mu, scale=sig) | |
norm = 1 - rv.cdf(0) | |
return (x >= 0) * rv.pdf(x) / norm | |
def build_model( | |
ts_obs: np.ndarray, # (n,2,) | |
ys_obs: np.ndarray, # (n,2,) | |
sig_z: float, | |
sig_y: float, | |
lam_0: float | |
): | |
assert ts_obs.shape[0] == ys_obs.shape[0] | |
assert ts_obs.shape[1] == ys_obs.shape[1] == 2 | |
covs = np.stack([cov_mat(*t, sig_z, sig_y) for t in ts_obs]) | |
ts_obs = tf.cast(ts_obs, dtype) | |
ys_obs = tf.cast(ys_obs, dtype) | |
covs = tf.cast(covs, dtype) | |
lam_0 = tf.cast(lam_0, dtype) | |
model = tfd.JointDistributionSequentialAutoBatched([ | |
tfd.Exponential(rate=lam_0, name='w', force_probs_to_zero_outside_support=True), | |
lambda w: tfd.MultivariateNormalFullCovariance( | |
loc=w * ts_obs, | |
covariance_matrix=covs | |
) | |
]) | |
model_log_prob_fn = lambda *w: model.log_prob(w + (ys_obs,)) | |
return model, model_log_prob_fn | |
@tf.function(autograph=False, experimental_compile=True) | |
def run_chain(init_state, step_size, target_log_prob_fn, unconstraining_bijectors, | |
num_steps=500, burnin=50): | |
""" MCMC using the No-U-Turn Sampler | |
source: | |
https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks | |
""" | |
def trace_fn(_, pkr): | |
return (pkr.inner_results.inner_results.target_log_prob, | |
pkr.inner_results.inner_results.leapfrogs_taken, | |
pkr.inner_results.inner_results.has_divergence, | |
pkr.inner_results.inner_results.energy, | |
pkr.inner_results.inner_results.log_accept_ratio) | |
kernel = tfp.mcmc.TransformedTransitionKernel( | |
inner_kernel=tfp.mcmc.NoUTurnSampler(target_log_prob_fn, step_size=step_size), | |
bijector=unconstraining_bijectors | |
) | |
adaptive_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation( | |
inner_kernel=kernel, | |
num_adaptation_steps=burnin, | |
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace( | |
inner_results=pkr.inner_results._replace(step_size=new_step_size)), | |
step_size_getter_fn=lambda pkr: pkr.inner_results.step_size, | |
log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio | |
) | |
# Sampling from the chain. | |
chain_state, sampler_stat = tfp.mcmc.sample_chain( | |
num_results=num_steps, | |
num_burnin_steps=burnin, | |
current_state=init_state, | |
kernel=adaptive_kernel, | |
trace_fn=trace_fn) | |
return chain_state, sampler_stat | |
def plot_results( | |
ns: List[int], | |
results: List[Dict[str, float]], | |
results_mcmc: List[Tuple[Any, Any]], | |
xmin: float = .2, | |
xmax: float = 1.5 | |
): | |
rows = len(ns) // 2 | |
fig = plt.figure(figsize=(15, 4 * rows)) | |
for idx, n in enumerate(ns): | |
res_true = results[idx] | |
w_samples = results_mcmc[idx][0][0].numpy().reshape(-1) | |
ax = fig.add_subplot(rows, 2, idx + 1) | |
x = np.linspace(xmin, xmax, 200) | |
ax.plot(x, truncated_normal_dist(x, res_true['mu'], res_true['sig']), | |
label=f'analytical (n={n})') | |
ax.plot(x, sc.kde.gaussian_kde(w_samples)(x), label=f'MCMC (n={n})') | |
ax.legend() | |
ax.set_xlabel(r'$\omega$') | |
return fig | |
if __name__ == '__main__': | |
""" Run experiment """ | |
ns = [1, 2, 4, 8] | |
params = dict( | |
w=.5, | |
sig_z=.2, | |
sig_y=.4, | |
lam_0=70, | |
nmax=max(ns), | |
tmin=1, | |
tmax=5, | |
seed=12 | |
) | |
ys, ts_obs, ys_obs = generate_data(**params) | |
# ##### Analytical solution | |
results = [] | |
for n in ns: | |
res = growth_analytical_exp_prior( | |
y=ys_obs[:n, 1], yp=ys_obs[:n, 0], | |
t=ts_obs[:n, 1], tp=ts_obs[:n, 0], | |
sig_z=params['sig_z'], | |
sig_y=params['sig_y'], | |
lam=params['lam_0'] | |
) | |
results.append(res) | |
df_true = pd.DataFrame(results, index=ns) | |
df_true.columns = pd.MultiIndex.from_tuples([('true', c) for c in df_true.columns]) | |
# ##### MCMC | |
results_mcmc = [] | |
for n in ns: | |
nchain = 8 | |
model, target_log_prob_fn = build_model( | |
ts_obs=ts_obs[:n], ys_obs=ys_obs[:n], | |
sig_z=params['sig_z'], sig_y=params['sig_y'], lam_0=params['lam_0']) | |
w0, _ = model.sample(nchain) | |
init_state = [w0] | |
step_size = [tf.cast(i, dtype=dtype) for i in [.01]] | |
unconstraining_bijectors = [tfb.Identity()] | |
samples, sampler_stat = run_chain( | |
init_state, step_size, target_log_prob_fn, unconstraining_bijectors, | |
num_steps=5000, burnin=500) | |
results_mcmc.append((samples, sampler_stat)) | |
data = [] | |
for r in results_mcmc: | |
samples, _ = r | |
w_samples = samples[0].numpy().reshape(-1) | |
data.append([w_samples.mean(), w_samples.std()]) | |
df_mcmc = pd.DataFrame(data, columns=['mu', 'sig'], index=ns) | |
df_mcmc.columns = pd.MultiIndex.from_tuples([('mcmc', c) for c in df_mcmc.columns]) | |
# ##### Concatenate solutions | |
df = pd.concat([df_true, df_mcmc], axis=1) | |
logger.info('\n\ndf:\n' + tabulate(df, headers="keys")) | |
fig = plot_results(ns=ns, results=results, results_mcmc=results_mcmc) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment