Skip to content

Instantly share code, notes, and snippets.

@justinrporter
Last active Dec 21, 2021
Embed
What would you like to do?
def kinetic_timestep(c_prev, c_in_i, K):
c_next = jnp.einsum('ij,ijk->ik', c_prev, K) + c_in_i
return c_next, (c_next,)
def kinetic_rollout(K, c_in, c_init):
c_in = jnp.swapaxes(c_in, 0, 1)
_, (ys,) = jax.lax.scan(
partial(kinetic_timestep, K=K),
init=c_init,
xs=c_in,
length=len(c_in),
reverse=False
)
return jnp.swapaxes(ys, 0, 1)
def model_basic(c_obs, m_in):
N = m_in.shape[0]
assert N == c_obs.shape[0], "%s != %s" % (N, c_obs.shape[0])
Vd = jnp.stack([
numpyro.deterministic('Vd0', jnp.ones((N,))),
numpyro.sample('Vd1', dist.TruncatedNormal(loc=30, scale=15, low=0.0), sample_shape=(N,)),
numpyro.deterministic('Vd2', jnp.ones((N,))),
])
with numpyro.plate('N', N):
# the exponent ("energy") of the rate parameters for
# each patient are distributed normally
dG_k01 = numpyro.sample('dG_k01', dist.TruncatedNormal(loc=5, scale=5, low=np.log(0.5)))
k01 = numpyro.deterministic('k01', jnp.exp(-dG_k01))
dG_k12 = numpyro.sample('dG_k12', dist.TruncatedNormal(loc=5, scale=5, low=np.log(0.25)))
k12 = numpyro.deterministic('k12', jnp.exp(-dG_k12))
dG_u = numpyro.sample('dG_u', dist.TruncatedNormal(loc=5, scale=5, low=np.log(0.25)))
u = numpyro.deterministic('u', jnp.exp(-dG_u))
dG_k21 = numpyro.sample('dG_k21', dist.TruncatedNormal(loc=5, scale=5, low=np.log(0.5)))
k21 = numpyro.deterministic('k21', jnp.exp(-dG_k21))
# Rate matrix comes just from the rate parameters
K = numpyro.deterministic('K', jnp.stack([
jnp.asarray(
[[ 1-k01[i], k01[i]*(Vd[0,i]/Vd[1,i]), 0],
[ 0, 1-k12[i]-u[i], k12[i]*(Vd[1,i]/Vd[2,i])],
[ 0, k21[i]*(Vd[2,i]/Vd[1,i]), 1-k21[i]]])
for i in range(len(k01))
]))
# multiply by K and add m_in/Vd for all timepoints
c_pred = numpyro.deterministic('cs', kinetic_rollout(
K,
m_in / Vd.T[:, None, :],
c_init=jnp.zeros((N,3)),
))
# error is distributed normally around the log of the concentration
epsilon=1
obs_var = numpyro.sample('sig_obs', dist.Exponential(8))
mask = ~jnp.isnan(c_obs[:, :])
numpyro.sample(
'c_obs', dist.Normal(
loc=jnp.log(c_pred[:, :, 1]+epsilon),
scale=obs_var).mask(mask),
obs=jnp.log(c_obs[:, :]+epsilon)
)
nuts_kernel = NUTS(
model_basic, init_strategy=numpyro.infer.initialization.init_to_sample)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(
jax_random.PRNGKey(int(time.time())),
c_obs=np.array(Y),
m_in=np.array(M),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment