Created
August 24, 2022 08:38
-
-
Save p-gw/3588bbfc16a48889f7ae3b1ac40d8973 to your computer and use it in GitHub Desktop.
benchmarking turing vs stan on a simple IRT model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions, Turing, ReverseDiff, Memoization, LogExpFunctions | |
using BenchmarkTools | |
using Turing: @addlogprob! | |
using StanSample | |
Turing.setprogress!(false) | |
Turing.setadbackend(:reversediff) | |
Turing.setrdcache(true) | |
set_cmdstan_home!(homedir() * "/Applications/cmdstan/") | |
# data simulation | |
function sim(I, P) | |
yvec = Vector{Int}(undef, I * P) | |
ivec = similar(yvec) | |
pvec = similar(yvec) | |
beta = rand(Normal(), I) | |
theta = rand(Normal(), P) | |
n = 0 | |
for i in 1:I, p in 1:P | |
n += 1 | |
ivec[n] = i | |
pvec[n] = p | |
yvec[n] = rand(BernoulliLogit(theta[p] - beta[i])) | |
end | |
return yvec, ivec, pvec, theta, beta | |
end | |
# naive implementation | |
@model function irt_naive(y, i, p; I=maximum(i), P=maximum(p)) | |
theta ~ filldist(Normal(), P) | |
beta ~ filldist(Normal(), I) | |
for n in eachindex(y) | |
y[n] ~ Bernoulli(logistic(theta[p[n]] - beta[i[n]])) | |
end | |
end | |
# turing model | |
@model function irt(y, i, p; I=maximum(i), P=maximum(p)) | |
theta ~ filldist(Normal(), P) | |
beta ~ filldist(Normal(), I) | |
@addlogprob! sum(logpdf.(BernoulliLogit.(theta[p] - beta[i]), y)) | |
end | |
# stan model | |
stan_model = """ | |
data { | |
int<lower=1> I; | |
int<lower=1> P; | |
int<lower=1> N; | |
array[N] int<lower=1, upper=I> i; | |
array[N] int<lower=1, upper=P> p; | |
array[N] int<lower=0, upper=1> y; | |
} | |
parameters { | |
vector[I] beta; | |
vector[P] theta; | |
} | |
model { | |
theta ~ std_normal(); | |
beta ~ std_normal(); | |
y ~ bernoulli_logit(theta[p] - beta[i]); | |
} | |
""" | |
sm = SampleModel("irt", stan_model, homedir() * "/Documents/Projects/turing-irt-benchmark/stan") | |
sm.num_chains = 1 | |
# match stan default settings | |
alg = NUTS(1_000, 0.8; max_depth=10) | |
n_samples = 1_000 | |
function setup_stan(y, i, p) | |
I = maximum(i) | |
P = maximum(p) | |
data = Dict("I" => I, "P" => P, "N" => I * P, "i" => i, "p" => p, "y" => y) | |
return data | |
end | |
ps = [10^i for i in 2:2] | |
stan_trials = BenchmarkTools.Trial[] | |
turing_trials = BenchmarkTools.Trial[] | |
for P in ps | |
y, i, p, _, _ = sim(20, P) | |
# Turing | |
@info "Running benchmark P = $P using Turing.jl" | |
m = irt(y, i, p) | |
turing_trial = @benchmark tm = sample($m, $alg, $n_samples) | |
push!(turing_trials, turing_trial) | |
# Stan | |
@info "Running benchmark P = $P using Stan" | |
sm = SampleModel("irt", stan_model, homedir() * "/Documents/Projects/turing-irt-benchmark/stan") | |
sm.num_chains = 1 | |
stan_data = setup_stan(y, i, p) | |
stan_trial = @benchmark stan_sample($sm, data=$stan_data) | |
push!(stan_trials, stan_trial) | |
end | |
# timings on Macbook Pro M1 | |
# | |
# P = 10 | |
# Turing: 575.398 ms | |
# Stan: 169.081 ms | |
# ratio: 3.40 | |
# | |
# P = 100 | |
# Turing: 14.295 s | |
# Stan: 1.462 s | |
# ratio: 9.78 | |
# | |
# P = 1000 | |
# Turing: 150.454 s | |
# Stan: 20.029 s | |
# ratio: 7.51 | |
# | |
# P = 10000 | |
# Turing: 2293.964 s | |
# Stan: 405.192 s | |
# ratio: 5.66 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment