-
-
Save itsdfish/eacb892d0cd38e58bfed5176a9663474 to your computer and use it in GitHub Desktop.
Turing Hierarchical MPT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Distributions: rand, logpdf, loglikelihood | |
struct Model{T1,T2} <: ContinuousUnivariateDistribution | |
ns::Vector{Int} | |
r::T1 | |
g::T2 | |
end | |
Broadcast.broadcastable(x::Model) = Ref(x) | |
loglikelihood(d::Model, data::Vector{Int}) = logpdf(d, data) | |
function logpdf(dist::Model, data::Vector{Int}) | |
θ = compute_probs(dist.r, dist.g) | |
return sum(@. logpdf(Binomial(dist.ns, θ), data)) | |
end | |
function compute_probs(r, g) | |
θt = r + (1 - r) * g | |
θf = g | |
return θt, θf | |
end | |
function rand(dist::Model) | |
θ = compute_probs(dist.r, dist.g) | |
return @. rand(Binomial(dist.ns, θ)) | |
end | |
function simulate(ns, θr, κr, θg, κg) | |
a_r = θr * κr | |
b_r = (1 - θr) * κr | |
a_g = θg * κg | |
b_g = (1 - θg) * κg | |
θri = rand(Beta(a_r, b_r)) | |
θgi = rand(Beta(a_g, b_g)) | |
dist = Model(ns, θri, θgi) | |
return rand(dist) | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################################################################################### | |
# Load Packages | |
################################################################################################### | |
cd(@__DIR__) | |
using Pkg | |
Pkg.activate("") | |
using Turing, Distributions, Memoization, ReverseDiff, Random | |
include("functions.jl") | |
Random.seed!(12147) | |
Turing.setrdcache(true) | |
Turing.setadbackend(:reversediff) | |
################################################################################################### | |
# Generate Data | |
################################################################################################### | |
# number of trials per condition | |
ns = [100,100] | |
# number of subjects | |
n_subj = 50 | |
# group mean for retrieval probability | |
θr = .8 | |
# group concentration parameter for retrieval probability | |
κr = 20 | |
# group mean for guessing "yes" | |
θg = .5 | |
# group concentration parameter for guessing "yes" | |
κg = 20 | |
data = [simulate(ns, θr, κr, θg, κg) for _ in 1:n_subj] | |
################################################################################################### | |
# Define Model | |
################################################################################################### | |
@model function model(ns, data) | |
n = length(data) | |
# θ: group level mean probability | |
# κ: group level concentration parameter | |
θr ~ Beta(8, 2) | |
κr ~ Gamma(2, 10) | |
θg ~ Beta(5, 5) | |
κg ~ Gamma(2, 10) | |
# transform θr and θg to parameters | |
# for group level Beta distributions | |
a_r = θr * κr | |
b_r = (1 - θr) * κr | |
a_g = θg * κg | |
b_g = (1 - θg) * κg | |
# subject level distributions for r and g parameters | |
r ~ filldist(Beta(a_r, b_r), n) | |
g ~ filldist(Beta(a_g, b_g), n) | |
for i in 1:n | |
data[i] ~ Model(ns, r[i], g[i]) | |
end | |
#data .~ Model((ns,), r, g) # does not work | |
end | |
################################################################################################### | |
# Estimate Parameters | |
################################################################################################### | |
chains = sample(model(ns, data), NUTS(1000, .65), MCMCThreads(), 1000, 4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment