Created
June 24, 2022 10:21
-
-
Save teddygroves/6fe30c97bca6fbc9646a25fc24f69457 to your computer and use it in GitHub Desktop.
Interpolation example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
functions { | |
real interpolate_linear(vector z, vector ts, real t){ | |
/* Equation from https://en.wikipedia.org/wiki/Linear_interpolation */ | |
for (i in 1:rows(ts)-1){ | |
if ((ts[i] < t) && (ts[i+1] > t)){ | |
return z[i] + (t - ts[i]) * (z[i + 1] - z[i]) / (ts[i + 1] - ts[i]); | |
} | |
} | |
reject("Bad inputs: ", z, ts, t); | |
} | |
vector sir(real t, vector y, vector theta, vector ts, vector z){ | |
return [ | |
-theta[1] * y[2] * y[1], | |
theta[1] * y[2] * y[1] - theta[2] * y[2] + interpolate_linear(z, ts, t), | |
theta[2] * y[2] | |
]'; | |
} | |
} | |
data { | |
int<lower=1> N; | |
int<lower=1> N_ode_integration_times; | |
real t0; | |
vector[3] y0; | |
array[N_ode_integration_times] real ode_integration_times; | |
ordered[N] ts; | |
vector[N] z; | |
} | |
parameters { | |
vector[2] theta; | |
} | |
transformed parameters{ | |
array[N_ode_integration_times] vector[3] yhat = | |
ode_rk45(sir, y0, t0, ode_integration_times, theta, ts, z); | |
} | |
model { | |
theta ~ normal(0, 1); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cmdstanpy | |
DATA = { | |
"N": 3, | |
"N_ode_integration_times": 4, | |
"t0": 0.1, | |
"y0": [1, 2, 3], | |
"ode_integration_times": [1, 3, 5, 10], | |
"ts": [0.05, 4, 12], | |
"z": [1.2, 4.5, 6.0], | |
} | |
SAMPLE_KWARGS = { | |
"chains": 1, | |
"iter_sampling": 20, | |
"iter_warmup": 20 | |
} | |
def main(data, sample_kwargs): | |
assert data["ts"][0] < data["t0"] and data["ts"][-1] > data["t0"], "Bad data!" | |
model = cmdstanpy.CmdStanModel(stan_file="interpolation.stan") | |
mcmc = model.sample(data=data, **sample_kwargs) | |
print(mcmc.summary()) | |
if __name__ == '__main__': | |
main(DATA, SAMPLE_KWARGS) |
Hi again! I don't think I've seen that exact error before but it looks like what might happen with incompatible versions of cmdstanpy and cmdstan. To check this you could try seeing if the following code runs without errors (it's the hello world example from here):
import os
from cmdstanpy import cmdstan_path, CmdStanModel
stan_file = os.path.join(cmdstan_path(), 'examples', 'bernoulli', 'bernoulli.stan')
model = CmdStanModel(stan_file=stan_file)
fit = model.sample()
You can also just check the versions in an interactive python session with cmdstanpy imported like this (I've included the output I got):
In [36]: cmdstanpy.cmdstan_version()
Out[36]: (2, 29)
In [37]: cmdstanpy.__version__
Out[37]: '1.0.1'
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Teddy, That is great. I am just getting an error in running the code after finishing the chain. It seems printing result problem!!
INFO:cmdstanpy:found newer exe file, not recompiling
INFO:cmdstanpy:compiled model file: /content/interpolation
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:finish chain 1
RuntimeError Traceback (most recent call last)
in ()
26
27 if name == 'main':
---> 28 main(DATA, SAMPLE_KWARGS)
29
30 from google.colab import drive
2 frames
/usr/local/lib/python3.7/dist-packages/cmdstanpy/utils.py in do_command(cmd, cwd, logger)
731 if stderr:
732 msg = 'ERROR\n {} '.format(stderr.decode('utf-8').strip())
--> 733 raise RuntimeError(msg)
734 if stdout:
735 return stdout.decode('utf-8').strip()
RuntimeError: ERROR
The following argument was not expected: --csv_file=/tmp/tmpeob9l_5g/stansummary-interpolation-1-chain-ftm6x3a4.csv
Run with --help for more information.