Created
October 10, 2021 11:33
-
-
Save itsdfish/5fc2713acedd7b0813d11a85f9bf7d17 to your computer and use it in GitHub Desktop.
Hierarchical MPT Turing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Distributions: rand, logpdf, loglikelihood | |
struct Model{T1,T2,T3} <: ContinuousUnivariateDistribution | |
n::Int | |
c::T1 | |
r::T2 | |
u::T3 | |
end | |
Broadcast.broadcastable(x::Model) = Ref(x) | |
loglikelihood(d::Model, data::Vector{Int}) = logpdf(d, data) | |
function logpdf(dist::Model, data::Vector{Int}) | |
θ = compute_probs(dist) | |
return logpdf(Multinomial(dist.n, θ), data) | |
end | |
compute_probs(dist) = compute_probs(dist.c, dist.r, dist.u) | |
function compute_probs(c::T, r, u) where {T} | |
preds = zeros(T, 4) | |
preds[1] = c * r | |
preds[2] = (1 - c) * u ^ 2 | |
preds[3] = (1 - c) * 2 * u * (1 - u) | |
preds[4] = c * (1 - r) + (1 - c) * (1- u)^2 | |
return preds | |
end | |
function rand(dist::Model) | |
θ = compute_probs(dist) | |
return rand(Multinomial(dist.n, θ)) | |
end | |
function rand_subj_parms(n_trials, μs, σs, ρ) | |
Σ = (σs * σs') .* ρ | |
θ′ = rand(MvNormal(μs, Σ), n_trials)' | |
return Φ.(θ′) | |
end | |
Φ(x) = cdf(Normal(0, 1), x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[deps] | |
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | |
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" | |
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | |
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" | |
[compat] | |
Distributions = "0.25.0" | |
Memoization = "0.1.0" | |
ReverseDiff = "1.9.0" | |
Turing = "0.18.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################################################################################### | |
# Load Packages | |
################################################################################################### | |
cd(@__DIR__) | |
using Pkg | |
Pkg.activate("") | |
using Turing, Distributions, Memoization, ReverseDiff, Random | |
include("functions.jl") | |
Random.seed!(12147) | |
Turing.setadbackend(:reversediff) | |
Turing.setrdcache(true) | |
################################################################################################### | |
# Generate Data | |
################################################################################################### | |
# number of trials per condition | |
n_trials = 100 | |
# number of subjects | |
n_subj = 200 | |
# group means in ℝ | |
μ_c = 1.0 | |
μ_r = 1.0 | |
μ_u = 0.5 | |
μs = [μ_c,μ_r,μ_u] | |
# standard deviations of group in ℝ | |
σs = fill(0.3, 3) | |
# correlation matrix of group in ℝ | |
ρ = [1 .3 .2; | |
.3 1 .2; | |
.2 .2 1] | |
subj_parms = rand_subj_parms(n_subj, μs, σs, ρ) | |
data = [rand(Model(n_trials, subj_parms[s,:]...)) for s in 1:n_subj] | |
################################################################################################### | |
# Define Model | |
################################################################################################### | |
@model function model(n_trials, data) | |
n_subj = length(data) | |
# priors for group means in ℝ | |
μ_c ~ Normal(0, 1) | |
μ_r ~ Normal(0, 1) | |
μ_u ~ Normal(0, 1) | |
μs = [μ_c, μ_r, μ_u] | |
# correlation matrix prior | |
ρ ~ LKJ(3, 1) | |
# prior on standard deviation | |
σ ~ filldist(truncated(Cauchy(0, 2.5), 0, Inf), 3) | |
# covariance matrix | |
Σ = (σ * σ') .* ρ | |
# subject parameters in ℝ | |
θ ~ filldist(MvNormal(μs, Σ), n_subj) | |
# subject parameters: ℝ → [0,1] | |
c = Φ(μ_c .+ θ[1,:]) | |
r = Φ(μ_r .+ θ[2,:]) | |
u = Φ(μ_u .+ θ[3,:]) | |
for i in 1:n_subj | |
data[i] ~ Model(n_trials, c[i], r[i], u[i]) | |
end | |
#data .~ Model.(n_trials, c, r, u) | |
end | |
################################################################################################### | |
# Estimate Parameters | |
################################################################################################### | |
chains = sample(model(n_trials, data), NUTS(1000, .65), MCMCThreads(), 1000, 4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment