Skip to content

Instantly share code, notes, and snippets.

@p-gw
Created August 24, 2022 08:38
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save p-gw/3588bbfc16a48889f7ae3b1ac40d8973 to your computer and use it in GitHub Desktop.
Save p-gw/3588bbfc16a48889f7ae3b1ac40d8973 to your computer and use it in GitHub Desktop.
benchmarking turing vs stan on a simple IRT model
using Distributions, Turing, ReverseDiff, Memoization, LogExpFunctions
using BenchmarkTools
using Turing: @addlogprob!
using StanSample
Turing.setprogress!(false)
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
set_cmdstan_home!(homedir() * "/Applications/cmdstan/")
# data simulation
function sim(I, P)
yvec = Vector{Int}(undef, I * P)
ivec = similar(yvec)
pvec = similar(yvec)
beta = rand(Normal(), I)
theta = rand(Normal(), P)
n = 0
for i in 1:I, p in 1:P
n += 1
ivec[n] = i
pvec[n] = p
yvec[n] = rand(BernoulliLogit(theta[p] - beta[i]))
end
return yvec, ivec, pvec, theta, beta
end
# naive implementation
@model function irt_naive(y, i, p; I=maximum(i), P=maximum(p))
theta ~ filldist(Normal(), P)
beta ~ filldist(Normal(), I)
for n in eachindex(y)
y[n] ~ Bernoulli(logistic(theta[p[n]] - beta[i[n]]))
end
end
# turing model
@model function irt(y, i, p; I=maximum(i), P=maximum(p))
theta ~ filldist(Normal(), P)
beta ~ filldist(Normal(), I)
@addlogprob! sum(logpdf.(BernoulliLogit.(theta[p] - beta[i]), y))
end
# stan model
stan_model = """
data {
int<lower=1> I;
int<lower=1> P;
int<lower=1> N;
array[N] int<lower=1, upper=I> i;
array[N] int<lower=1, upper=P> p;
array[N] int<lower=0, upper=1> y;
}
parameters {
vector[I] beta;
vector[P] theta;
}
model {
theta ~ std_normal();
beta ~ std_normal();
y ~ bernoulli_logit(theta[p] - beta[i]);
}
"""
sm = SampleModel("irt", stan_model, homedir() * "/Documents/Projects/turing-irt-benchmark/stan")
sm.num_chains = 1
# match stan default settings
alg = NUTS(1_000, 0.8; max_depth=10)
n_samples = 1_000
function setup_stan(y, i, p)
I = maximum(i)
P = maximum(p)
data = Dict("I" => I, "P" => P, "N" => I * P, "i" => i, "p" => p, "y" => y)
return data
end
ps = [10^i for i in 2:2]
stan_trials = BenchmarkTools.Trial[]
turing_trials = BenchmarkTools.Trial[]
for P in ps
y, i, p, _, _ = sim(20, P)
# Turing
@info "Running benchmark P = $P using Turing.jl"
m = irt(y, i, p)
turing_trial = @benchmark tm = sample($m, $alg, $n_samples)
push!(turing_trials, turing_trial)
# Stan
@info "Running benchmark P = $P using Stan"
sm = SampleModel("irt", stan_model, homedir() * "/Documents/Projects/turing-irt-benchmark/stan")
sm.num_chains = 1
stan_data = setup_stan(y, i, p)
stan_trial = @benchmark stan_sample($sm, data=$stan_data)
push!(stan_trials, stan_trial)
end
# timings on Macbook Pro M1
#
# P = 10
# Turing: 575.398 ms
# Stan: 169.081 ms
# ratio: 3.40
#
# P = 100
# Turing: 14.295 s
# Stan: 1.462 s
# ratio: 9.78
#
# P = 1000
# Turing: 150.454 s
# Stan: 20.029 s
# ratio: 7.51
#
# P = 10000
# Turing: 2293.964 s
# Stan: 405.192 s
# ratio: 5.66
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment