Created
October 11, 2022 09:04
-
-
Save tpaixao/9475fd4121bab80bff9dbe7c83f7f8c7 to your computer and use it in GitHub Desktop.
SDE examples in pymc3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pymc3 as pm | |
import arviz as az | |
import theano.tensor as tt | |
from pymc3.distributions.timeseries import EulerMaruyama | |
import matplotlib.pyplot as plt | |
def simulate(sde_fun,params,x0,t_max,n_reps,dt=0.01): | |
""" | |
simulates sde_fun | |
:sde_fun: function that returns deterministic and stochastic parts of SDE | |
:params: parameters for sde_fun | |
:x0: initial condition (scalar - all replicates will start from x0) | |
:t_max: maximum time | |
:n_reps: number of replicates of timeseries - will be stacked | |
:dt: delta t for discretization | |
""" | |
tts=np.arange(0,t_max,dt) | |
data = np.zeros((n_reps,tts.shape[0],)) | |
data[:,0] = np.ones(n_reps)*x0 | |
for i in range(tts.shape[0]-1): | |
x=data[:,i] | |
f,g = sde_fun(x,*params) | |
x+=f*dt+np.sqrt(dt)*g*np.random.normal(size=(n_reps,)) | |
data[:,i+1] = x | |
return data | |
## Model 1: | |
def sde1(x,k,d,s): | |
return (k-d*x,s) | |
Tmax=100 | |
data = simulate(sde1,params=(1,2,.1),x0=0,t_max=Tmax,n_reps=10,dt=0.01) | |
with pm.Model() as model1: | |
k=pm.Normal("k",0,3) | |
d=pm.Normal("d",0,3) | |
s=pm.Exponential("s",1) | |
xt = EulerMaruyama("xt",dt=0.01,sde_fn=sde1,sde_pars=(k,d,s),observed = data[n,::1]) | |
trace1 = pm.sample(1000,return_inferencedata=True,cores=1,chains=2) | |
az.summary(trace1) | |
## Model2 | |
def sde2(p,s): | |
N=500 | |
return s*p*(1-p)/(1+s*p),np.sqrt(p*(1-p)/N) | |
Tmax=50 | |
data = simulate(sde2,params=(0.1,),x0=0.1,t_max=Tmax,n_reps=10,dt=0.001) | |
## fix nans that come from numerical errors | |
for i,run in enumerate(data): | |
if np.isnan(run).any(): | |
last_val = np.round(run[~np.isnan(run)][-1]) | |
data[i] = np.nan_to_num(data[i],nan=last_val) | |
data[i] = np.where(data[i]>1,1,np.where(data[i]<0,0,data[i])) | |
with pm.Model() as model2: | |
s=pm.Normal("s",0,.5) | |
xt = EulerMaruyama("xt",dt=0.001,sde_fn=sde2,sde_pars=(s,),observed = data[n,::1]) | |
## multicore does not work on windows | |
## intialization needs to be performed so that the SDE starts with all values in the [0,1] interval | |
## (this is because of jitter, we could also use simply adapt_diag without jitter for initialization ) | |
trace2 = pm.sample(1000,return_inferencedata=True,cores=1,chains=2,start={"xt":np.ones(data.shape)*.5}) | |
az.summary(trace2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment