-
-
Save justinrporter/8fd1651a790bfcf0dd6aecd894305cc1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def kinetic_timestep(c_prev, c_in_i, K): | |
c_next = jnp.einsum('ij,ijk->ik', c_prev, K) + c_in_i | |
return c_next, (c_next,) | |
def kinetic_rollout(K, c_in, c_init): | |
c_in = jnp.swapaxes(c_in, 0, 1) | |
_, (ys,) = jax.lax.scan( | |
partial(kinetic_timestep, K=K), | |
init=c_init, | |
xs=c_in, | |
length=len(c_in), | |
reverse=False | |
) | |
return jnp.swapaxes(ys, 0, 1) | |
def model_basic(c_obs, m_in): | |
N = m_in.shape[0] | |
assert N == c_obs.shape[0], "%s != %s" % (N, c_obs.shape[0]) | |
Vd = jnp.stack([ | |
numpyro.deterministic('Vd0', jnp.ones((N,))), | |
numpyro.sample('Vd1', dist.TruncatedNormal(loc=30, scale=15, low=0.0), sample_shape=(N,)), | |
numpyro.deterministic('Vd2', jnp.ones((N,))), | |
]) | |
with numpyro.plate('N', N): | |
# the exponent ("energy") of the rate parameters for | |
# each patient are distributed normally | |
dG_k01 = numpyro.sample('dG_k01', dist.TruncatedNormal(loc=5, scale=5, low=np.log(0.5))) | |
k01 = numpyro.deterministic('k01', jnp.exp(-dG_k01)) | |
dG_k12 = numpyro.sample('dG_k12', dist.TruncatedNormal(loc=5, scale=5, low=np.log(0.25))) | |
k12 = numpyro.deterministic('k12', jnp.exp(-dG_k12)) | |
dG_u = numpyro.sample('dG_u', dist.TruncatedNormal(loc=5, scale=5, low=np.log(0.25))) | |
u = numpyro.deterministic('u', jnp.exp(-dG_u)) | |
dG_k21 = numpyro.sample('dG_k21', dist.TruncatedNormal(loc=5, scale=5, low=np.log(0.5))) | |
k21 = numpyro.deterministic('k21', jnp.exp(-dG_k21)) | |
# Rate matrix comes just from the rate parameters | |
K = numpyro.deterministic('K', jnp.stack([ | |
jnp.asarray( | |
[[ 1-k01[i], k01[i]*(Vd[0,i]/Vd[1,i]), 0], | |
[ 0, 1-k12[i]-u[i], k12[i]*(Vd[1,i]/Vd[2,i])], | |
[ 0, k21[i]*(Vd[2,i]/Vd[1,i]), 1-k21[i]]]) | |
for i in range(len(k01)) | |
])) | |
# multiply by K and add m_in/Vd for all timepoints | |
c_pred = numpyro.deterministic('cs', kinetic_rollout( | |
K, | |
m_in / Vd.T[:, None, :], | |
c_init=jnp.zeros((N,3)), | |
)) | |
# error is distributed normally around the log of the concentration | |
epsilon=1 | |
obs_var = numpyro.sample('sig_obs', dist.Exponential(8)) | |
mask = ~jnp.isnan(c_obs[:, :]) | |
numpyro.sample( | |
'c_obs', dist.Normal( | |
loc=jnp.log(c_pred[:, :, 1]+epsilon), | |
scale=obs_var).mask(mask), | |
obs=jnp.log(c_obs[:, :]+epsilon) | |
) | |
nuts_kernel = NUTS( | |
model_basic, init_strategy=numpyro.infer.initialization.init_to_sample) | |
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000) | |
mcmc.run( | |
jax_random.PRNGKey(int(time.time())), | |
c_obs=np.array(Y), | |
m_in=np.array(M), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment