Skip to content

Instantly share code, notes, and snippets.

@chvandorp
Last active January 12, 2020 00:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chvandorp/82eeb7ae8f5ad7d437d1d100eb44959a to your computer and use it in GitHub Desktop.
Save chvandorp/82eeb7ae8f5ad7d437d1d100eb44959a to your computer and use it in GitHub Desktop.
A python script demonstrating a Stan implementation of generalized profiling
#!/usr/bin/env python3
"""
Fit the Lotka-Volterra predator-prey model to artificial data
with the generalized profiling method implemented in Stan
"""
import pystan
import numpy as np
import scipy.stats as sts
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 18})
from matplotlib.gridspec import GridSpec
from scipy.integrate import solve_ivp
import sdeint
## compile the Stan model
sm = pystan.StanModel(file="gen-prof-LV.stan")
## choose nice parameter values
a = 1.0
b = 0.4
c = 0.4
d = 0.5
theta = [a, b, c, d]
## observation times and initial conditions
NumObs = 25
tmin, tmax = 0, 50
TimesObs = np.linspace(tmin, tmax, NumObs)
## initial values
u0 = [1.0, 1.0] ## x0, y0
## the "system size" parameter K
K = 10
## define the Lotka-Volterra predator-prey model
def LV_sys(t, u):
return [a*u[0] - b*u[0]*u[1], c*b*u[0]*u[1] - d*u[1]]
## use an ODE integrator to produce a trajectory
sol = solve_ivp(LV_sys, (tmin, tmax), u0, t_eval=TimesObs)
## generate random data (observations)
Obs = sts.poisson.rvs(sol.y.T*K)
## K determines the measurement noise
def run_gen_prof(sm, obs, times, lam, system_size, deg=3,
chains=4, chain_len=1000, thin=5):
"""
convenient function to make a data dictionary for Stan
and run the Stan model
"""
n = len(times)
## put a knot at every observation and between two observations
num_knots = 2*n-1
## number of points for numerical integration
num_grid_pts = 3*num_knots-1
grid_pts = np.linspace(times[0], times[n-1], num_grid_pts)
data = {
"NumKnots" : num_knots,
"SplineDeg" : deg,
"NumGridPts" : num_grid_pts,
"NumObs" : n,
"TimesObs" : times,
"Obs" : obs,
"K" : system_size,
"Lambda" : lam,
}
settings = {
"max_treedepth" : 15,
"adapt_delta" : 0.9
}
sam = sm.sampling(data=data, chains=chains, iter=chain_len,
control=settings, thin=thin)
## return the grid points for plotting
return sam, grid_pts
## fit the model twice with different lambdas
lam_small = 1
sam_small, GridPts = run_gen_prof(sm, Obs, TimesObs, lam_small, K)
lam_large = 100
sam_large, _ = run_gen_prof(sm, Obs, TimesObs, lam_large, K)
def plot_gen_prof_fit(sam, times, obs, grid_pts, system_size, n=None):
"""
Make a figure with the data and the fitted spline.
Also add the derivative of the spline and the vector field
to give an indication of the deviation from the LV model
"""
if n is None:
n = len(times)
chain_dict = sam.extract(permuted=True)
fig = plt.figure(figsize=(14, 7))
gs = GridSpec(4,1)
ax = fig.add_subplot(gs[:2,0])
bxs = [fig.add_subplot(gs[2,0], sharex=ax),
fig.add_subplot(gs[3,0], sharex=ax)]
labels = ["Prey ($X$)", "Predator ($Y$)"]
colors = ["tab:blue", "tab:orange"]
pcts = [2.5, 97.5]
## make plots for predators and prey
for i, color in enumerate(colors):
ax.scatter(times[:n], obs[:n,i], color=color, edgecolors='k',
zorder=3, label=labels[i])
## plot trajectories
uss = chain_dict["uhat"][:,:,i].T
mean_uhat = [system_size*np.mean(us) for us in uss]
ax.plot(grid_pts, mean_uhat, color='k', zorder=2,
label='fit' if i == 0 else None)
range_uhat = [system_size*np.percentile(us, pcts) for us in uss]
ax.fill_between(grid_pts, *np.array(range_uhat).T, color=color,
alpha=0.5, linewidth=0, zorder=1)
## plot simulations
uss = chain_dict["usim"][:,:,i].T
range_usim = [np.percentile(us, pcts) for us in uss]
ax.fill_between(grid_pts, *np.array(range_usim).T, color=color,
alpha=0.3, linewidth=0)
## plot derivative of the spline and the target derivative
uss = chain_dict["duhat_real"][:,:,i].T
mean_duhat_real = [system_size*np.mean(us) for us in uss]
bxs[i].plot(grid_pts, mean_duhat_real, color=color,
linewidth=3, label="spline")
uss = chain_dict["duhat_target"][:,:,i].T
mean_duhat_target = [system_size*np.mean(xs) for xs in uss]
bxs[i].plot(grid_pts, mean_duhat_target, color='k',
linestyle='--', label="LV model")
bxs[i].legend(loc=1, ncol=2, prop={'size': 10})
## some labels etc...
ax.legend(loc=1, ncol=3, prop={'size': 10})
ax.set_ylabel("data and fit")
for i, c in enumerate('xy'):
bxs[i].set_ylabel(f"$\\frac{{d{c}}}{{dt}}$",
rotation=0, va='center')
ax.get_xaxis().set_visible(False)
bxs[0].get_xaxis().set_visible(False)
bxs[1].set_xlabel("Time ($t$)")
fig.align_ylabels()
return fig, (ax, *bxs)
fig, axs = plot_gen_prof_fit(sam_small, TimesObs, Obs, GridPts, K)
fig.savefig("gen-prof-fit-small-lambda.png", dpi=300, bbox_inches='tight')
fig, axs = plot_gen_prof_fit(sam_large, TimesObs, Obs, GridPts, K)
fig.savefig("gen-prof-fit-large-lambda.png", dpi=300, bbox_inches='tight')
def plot_par_est(ax, sam, real_par_vals):
"""
plot parameter estimates and compare them with the real values
"""
chain_dict = sam.extract(permuted=True)
parnames = ["a", "b", "c", "d"]
latex_parnames = [f"${x}$" for x in parnames]
pcts = [2.5, 97.5]
## plot estimates and 95 percentiles
pos = range(len(parnames))
means = [np.mean(chain_dict[x]) for x in parnames]
ranges = [np.percentile(chain_dict[x], pcts) for x in parnames]
ax.scatter(pos, [np.mean(chain_dict[x]) for x in parnames],
color='tab:red', label="estimate", marker='D', zorder=1)
for p, r in zip(pos, ranges):
ax.plot([p, p], r, color='tab:red', linewidth=2, zorder=1)
ax.set_xticks(pos)
ax.set_xticklabels(latex_parnames)
## plot real parameter values
ax.scatter(pos, real_par_vals, color='k', label="real value", zorder=2)
ax.legend(loc=1, ncol=1, prop={'size': 10})
ax.set_ylabel("parameter value")
ax.set_xlabel("parameter name")
## compare the parameter estimates with small and large lambda
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7,3), sharey=True)
plot_par_est(ax1, sam_small, theta)
plot_par_est(ax2, sam_large, theta)
ax1.set_title(f"$\\lambda = {lam_small}$")
ax2.set_title(f"$\\lambda = {lam_large}$")
fig.savefig("gen-prof-estimates.png", dpi=300, bbox_inches='tight')
"""
Sample a trajectory from a stochastic LV model defined by an SDE.
Again use the generalized profiling method to estimate parameters.
"""
## add process noise to the model
sigma = 0.1
## other parameters stay the same
NumObs = 50
tmin, tmax = 0, 100
Thin = 10
TimesObsSDE = np.linspace(tmin, tmax, Thin*(NumObs-1)+1)
TimesObs = TimesObsSDE[::Thin]
## define the system
def LV_sys_drift(u, t):
return np.array([a*u[0] - b*u[0]*u[1], c*b*u[0]*u[1] - d*u[1]])
def LV_sys_diffusion(u, t):
return sigma * np.diag(u)
sol = sdeint.itoint(LV_sys_drift, LV_sys_diffusion, u0, TimesObsSDE)
sol_ode = solve_ivp(LV_sys, (tmin, tmax), u0, t_eval=TimesObsSDE)
## generate random data (observations)
Obs = sts.poisson.rvs(sol[::Thin,:]*K)
## make a figure of the stochastic process,
## the data and the deterministic skellaton
pcts = [2.5, 97.5] ## percentiles used later for CrIs
colors = ["tab:blue", "tab:orange"] ## color for prey and predator
fig, axs = plt.subplots(2, 1, figsize=(14, 7), sharex=True)
for i, color in enumerate(colors):
axs[i].scatter(TimesObs, Obs[:,i], color=color,
edgecolors='k', zorder=2, label='observations')
axs[i].plot(TimesObsSDE, sol[:,i]*K, color=color,
linewidth=3, zorder=1, label='stochastic process')
axs[i].plot(TimesObsSDE, sol_ode.y[i]*K, color='k', alpha=0.75,
label='deterministic skellaton')
axs[i].legend(loc=0, ncol=3, prop={'size': 10})
axs[-1].set_xlabel("Time ($t$)")
axs[0].set_ylabel("$x$ SDE (blue)\n$x$ ODE (black)")
axs[1].set_ylabel("$y$ SDE (orange)\n$y$ ODE (black)")
fig.savefig("stochastic-LV-sim.png", dpi=300, bbox_inches='tight')
## choose a lambda, and fit the model...
lam = 10
sam, GridPts = run_gen_prof(sm, Obs, TimesObs, lam, K)
## plot the fit
fig, axs = plot_gen_prof_fit(sam, TimesObs, Obs, GridPts, K)
fig.savefig("gen-prof-fit-sde.png", dpi=300, bbox_inches='tight')
## plot estimates
fig, ax = plt.subplots(1, 1, figsize=(4,3))
plot_par_est(ax, sam, theta)
fig.savefig("gen-prof-estimates-sde.png", dpi=300, bbox_inches='tight')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment