Skip to content

Instantly share code, notes, and snippets.

@ImScientist
Last active August 29, 2022 14:12
Show Gist options
  • Save ImScientist/4807b46a4f796220d102798216a2d7be to your computer and use it in GitHub Desktop.
Save ImScientist/4807b46a4f796220d102798216a2d7be to your computer and use it in GitHub Desktop.
Local level model for M trajectories of 2 elements each
import logging
from tabulate import tabulate
from typing import List, Dict, Tuple, Any
import numpy as np
import pandas as pd
import scipy.stats as sc
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
tfd = tfp.distributions
tfb = tfp.bijectors
dtype = tf.float32
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
tf.get_logger().setLevel('ERROR')
def cov_mat(t0, t1, sig_z, sig_y):
""" Cov 2x2 matrix for two observations from a local level process
(t0 < t1)
"""
cov = (sig_z ** 2 * np.array([[t0, t0],
[t0, t1]]) +
sig_y ** 2 * np.eye(2)
)
return cov
def generate_data(
w: float = 1.4,
sig_z: float = .2,
sig_y: float = .4,
nmax: int = 8,
tmin: int = 2,
tmax: int = 10,
seed: int = 12,
**kwargs
):
""" Generate nmax trajectories from a local level model with drift w
"""
assert 0 < tmin < tmax, "tmin < tmax should be fulfilled"
np.random.seed(seed)
ep = np.random.randn(nmax, tmax) * sig_y
et = np.random.randn(nmax, tmax) * sig_z
et[:, 0] = 0
drift = np.arange(tmax) * np.ones(shape=(nmax, 1)) * w
ys = drift + ep + et.cumsum(axis=1)
# take 2 time indices from every trajectory; shape = (nmax, 2)
ts_obs = np.stack([np.random.choice(range(tmin, tmax),
size=(2,),
replace=False) for _ in range(nmax)])
ts_obs = np.sort(ts_obs)
# shape = (nmax, 2)
ys_obs = np.stack([y[t] for t, y in zip(ts_obs, ys)])
return ys, ts_obs, ys_obs
def get_ab_params(
y: float, yp: float, t: float, tp: float, sig_z: float, sig_y: float
):
""" Get a, b parameters used in the analytical solution """
sig_t = np.sqrt(t * sig_z ** 2 + sig_y ** 2)
sig_tp = np.sqrt(tp * sig_z ** 2 + sig_y ** 2)
den = ((sig_t ** 2) * (sig_tp ** 2) -
np.minimum(t, tp) ** 2 * (sig_z ** 4)
)
b_num = (y * t * sig_tp ** 2 +
yp * tp * sig_t ** 2 -
np.minimum(t, tp) * (sig_z ** 2) * (y * tp + yp * t)
)
a_num = (t ** 2 * sig_tp ** 2 +
tp ** 2 * sig_t ** 2 -
2 * np.minimum(t, tp) * sig_z ** 2 * t * tp
)
a = a_num / den
b = b_num / den
return a, b
def growth_analytical_exp_prior(
y: np.array, yp: np.array, t: np.array, tp: np.array,
sig_z: float, sig_y: float, lam: float
):
""" Extend the previous function to the case of having multiple trees """
assert len(y) == len(yp) == len(t) == len(tp)
assert np.all(tp < t), "tp < t should be fulfilled"
a, b = get_ab_params(y, yp, t, tp, sig_z, sig_y)
mu = (b.sum() - lam) / a.sum()
sig = np.sqrt(1 / a.sum())
return dict(mu=mu, sig=sig)
def truncated_normal_dist(x, mu, sig):
""" Truncated normal distribution that is set to 0 for x < 0 """
rv = sc.norm(loc=mu, scale=sig)
norm = 1 - rv.cdf(0)
return (x >= 0) * rv.pdf(x) / norm
def build_model(
ts_obs: np.ndarray, # (n,2,)
ys_obs: np.ndarray, # (n,2,)
sig_z: float,
sig_y: float,
lam_0: float
):
assert ts_obs.shape[0] == ys_obs.shape[0]
assert ts_obs.shape[1] == ys_obs.shape[1] == 2
covs = np.stack([cov_mat(*t, sig_z, sig_y) for t in ts_obs])
ts_obs = tf.cast(ts_obs, dtype)
ys_obs = tf.cast(ys_obs, dtype)
covs = tf.cast(covs, dtype)
lam_0 = tf.cast(lam_0, dtype)
model = tfd.JointDistributionSequentialAutoBatched([
tfd.Exponential(rate=lam_0, name='w', force_probs_to_zero_outside_support=True),
lambda w: tfd.MultivariateNormalFullCovariance(
loc=w * ts_obs,
covariance_matrix=covs
)
])
model_log_prob_fn = lambda *w: model.log_prob(w + (ys_obs,))
return model, model_log_prob_fn
@tf.function(autograph=False, experimental_compile=True)
def run_chain(init_state, step_size, target_log_prob_fn, unconstraining_bijectors,
num_steps=500, burnin=50):
""" MCMC using the No-U-Turn Sampler
source:
https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks
"""
def trace_fn(_, pkr):
return (pkr.inner_results.inner_results.target_log_prob,
pkr.inner_results.inner_results.leapfrogs_taken,
pkr.inner_results.inner_results.has_divergence,
pkr.inner_results.inner_results.energy,
pkr.inner_results.inner_results.log_accept_ratio)
kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.NoUTurnSampler(target_log_prob_fn, step_size=step_size),
bijector=unconstraining_bijectors
)
adaptive_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
inner_kernel=kernel,
num_adaptation_steps=burnin,
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
inner_results=pkr.inner_results._replace(step_size=new_step_size)),
step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio
)
# Sampling from the chain.
chain_state, sampler_stat = tfp.mcmc.sample_chain(
num_results=num_steps,
num_burnin_steps=burnin,
current_state=init_state,
kernel=adaptive_kernel,
trace_fn=trace_fn)
return chain_state, sampler_stat
def plot_results(
ns: List[int],
results: List[Dict[str, float]],
results_mcmc: List[Tuple[Any, Any]],
xmin: float = .2,
xmax: float = 1.5
):
rows = len(ns) // 2
fig = plt.figure(figsize=(15, 4 * rows))
for idx, n in enumerate(ns):
res_true = results[idx]
w_samples = results_mcmc[idx][0][0].numpy().reshape(-1)
ax = fig.add_subplot(rows, 2, idx + 1)
x = np.linspace(xmin, xmax, 200)
ax.plot(x, truncated_normal_dist(x, res_true['mu'], res_true['sig']),
label=f'analytical (n={n})')
ax.plot(x, sc.kde.gaussian_kde(w_samples)(x), label=f'MCMC (n={n})')
ax.legend()
ax.set_xlabel(r'$\omega$')
return fig
if __name__ == '__main__':
""" Run experiment """
ns = [1, 2, 4, 8]
params = dict(
w=.5,
sig_z=.2,
sig_y=.4,
lam_0=70,
nmax=max(ns),
tmin=1,
tmax=5,
seed=12
)
ys, ts_obs, ys_obs = generate_data(**params)
# ##### Analytical solution
results = []
for n in ns:
res = growth_analytical_exp_prior(
y=ys_obs[:n, 1], yp=ys_obs[:n, 0],
t=ts_obs[:n, 1], tp=ts_obs[:n, 0],
sig_z=params['sig_z'],
sig_y=params['sig_y'],
lam=params['lam_0']
)
results.append(res)
df_true = pd.DataFrame(results, index=ns)
df_true.columns = pd.MultiIndex.from_tuples([('true', c) for c in df_true.columns])
# ##### MCMC
results_mcmc = []
for n in ns:
nchain = 8
model, target_log_prob_fn = build_model(
ts_obs=ts_obs[:n], ys_obs=ys_obs[:n],
sig_z=params['sig_z'], sig_y=params['sig_y'], lam_0=params['lam_0'])
w0, _ = model.sample(nchain)
init_state = [w0]
step_size = [tf.cast(i, dtype=dtype) for i in [.01]]
unconstraining_bijectors = [tfb.Identity()]
samples, sampler_stat = run_chain(
init_state, step_size, target_log_prob_fn, unconstraining_bijectors,
num_steps=5000, burnin=500)
results_mcmc.append((samples, sampler_stat))
data = []
for r in results_mcmc:
samples, _ = r
w_samples = samples[0].numpy().reshape(-1)
data.append([w_samples.mean(), w_samples.std()])
df_mcmc = pd.DataFrame(data, columns=['mu', 'sig'], index=ns)
df_mcmc.columns = pd.MultiIndex.from_tuples([('mcmc', c) for c in df_mcmc.columns])
# ##### Concatenate solutions
df = pd.concat([df_true, df_mcmc], axis=1)
logger.info('\n\ndf:\n' + tabulate(df, headers="keys"))
fig = plot_results(ns=ns, results=results, results_mcmc=results_mcmc)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment