Skip to content

Instantly share code, notes, and snippets.

@itsdfish
Created October 3, 2021 19:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save itsdfish/eacb892d0cd38e58bfed5176a9663474 to your computer and use it in GitHub Desktop.
Save itsdfish/eacb892d0cd38e58bfed5176a9663474 to your computer and use it in GitHub Desktop.
Turing Hierarchical MPT
import Distributions: rand, logpdf, loglikelihood
struct Model{T1,T2} <: ContinuousUnivariateDistribution
ns::Vector{Int}
r::T1
g::T2
end
Broadcast.broadcastable(x::Model) = Ref(x)
loglikelihood(d::Model, data::Vector{Int}) = logpdf(d, data)
function logpdf(dist::Model, data::Vector{Int})
θ = compute_probs(dist.r, dist.g)
return sum(@. logpdf(Binomial(dist.ns, θ), data))
end
function compute_probs(r, g)
θt = r + (1 - r) * g
θf = g
return θt, θf
end
function rand(dist::Model)
θ = compute_probs(dist.r, dist.g)
return @. rand(Binomial(dist.ns, θ))
end
function simulate(ns, θr, κr, θg, κg)
a_r = θr * κr
b_r = (1 - θr) * κr
a_g = θg * κg
b_g = (1 - θg) * κg
θri = rand(Beta(a_r, b_r))
θgi = rand(Beta(a_g, b_g))
dist = Model(ns, θri, θgi)
return rand(dist)
end
###################################################################################################
# Load Packages
###################################################################################################
cd(@__DIR__)
using Pkg
Pkg.activate("")
using Turing, Distributions, Memoization, ReverseDiff, Random
include("functions.jl")
Random.seed!(12147)
Turing.setrdcache(true)
Turing.setadbackend(:reversediff)
###################################################################################################
# Generate Data
###################################################################################################
# number of trials per condition
ns = [100,100]
# number of subjects
n_subj = 50
# group mean for retrieval probability
θr = .8
# group concentration parameter for retrieval probability
κr = 20
# group mean for guessing "yes"
θg = .5
# group concentration parameter for guessing "yes"
κg = 20
data = [simulate(ns, θr, κr, θg, κg) for _ in 1:n_subj]
###################################################################################################
# Define Model
###################################################################################################
@model function model(ns, data)
n = length(data)
# θ: group level mean probability
# κ: group level concentration parameter
θr ~ Beta(8, 2)
κr ~ Gamma(2, 10)
θg ~ Beta(5, 5)
κg ~ Gamma(2, 10)
# transform θr and θg to parameters
# for group level Beta distributions
a_r = θr * κr
b_r = (1 - θr) * κr
a_g = θg * κg
b_g = (1 - θg) * κg
# subject level distributions for r and g parameters
r ~ filldist(Beta(a_r, b_r), n)
g ~ filldist(Beta(a_g, b_g), n)
for i in 1:n
data[i] ~ Model(ns, r[i], g[i])
end
#data .~ Model((ns,), r, g) # does not work
end
###################################################################################################
# Estimate Parameters
###################################################################################################
chains = sample(model(ns, data), NUTS(1000, .65), MCMCThreads(), 1000, 4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment