-
-
Save chvandorp/82eeb7ae8f5ad7d437d1d100eb44959a to your computer and use it in GitHub Desktop.
A python script demonstrating a Stan implementation of generalized profiling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Fit the Lotka-Volterra predator-prey model to artificial data | |
with the generalized profiling method implemented in Stan | |
""" | |
import pystan | |
import numpy as np | |
import scipy.stats as sts | |
import matplotlib.pyplot as plt | |
plt.rcParams.update({'font.size': 18}) | |
from matplotlib.gridspec import GridSpec | |
from scipy.integrate import solve_ivp | |
import sdeint | |
## compile the Stan model | |
sm = pystan.StanModel(file="gen-prof-LV.stan") | |
## choose nice parameter values | |
a = 1.0 | |
b = 0.4 | |
c = 0.4 | |
d = 0.5 | |
theta = [a, b, c, d] | |
## observation times and initial conditions | |
NumObs = 25 | |
tmin, tmax = 0, 50 | |
TimesObs = np.linspace(tmin, tmax, NumObs) | |
## initial values | |
u0 = [1.0, 1.0] ## x0, y0 | |
## the "system size" parameter K | |
K = 10 | |
## define the Lotka-Volterra predator-prey model | |
def LV_sys(t, u): | |
return [a*u[0] - b*u[0]*u[1], c*b*u[0]*u[1] - d*u[1]] | |
## use an ODE integrator to produce a trajectory | |
sol = solve_ivp(LV_sys, (tmin, tmax), u0, t_eval=TimesObs) | |
## generate random data (observations) | |
Obs = sts.poisson.rvs(sol.y.T*K) | |
## K determines the measurement noise | |
def run_gen_prof(sm, obs, times, lam, system_size, deg=3, | |
chains=4, chain_len=1000, thin=5): | |
""" | |
convenient function to make a data dictionary for Stan | |
and run the Stan model | |
""" | |
n = len(times) | |
## put a knot at every observation and between two observations | |
num_knots = 2*n-1 | |
## number of points for numerical integration | |
num_grid_pts = 3*num_knots-1 | |
grid_pts = np.linspace(times[0], times[n-1], num_grid_pts) | |
data = { | |
"NumKnots" : num_knots, | |
"SplineDeg" : deg, | |
"NumGridPts" : num_grid_pts, | |
"NumObs" : n, | |
"TimesObs" : times, | |
"Obs" : obs, | |
"K" : system_size, | |
"Lambda" : lam, | |
} | |
settings = { | |
"max_treedepth" : 15, | |
"adapt_delta" : 0.9 | |
} | |
sam = sm.sampling(data=data, chains=chains, iter=chain_len, | |
control=settings, thin=thin) | |
## return the grid points for plotting | |
return sam, grid_pts | |
## fit the model twice with different lambdas | |
lam_small = 1 | |
sam_small, GridPts = run_gen_prof(sm, Obs, TimesObs, lam_small, K) | |
lam_large = 100 | |
sam_large, _ = run_gen_prof(sm, Obs, TimesObs, lam_large, K) | |
def plot_gen_prof_fit(sam, times, obs, grid_pts, system_size, n=None): | |
""" | |
Make a figure with the data and the fitted spline. | |
Also add the derivative of the spline and the vector field | |
to give an indication of the deviation from the LV model | |
""" | |
if n is None: | |
n = len(times) | |
chain_dict = sam.extract(permuted=True) | |
fig = plt.figure(figsize=(14, 7)) | |
gs = GridSpec(4,1) | |
ax = fig.add_subplot(gs[:2,0]) | |
bxs = [fig.add_subplot(gs[2,0], sharex=ax), | |
fig.add_subplot(gs[3,0], sharex=ax)] | |
labels = ["Prey ($X$)", "Predator ($Y$)"] | |
colors = ["tab:blue", "tab:orange"] | |
pcts = [2.5, 97.5] | |
## make plots for predators and prey | |
for i, color in enumerate(colors): | |
ax.scatter(times[:n], obs[:n,i], color=color, edgecolors='k', | |
zorder=3, label=labels[i]) | |
## plot trajectories | |
uss = chain_dict["uhat"][:,:,i].T | |
mean_uhat = [system_size*np.mean(us) for us in uss] | |
ax.plot(grid_pts, mean_uhat, color='k', zorder=2, | |
label='fit' if i == 0 else None) | |
range_uhat = [system_size*np.percentile(us, pcts) for us in uss] | |
ax.fill_between(grid_pts, *np.array(range_uhat).T, color=color, | |
alpha=0.5, linewidth=0, zorder=1) | |
## plot simulations | |
uss = chain_dict["usim"][:,:,i].T | |
range_usim = [np.percentile(us, pcts) for us in uss] | |
ax.fill_between(grid_pts, *np.array(range_usim).T, color=color, | |
alpha=0.3, linewidth=0) | |
## plot derivative of the spline and the target derivative | |
uss = chain_dict["duhat_real"][:,:,i].T | |
mean_duhat_real = [system_size*np.mean(us) for us in uss] | |
bxs[i].plot(grid_pts, mean_duhat_real, color=color, | |
linewidth=3, label="spline") | |
uss = chain_dict["duhat_target"][:,:,i].T | |
mean_duhat_target = [system_size*np.mean(xs) for xs in uss] | |
bxs[i].plot(grid_pts, mean_duhat_target, color='k', | |
linestyle='--', label="LV model") | |
bxs[i].legend(loc=1, ncol=2, prop={'size': 10}) | |
## some labels etc... | |
ax.legend(loc=1, ncol=3, prop={'size': 10}) | |
ax.set_ylabel("data and fit") | |
for i, c in enumerate('xy'): | |
bxs[i].set_ylabel(f"$\\frac{{d{c}}}{{dt}}$", | |
rotation=0, va='center') | |
ax.get_xaxis().set_visible(False) | |
bxs[0].get_xaxis().set_visible(False) | |
bxs[1].set_xlabel("Time ($t$)") | |
fig.align_ylabels() | |
return fig, (ax, *bxs) | |
fig, axs = plot_gen_prof_fit(sam_small, TimesObs, Obs, GridPts, K) | |
fig.savefig("gen-prof-fit-small-lambda.png", dpi=300, bbox_inches='tight') | |
fig, axs = plot_gen_prof_fit(sam_large, TimesObs, Obs, GridPts, K) | |
fig.savefig("gen-prof-fit-large-lambda.png", dpi=300, bbox_inches='tight') | |
def plot_par_est(ax, sam, real_par_vals): | |
""" | |
plot parameter estimates and compare them with the real values | |
""" | |
chain_dict = sam.extract(permuted=True) | |
parnames = ["a", "b", "c", "d"] | |
latex_parnames = [f"${x}$" for x in parnames] | |
pcts = [2.5, 97.5] | |
## plot estimates and 95 percentiles | |
pos = range(len(parnames)) | |
means = [np.mean(chain_dict[x]) for x in parnames] | |
ranges = [np.percentile(chain_dict[x], pcts) for x in parnames] | |
ax.scatter(pos, [np.mean(chain_dict[x]) for x in parnames], | |
color='tab:red', label="estimate", marker='D', zorder=1) | |
for p, r in zip(pos, ranges): | |
ax.plot([p, p], r, color='tab:red', linewidth=2, zorder=1) | |
ax.set_xticks(pos) | |
ax.set_xticklabels(latex_parnames) | |
## plot real parameter values | |
ax.scatter(pos, real_par_vals, color='k', label="real value", zorder=2) | |
ax.legend(loc=1, ncol=1, prop={'size': 10}) | |
ax.set_ylabel("parameter value") | |
ax.set_xlabel("parameter name") | |
## compare the parameter estimates with small and large lambda | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7,3), sharey=True) | |
plot_par_est(ax1, sam_small, theta) | |
plot_par_est(ax2, sam_large, theta) | |
ax1.set_title(f"$\\lambda = {lam_small}$") | |
ax2.set_title(f"$\\lambda = {lam_large}$") | |
fig.savefig("gen-prof-estimates.png", dpi=300, bbox_inches='tight') | |
""" | |
Sample a trajectory from a stochastic LV model defined by an SDE. | |
Again use the generalized profiling method to estimate parameters. | |
""" | |
## add process noise to the model | |
sigma = 0.1 | |
## other parameters stay the same | |
NumObs = 50 | |
tmin, tmax = 0, 100 | |
Thin = 10 | |
TimesObsSDE = np.linspace(tmin, tmax, Thin*(NumObs-1)+1) | |
TimesObs = TimesObsSDE[::Thin] | |
## define the system | |
def LV_sys_drift(u, t): | |
return np.array([a*u[0] - b*u[0]*u[1], c*b*u[0]*u[1] - d*u[1]]) | |
def LV_sys_diffusion(u, t): | |
return sigma * np.diag(u) | |
sol = sdeint.itoint(LV_sys_drift, LV_sys_diffusion, u0, TimesObsSDE) | |
sol_ode = solve_ivp(LV_sys, (tmin, tmax), u0, t_eval=TimesObsSDE) | |
## generate random data (observations) | |
Obs = sts.poisson.rvs(sol[::Thin,:]*K) | |
## make a figure of the stochastic process, | |
## the data and the deterministic skellaton | |
pcts = [2.5, 97.5] ## percentiles used later for CrIs | |
colors = ["tab:blue", "tab:orange"] ## color for prey and predator | |
fig, axs = plt.subplots(2, 1, figsize=(14, 7), sharex=True) | |
for i, color in enumerate(colors): | |
axs[i].scatter(TimesObs, Obs[:,i], color=color, | |
edgecolors='k', zorder=2, label='observations') | |
axs[i].plot(TimesObsSDE, sol[:,i]*K, color=color, | |
linewidth=3, zorder=1, label='stochastic process') | |
axs[i].plot(TimesObsSDE, sol_ode.y[i]*K, color='k', alpha=0.75, | |
label='deterministic skellaton') | |
axs[i].legend(loc=0, ncol=3, prop={'size': 10}) | |
axs[-1].set_xlabel("Time ($t$)") | |
axs[0].set_ylabel("$x$ SDE (blue)\n$x$ ODE (black)") | |
axs[1].set_ylabel("$y$ SDE (orange)\n$y$ ODE (black)") | |
fig.savefig("stochastic-LV-sim.png", dpi=300, bbox_inches='tight') | |
## choose a lambda, and fit the model... | |
lam = 10 | |
sam, GridPts = run_gen_prof(sm, Obs, TimesObs, lam, K) | |
## plot the fit | |
fig, axs = plot_gen_prof_fit(sam, TimesObs, Obs, GridPts, K) | |
fig.savefig("gen-prof-fit-sde.png", dpi=300, bbox_inches='tight') | |
## plot estimates | |
fig, ax = plt.subplots(1, 1, figsize=(4,3)) | |
plot_par_est(ax, sam, theta) | |
fig.savefig("gen-prof-estimates-sde.png", dpi=300, bbox_inches='tight') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment