Skip to content

Instantly share code, notes, and snippets.

@tpaixao
Created October 11, 2022 09:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tpaixao/9475fd4121bab80bff9dbe7c83f7f8c7 to your computer and use it in GitHub Desktop.
Save tpaixao/9475fd4121bab80bff9dbe7c83f7f8c7 to your computer and use it in GitHub Desktop.
SDE examples in pymc3
import numpy as np
import pymc3 as pm
import arviz as az
import theano.tensor as tt
from pymc3.distributions.timeseries import EulerMaruyama
import matplotlib.pyplot as plt
def simulate(sde_fun,params,x0,t_max,n_reps,dt=0.01):
"""
simulates sde_fun
:sde_fun: function that returns deterministic and stochastic parts of SDE
:params: parameters for sde_fun
:x0: initial condition (scalar - all replicates will start from x0)
:t_max: maximum time
:n_reps: number of replicates of timeseries - will be stacked
:dt: delta t for discretization
"""
tts=np.arange(0,t_max,dt)
data = np.zeros((n_reps,tts.shape[0],))
data[:,0] = np.ones(n_reps)*x0
for i in range(tts.shape[0]-1):
x=data[:,i]
f,g = sde_fun(x,*params)
x+=f*dt+np.sqrt(dt)*g*np.random.normal(size=(n_reps,))
data[:,i+1] = x
return data
## Model 1:
def sde1(x,k,d,s):
return (k-d*x,s)
Tmax=100
data = simulate(sde1,params=(1,2,.1),x0=0,t_max=Tmax,n_reps=10,dt=0.01)
with pm.Model() as model1:
k=pm.Normal("k",0,3)
d=pm.Normal("d",0,3)
s=pm.Exponential("s",1)
xt = EulerMaruyama("xt",dt=0.01,sde_fn=sde1,sde_pars=(k,d,s),observed = data[n,::1])
trace1 = pm.sample(1000,return_inferencedata=True,cores=1,chains=2)
az.summary(trace1)
## Model2
def sde2(p,s):
N=500
return s*p*(1-p)/(1+s*p),np.sqrt(p*(1-p)/N)
Tmax=50
data = simulate(sde2,params=(0.1,),x0=0.1,t_max=Tmax,n_reps=10,dt=0.001)
## fix nans that come from numerical errors
for i,run in enumerate(data):
if np.isnan(run).any():
last_val = np.round(run[~np.isnan(run)][-1])
data[i] = np.nan_to_num(data[i],nan=last_val)
data[i] = np.where(data[i]>1,1,np.where(data[i]<0,0,data[i]))
with pm.Model() as model2:
s=pm.Normal("s",0,.5)
xt = EulerMaruyama("xt",dt=0.001,sde_fn=sde2,sde_pars=(s,),observed = data[n,::1])
## multicore does not work on windows
## intialization needs to be performed so that the SDE starts with all values in the [0,1] interval
## (this is because of jitter, we could also use simply adapt_diag without jitter for initialization )
trace2 = pm.sample(1000,return_inferencedata=True,cores=1,chains=2,start={"xt":np.ones(data.shape)*.5})
az.summary(trace2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment