Skip to content

Instantly share code, notes, and snippets.

@sammosummo
Created March 31, 2021 20:18
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sammosummo/24d9fcd3686d3d8ffa72864a85e91ffe to your computer and use it in GitHub Desktop.
Save sammosummo/24d9fcd3686d3d8ffa72864a85e91ffe to your computer and use it in GitHub Desktop.
JAX implementation of the full DDM log likelihood function
"""JAX functions for calculating the probability density of the Wiener diffusion first-
passage time (WFPT) distribution used in drift diffusion models (DDMs).
"""
import jax
import jax.numpy as jnp
def jax_wfpt_pdf_sv(x, v, sv, a, z, t):
"""Probability density function of the WFPT distribution with drift rates normally
distributed over trials. When the standard deviation of drift-rate variability is 0,
this reduces down to the "simple" DDM likelihood function without contaminants.
Args:
x: Reaction times. Responses to the lower bound must be negative.
v: Mean drift rate.
sv: Standard deviation of drift rate. [0, inf)
a: Value of decision upper bound. (0, inf).
z: Normalized decision starting point. (0, 1).
t: Non-decision time. [0, inf)
"""
# transform v and z if x is upper-bound response
flip = x > 0
v = flip * -v + (1 - flip) * v
z = flip * (1 - z) + (1 - flip) * z
x = jnp.abs(x) # absolute rts
tt = (x - t) / a ** 2 # use normalized time
w = z # z is already normalized
err = 1e-7 # I don't think this value matters much so long as it's small
# determine number of terms needed for small-t expansion
_a = 2 * jnp.sqrt(2 * jnp.pi * tt) * err < 1
_b = 2 + jnp.sqrt(-2 * tt * jnp.log(2 * jnp.sqrt(2 * jnp.pi * tt) * err))
_c = jnp.sqrt(tt) + 1
_d = jnp.max(jnp.array([_b, _c]), axis=0)
ks = _a * _d + (1 - _a) * 2
# determine number of terms needed for large-t expansion
_a = jnp.pi * tt * err < 1
_b = 1.0 / (jnp.pi * jnp.sqrt(tt))
_c = jnp.sqrt(-2 * jnp.log(jnp.pi * tt * err) / (jnp.pi ** 2 * tt))
_d = jnp.max(jnp.array([_b, _c]), axis=0)
kl = _a * _d + (1 - _a) * _b
# probability calculated with small-t expansion
# arange might be more elegant but there were/are issues with it, apparently
ps = (w + 2 * -3) * jnp.exp(-jnp.power(w + 2 * -3, 2) / 2 / tt)
ps = ps + (w + 2 * -2) * jnp.exp(-jnp.power(w + 2 * -2, 2) / 2 / tt)
ps = ps + (w + 2 * -1) * jnp.exp(-jnp.power(w + 2 * -1, 2) / 2 / tt)
ps = ps + (w + 2 * 0) * jnp.exp(-jnp.power(w + 2 * 0, 2) / 2 / tt)
ps = ps + (w + 2 * 1) * jnp.exp(-jnp.power(w + 2 * 1, 2) / 2 / tt)
ps = ps + (w + 2 * 2) * jnp.exp(-jnp.power(w + 2 * 2, 2) / 2 / tt)
ps = ps + (w + 2 * 3) * jnp.exp(-jnp.power(w + 2 * 3, 2) / 2 / tt)
ps = ps / jnp.sqrt(2 * jnp.pi * jnp.power(tt, 3))
# probability calculated with large-t expansion
_x = jnp.power(jnp.pi, 2) * tt / 2
pl = jnp.exp(-jnp.power(1, 2) * _x) * jnp.sin(jnp.pi * w)
pl = pl + 2 * jnp.exp(-jnp.power(2, 2) * _x) * jnp.sin(2 * jnp.pi * w)
pl = pl + 3 * jnp.exp(-jnp.power(3, 2) * _x) * jnp.sin(3 * jnp.pi * w)
pl = pl + 4 * jnp.exp(-jnp.power(4, 2) * _x) * jnp.sin(4 * jnp.pi * w)
pl = pl + 5 * jnp.exp(-jnp.power(5, 2) * _x) * jnp.sin(5 * jnp.pi * w)
pl = pl + 6 * jnp.exp(-jnp.power(6, 2) * _x) * jnp.sin(6 * jnp.pi * w)
pl = pl + 7 * jnp.exp(-jnp.power(7, 2) * _x) * jnp.sin(7 * jnp.pi * w)
pl = pl * jnp.pi
# select the best expansion per element
normp = (ks < kl) * ps + (ks >= kl) * pl
# convert normalized probabilities to f(t|v,sv,a,w)
logp = jnp.log(normp)
ps = jnp.exp(
logp
+ ((a * z * sv) ** 2 - 2 * a * v * z - (v ** 2) * x)
/ (2 * (sv ** 2) * x + 2)
/ jnp.sqrt((sv ** 2) * x + 1)
/ (a ** 2)
)
return ps
def jax_wfpt_pdf_sv_sz(x, v, sv, a, lz, uz, t):
"""Probability density function of the WFPT distribution with normally distributed
drift rate and uniformly distributed starting point.
Args:
x: Reaction times. Responses to the lower bound must be negative.
v: Drift rate if sv == 0 or mean drift rate if sv > 0.
sv: Standard deviation of drift rate. [0, inf)
a: Value of upper bound. (0, inf).
lz: Lower bound on normalized starting point. (0, uz].
uz: Upper bound on normalized starting point. [l, 1).
t: Non-decision time. [0, inf)
"""
f = jax_wfpt_pdf_sv(x, v, sv, a, lz, t)
f = f + jax_wfpt_pdf_sv(x, v, sv, a, (lz + uz) / 2, t)
f = f + jax_wfpt_pdf_sv(x, v, sv, a, uz, t)
return f * (uz - lz) / 6 * (uz != lz) + f * (uz == lz)
def jax_wfpt_pdf_sv_sz_st(x, v, sv, a, lz, uz, lt, ut):
"""Probability density function of the WFPT distribution with normally distributed
drift rate, uniformly distributed starting point, uniformly distributed nondecision
time.
Args:
x: Reaction times. Responses to the lower bound must be negative.
v: Drift rate if sv == 0 or mean drift rate if sv > 0.
sv: Standard deviation of drift rate. [0, inf)
a: Value of upper bound. (0, inf).
lz: Lower bound on normalized starting point. (0, uz].
uz: Upper bound on normalized starting point. [l, 1).
lt: Lower bound on nondecision time. [0, ut].
ut: Upper bound on nondecision time. [lt, inf).
"""
f = jax_wfpt_pdf_sv_sz(x, v, sv, a, lz, uz, lt)
f = f + jax_wfpt_pdf_sv_sz(x, v, sv, a, lz, uz, (lt + ut) / 2)
f = f + jax_wfpt_pdf_sv_sz(x, v, sv, a, lz, uz, ut)
return f * (ut - lt) / 6 * (ut != lt) + f * (ut == lt)
def jax_wfpt_pdf_sv_sz_st_q(x, v, sv, a, lz, uz, lt, ut, q):
"""Probability density function of the WFPT distribution with normally distributed
drift rate, uniformly distributed starting point, uniformly distributed nondecision
time, and uniformly distributed contaminants.
Args:
x: Reaction times. Responses to the lower bound must be negative.
v: Drift rate if sv == 0 or mean drift rate if sv > 0.
sv: Standard deviation of drift rate. [0, inf)
a: Value of upper bound. (0, inf).
lz: Lower bound on normalized starting point. (0, uz].
uz: Upper bound on normalized starting point. [l, 1).
lt: Lower bound on nondecision time. [0, ut].
ut: Upper bound on nondecision time. [lt, inf).
q: Contaminant probability
"""
f = jax_wfpt_pdf_sv_sz_st(x, v, sv, a, lz, uz, lt, ut)
p = 1 / jnp.max(jnp.abs(x))
return (1 - q) * f + q * p
def jax_wfpt_sumlogp(x, v, sv, a, lz, uz, lt, ut, q):
"""Sum of log probability densities function of the WFPT distribution with normally
distributed drift rate, uniformly distributed starting point, uniformly distributed
nondecision time, and uniformly distributed contaminants.
Args:
x: Reaction times. Responses to the lower bound must be negative.
v: Drift rate if sv == 0 or mean drift rate if sv > 0.
sv: Standard deviation of drift rate. [0, inf)
a: Value of upper bound. (0, inf).
lz: Lower bound on normalized starting point. (0, uz].
uz: Upper bound on normalized starting point. [l, 1).
lt: Lower bound on nondecision time. [0, ut].
ut: Upper bound on nondecision time. [lt, inf).
q: Contaminant probability
"""
return jnp.sum(jnp.log(jax_wfpt_pdf_sv_sz_st_q(x, v, sv, a, lz, uz, lt, ut, q)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment