Skip to content

Instantly share code, notes, and snippets.

@itsdfish
Created October 10, 2021 11:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save itsdfish/5fc2713acedd7b0813d11a85f9bf7d17 to your computer and use it in GitHub Desktop.
Save itsdfish/5fc2713acedd7b0813d11a85f9bf7d17 to your computer and use it in GitHub Desktop.
Hierarchical MPT Turing
import Distributions: rand, logpdf, loglikelihood
struct Model{T1,T2,T3} <: ContinuousUnivariateDistribution
n::Int
c::T1
r::T2
u::T3
end
Broadcast.broadcastable(x::Model) = Ref(x)
loglikelihood(d::Model, data::Vector{Int}) = logpdf(d, data)
function logpdf(dist::Model, data::Vector{Int})
θ = compute_probs(dist)
return logpdf(Multinomial(dist.n, θ), data)
end
compute_probs(dist) = compute_probs(dist.c, dist.r, dist.u)
function compute_probs(c::T, r, u) where {T}
preds = zeros(T, 4)
preds[1] = c * r
preds[2] = (1 - c) * u ^ 2
preds[3] = (1 - c) * 2 * u * (1 - u)
preds[4] = c * (1 - r) + (1 - c) * (1- u)^2
return preds
end
function rand(dist::Model)
θ = compute_probs(dist)
return rand(Multinomial(dist.n, θ))
end
function rand_subj_parms(n_trials, μs, σs, ρ)
Σ = (σs * σs') .* ρ
θ′ = rand(MvNormal(μs, Σ), n_trials)'
return Φ.(θ′)
end
Φ(x) = cdf(Normal(0, 1), x)
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
[compat]
Distributions = "0.25.0"
Memoization = "0.1.0"
ReverseDiff = "1.9.0"
Turing = "0.18.0"
###################################################################################################
# Load Packages
###################################################################################################
cd(@__DIR__)
using Pkg
Pkg.activate("")
using Turing, Distributions, Memoization, ReverseDiff, Random
include("functions.jl")
Random.seed!(12147)
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
###################################################################################################
# Generate Data
###################################################################################################
# number of trials per condition
n_trials = 100
# number of subjects
n_subj = 200
# group means in ℝ
μ_c = 1.0
μ_r = 1.0
μ_u = 0.5
μs = [μ_c,μ_r,μ_u]
# standard deviations of group in ℝ
σs = fill(0.3, 3)
# correlation matrix of group in ℝ
ρ = [1 .3 .2;
.3 1 .2;
.2 .2 1]
subj_parms = rand_subj_parms(n_subj, μs, σs, ρ)
data = [rand(Model(n_trials, subj_parms[s,:]...)) for s in 1:n_subj]
###################################################################################################
# Define Model
###################################################################################################
@model function model(n_trials, data)
n_subj = length(data)
# priors for group means in ℝ
μ_c ~ Normal(0, 1)
μ_r ~ Normal(0, 1)
μ_u ~ Normal(0, 1)
μs = [μ_c, μ_r, μ_u]
# correlation matrix prior
ρ ~ LKJ(3, 1)
# prior on standard deviation
σ ~ filldist(truncated(Cauchy(0, 2.5), 0, Inf), 3)
# covariance matrix
Σ = (σ * σ') .* ρ
# subject parameters in ℝ
θ ~ filldist(MvNormal(μs, Σ), n_subj)
# subject parameters: ℝ → [0,1]
c = Φ(μ_c .+ θ[1,:])
r = Φ(μ_r .+ θ[2,:])
u = Φ(μ_u .+ θ[3,:])
for i in 1:n_subj
data[i] ~ Model(n_trials, c[i], r[i], u[i])
end
#data .~ Model.(n_trials, c, r, u)
end
###################################################################################################
# Estimate Parameters
###################################################################################################
chains = sample(model(n_trials, data), NUTS(1000, .65), MCMCThreads(), 1000, 4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment