Skip to content

Instantly share code, notes, and snippets.

@sammosummo
Created March 22, 2021 22:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save sammosummo/d20de1063c11e013454d936d2ac4b1fe to your computer and use it in GitHub Desktop.
Save sammosummo/d20de1063c11e013454d936d2ac4b1fe to your computer and use it in GitHub Desktop.
Python implementation (via Numba) of the full DDM log likelihood function
"""Attempts to implement DDM likelihoods in Python.
"""
from math import pi, sqrt, log, ceil, floor, exp, sin, fabs, inf
from numba import jit, vectorize
@jit(nopython=True)
def simpson_1d(x, v, sv, a, z, t, err, lb_z, ub_z, n_sz, lb_t, ub_t, n_st):
n = max(n_st, n_sz)
if n_st == 0: # integration over z
hz = (ub_z - lb_z) / n
ht = 0
lb_t = t
ub_t = t
else: # integration over t
hz = 0
ht = (ub_t - lb_t) / n
lb_z = z
ub_z = z
s = pdf_sv(x - lb_t, v, sv, a, lb_z, err)
for i in range(1, n + 1):
z_tag = lb_z + hz * i
t_tag = lb_t + ht * i
y = pdf_sv(x - t_tag, v, sv, a, z_tag, err)
if i & 1: # check if i is odd
s += 4 * y
else:
s += 2 * y
s = s - y # the last term should be f(b) and not 2*f(b) so we subtract y
s = s / (
(ub_t - lb_t) + (ub_z - lb_z)
) # the right function if pdf_sv()/sz or pdf_sv()/st
return (ht + hz) * s / 3
@jit(nopython=True)
def simpson_2d(x, v, sv, a, z, t, err, lb_z, ub_z, n_sz, lb_t, ub_t, n_st):
ht = (ub_t - lb_t) / n_st
s = simpson_1d(x, v, sv, a, z, lb_t, err, lb_z, ub_z, n_sz, 0, 0, 0)
for i_t in range(1, n_st + 1):
t_tag = lb_t + ht * i_t
y = simpson_1d(x, v, sv, a, z, t_tag, err, lb_z, ub_z, n_sz, 0, 0, 0)
if i_t & 1: # check if i is odd
s += 4 * y
else:
s += 2 * y
s = s - y # the last term should be f(b) and not 2*f(b) so we subtract y
s = s / (ub_t - lb_t)
return ht * s / 3
@jit(nopython=True)
def adapt_simpson_aux(
x,
v,
sv,
a,
z,
t,
pdf_err,
lb_z,
ub_z,
lb_t,
ub_t,
ZT,
simps_err,
S,
f_beg,
f_end,
f_mid,
bottom,
):
if (ub_t - lb_t) == 0: # integration over sz
h = ub_z - lb_z
z_c = (ub_z + lb_z) / 2.0
z_d = (lb_z + z_c) / 2.0
z_e = (z_c + ub_z) / 2.0
t_c = t
t_d = t
t_e = t
else: # integration over t
h = ub_t - lb_t
t_c = (ub_t + lb_t) / 2.0
t_d = (lb_t + t_c) / 2.0
t_e = (t_c + ub_t) / 2.0
z_c = z
z_d = z
z_e = z
fd = pdf_sv(x - t_d, v, sv, a, z_d, pdf_err) / ZT
fe = pdf_sv(x - t_e, v, sv, a, z_e, pdf_err) / ZT
Sleft = (h / 12) * (f_beg + 4 * fd + f_mid)
Sright = (h / 12) * (f_mid + 4 * fe + f_end)
S2 = Sleft + Sright
if bottom <= 0 or fabs(S2 - S) <= 15 * simps_err:
return S2 + (S2 - S) / 15
return adapt_simpson_aux(
x,
v,
sv,
a,
z,
t,
pdf_err,
lb_z,
z_c,
lb_t,
t_c,
ZT,
simps_err / 2,
Sleft,
f_beg,
f_mid,
fd,
bottom - 1,
) + adapt_simpson_aux(
x,
v,
sv,
a,
z,
t,
pdf_err,
z_c,
ub_z,
t_c,
ub_t,
ZT,
simps_err / 2,
Sright,
f_mid,
f_end,
fe,
bottom - 1,
)
@jit(nopython=True)
def adapt_simpson_1d(
x, v, sv, a, z, t, pdf_err, lb_z, ub_z, lb_t, ub_t, simps_err, maxRecursionDepth
):
if (ub_t - lb_t) == 0: # integration over z
lb_t = t
ub_t = t
h = ub_z - lb_z
else: # integration over t
h = ub_t - lb_t
lb_z = z
ub_z = z
ZT = h
c_t = (lb_t + ub_t) / 2.0
c_z = (lb_z + ub_z) / 2.0
f_beg = pdf_sv(x - lb_t, v, sv, a, lb_z, pdf_err) / ZT
f_end = pdf_sv(x - ub_t, v, sv, a, ub_z, pdf_err) / ZT
f_mid = pdf_sv(x - c_t, v, sv, a, c_z, pdf_err) / ZT
S = (h / 6) * (f_beg + 4 * f_mid + f_end)
res = adapt_simpson_aux(
x,
v,
sv,
a,
z,
t,
pdf_err,
lb_z,
ub_z,
lb_t,
ub_t,
ZT,
simps_err,
S,
f_beg,
f_end,
f_mid,
maxRecursionDepth,
)
return res
@jit(nopython=True)
def adapt_simpson_aux_2d(
x,
v,
sv,
a,
z,
t,
pdf_err,
err_1d,
lb_z,
ub_z,
lb_t,
ub_t,
st,
err_2d,
S,
f_beg,
f_end,
f_mid,
maxRecursionDepth_sz,
bottom,
):
t_c = (ub_t + lb_t) / 2.0
t_d = (lb_t + t_c) / 2.0
t_e = (t_c + ub_t) / 2.0
h = ub_t - lb_t
fd = (
adapt_simpson_1d(
x, v, sv, a, z, t_d, pdf_err, lb_z, ub_z, 0, 0, err_1d, maxRecursionDepth_sz
)
/ st
)
fe = (
adapt_simpson_1d(
x, v, sv, a, z, t_e, pdf_err, lb_z, ub_z, 0, 0, err_1d, maxRecursionDepth_sz
)
/ st
)
Sleft = (h / 12) * (f_beg + 4 * fd + f_mid)
Sright = (h / 12) * (f_mid + 4 * fe + f_end)
S2 = Sleft + Sright
if bottom <= 0 or fabs(S2 - S) <= 15 * err_2d:
return S2 + (S2 - S) / 15
return adapt_simpson_aux_2d(
x,
v,
sv,
a,
z,
t,
pdf_err,
err_1d,
lb_z,
ub_z,
lb_t,
t_c,
st,
err_2d / 2,
Sleft,
f_beg,
f_mid,
fd,
maxRecursionDepth_sz,
bottom - 1,
) + adapt_simpson_aux_2d(
x,
v,
sv,
a,
z,
t,
pdf_err,
err_1d,
lb_z,
ub_z,
t_c,
ub_t,
st,
err_2d / 2,
Sright,
f_mid,
f_end,
fe,
maxRecursionDepth_sz,
bottom - 1,
)
@jit(nopython=True)
def adapt_simpson_2d(
x,
v,
sv,
a,
z,
t,
pdf_err,
lb_z,
ub_z,
lb_t,
ub_t,
simps_err,
maxRecursionDepth_sz,
maxRecursionDepth_st,
):
h = ub_t - lb_t
st = ub_t - lb_t
c_t = (lb_t + ub_t) / 2.0
c_z = (lb_z + ub_z) / 2.0
err_1d = simps_err
err_2d = simps_err
f_beg = (
adapt_simpson_1d(
x,
v,
sv,
a,
z,
lb_t,
pdf_err,
lb_z,
ub_z,
0,
0,
err_1d,
maxRecursionDepth_sz,
)
/ st
)
f_end = (
adapt_simpson_1d(
x,
v,
sv,
a,
z,
ub_t,
pdf_err,
lb_z,
ub_z,
0,
0,
err_1d,
maxRecursionDepth_sz,
)
/ st
)
f_mid = (
adapt_simpson_1d(
x,
v,
sv,
a,
z,
(lb_t + ub_t) / 2,
pdf_err,
lb_z,
ub_z,
0,
0,
err_1d,
maxRecursionDepth_sz,
)
/ st
)
S = (h / 6) * (f_beg + 4 * f_mid + f_end)
res = adapt_simpson_aux_2d(
x,
v,
sv,
a,
z,
t,
pdf_err,
err_1d,
lb_z,
ub_z,
lb_t,
ub_t,
st,
err_2d,
S,
f_beg,
f_end,
f_mid,
maxRecursionDepth_sz,
maxRecursionDepth_st,
)
return res
@jit(nopython=True)
def ftt_01w(tt, w, err=1e-4):
"""Compute f(t|0,1,w) according to Navarro and Fuss (2009)."""
# calculate number of terms needed for large t
if pi * tt * err < 1: # if error threshold is set low enough
kl = sqrt(-2 * log(pi * tt * err) / (pi ** 2 * tt)) # bound
kl = max(kl, 1.0 / (pi * sqrt(tt))) # ensure boundary conditions met
else: # if error threshold set too high
kl = 1.0 / (pi * sqrt(tt)) # set to boundary condition
# calculate number of terms needed for small t
if 2 * sqrt(2 * pi * tt) * err < 1: # if error threshold is set low enough
ks = 2 + sqrt(-2 * tt * log(2 * sqrt(2 * pi * tt) * err)) # bound
ks = max(ks, sqrt(tt) + 1) # ensure boundary conditions are met
else: # if error threshold was set too high
ks = 2 # minimal kappa for that case
# compute f(tt|0,1,w)
p = 0 # initialize density
if ks < kl: # if small t is better (i.e., lambda<0) ...
K = ceil(ks) # round to smallest integer meeting error
lower = -floor((K - 1) / 2.0)
upper = ceil((K - 1) / 2.0)
for k in range(lower, upper + 1): # loop over k
p += (w + 2 * k) * exp(-(pow((w + 2 * k), 2)) / 2 / tt) # increment sum
p /= sqrt(2 * pi * pow(tt, 3)) # add constant term
else: # if large t is better ...
K = ceil(kl) # round to smallest integer meeting error
for k in range(1, K + 1):
p += (
k * exp(-(pow(k, 2)) * (pi ** 2) * tt / 2) * sin(k * pi * w)
) # increment sum
p *= pi # add constant term
return p
@jit(nopython=True)
def pdf(x, v, a, w, err=1e-4):
"""Compute f(t|v,a,z) according to Navarro and Fuss (2009)."""
# time must be positive
if x <= 0:
return 0
tt = x / a ** 2 # use normalized time
p = ftt_01w(tt, w, err) # get f(t|0,1,w)
# convert to f(t|v,a,w)
return p * exp(-v * a * w - (pow(v, 2)) * x / 2.0) / (pow(a, 2))
@jit(nopython=True)
def pdf_sv(x, v, sv, a, z, err=1e-4):
"""Compute f(t|v,a,z,sv) using the method of Navarro and Fuss (2009), with analytic
integration of v according to Tuerlinckx et al. (2001)."""
# time must be positive
if x <= 0:
return 0
# if sv=0 don't integrate
if sv == 0:
return pdf(x, v, a, z, err)
tt = x / (pow(a, 2)) # use normalized time
p = ftt_01w(tt, z, err) # get f(t|0,1,w)
# TODO: hack to prevent math domain error; fix this!
if p == 0:
logp = -500
else:
logp = log(p)
# convert to f(t|v,a,w)
return (
exp(
logp
+ ((a * z * sv) ** 2 - 2 * a * v * z - (v ** 2) * x)
/ (2 * (sv ** 2) * x + 2)
)
/ sqrt((sv ** 2) * x + 1)
/ (a ** 2)
)
@jit(nopython=True)
def full_pdf(
x, v, sv, a, z, sz, t, st, err=1e-4, n_st=2, n_sz=2, use_adaptive=1, simps_err=1e-3
):
"""Compute the probability density function of the full drift diffusion model.
Computes f(t|v,a,z,sv,sz,st) using the method of Navarro and Fuss (2009) to compute
the basic DDM likelihood, analytic integration of v when sv is non-zero (Tuerlinckx
et al., 2001), and numeric of t_er and/or z when st and/or sz are non-zero,
respectively (Ratcliff & Tuerlinckx, 2002).
This function excepts negative or positive reaction times. Negative values
correspond to the lower bound whereas positive responses correspond to the upper
bound. Drift rates are flipped and biases are inverted for upper-bound responses.
"""
# transform x, v, z if x is upper bound response
if x > 0:
v = -v
z = 1.0 - z
# absolute RT
x = fabs(x)
# set st and sz to 0 if really small
if st < 1e-3:
st = 0
if sz < 1e-3:
sz = 0
if sv < 1e-3:
sv = 0
if sz == 0:
if st == 0: # sv=0,sz=0,st=0
return pdf_sv(x - t, v, sv, a, z, err)
else: # sv=0,sz=0,st=$
if use_adaptive > 0:
return adapt_simpson_1d(
x,
v,
sv,
a,
z,
t,
err,
z,
z,
t - st / 2.0,
t + st / 2.0,
simps_err,
n_st,
)
else:
return simpson_1d(
x, v, sv, a, z, t, err, z, z, 0, t - st / 2.0, t + st / 2.0, n_st
)
else: # sz=$
if st == 0: # sv=0,sz=$,st=0
if use_adaptive:
return adapt_simpson_1d(
x,
v,
sv,
a,
z,
t,
err,
z - sz / 2.0,
z + sz / 2.0,
t,
t,
simps_err,
n_sz,
)
else:
return simpson_1d(
x, v, sv, a, z, t, err, z - sz / 2.0, z + sz / 2.0, n_sz, t, t, 0
)
else: # sv=0,sz=$,st=$
if use_adaptive:
return adapt_simpson_2d(
x,
v,
sv,
a,
z,
t,
err,
z - sz / 2.0,
z + sz / 2.0,
t - st / 2.0,
t + st / 2.0,
simps_err,
n_sz,
n_st,
)
else:
return simpson_2d(
x,
v,
sv,
a,
z,
t,
err,
z - sz / 2.0,
z + sz / 2.0,
n_sz,
t - st / 2.0,
t + st / 2.0,
n_st,
)
@jit(nopython=True)
def pdf_contaminant_uniform(x, w=0.05):
"""Compute the probability density of a uniform contaminant distribution."""
if -(0.5 / w) <= x <= (0.5 / w):
return w
else:
return 0
@jit(nopython=True)
def pdf_contaminant_exponential(x, l):
"""Compute the probability density of an exponential contaminant distribution."""
return l * exp(-l * fabs(x))
@vectorize
def logpdf_with_contaminant_exponential(x, v, sv, a, z, sz, t, st, p_outlier, l):
"""Compute the log likelihood of the full DDM mixed with exponential contaminant
distribution."""
# check if all parameters are valid
if (
(z < 0)
or (z > 1)
or (a < 0)
or (t < 0)
or (st < 0)
or (sv < 0)
or (sz < 0)
or (sz > 1)
or (z + sz / 2.0 > 1)
or (z - sz / 2.0 < 0)
or (t - st / 2.0 < 0)
or (p_outlier < 0)
or (p_outlier > 1)
):
return -inf
if p_outlier == 0:
return log(full_pdf(x, v, sv, a, z, sz, t, st))
else:
if l <= 0:
return -inf
p0 = full_pdf(x, v, sv, a, z, sz, t, st) * (1 - p_outlier)
p1 = pdf_contaminant_exponential(x, l) * p_outlier
return log(p0 + p1)
def test():
from scipy.stats import lognorm
import numpy as np
np.random.seed(0)
x = lognorm.rvs(s=0.5, size=10000)
x = np.concatenate([x , -x[:1000]])
v = -1.1
sv = 0.01
a = 1.
z = 0.5
sz = 0.01
t = 0.5
st = 0.01
p_outlier = 0.001
l = 0.1
for i, _x in enumerate(x):
print("i =", i)
print("x =", _x)
y = logpdf_with_contaminant_exponential(_x, v, sv, a, z, sz, t, st, p_outlier, l)
print("y =", y)
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment