Last active
April 23, 2024 13:03
-
-
Save mike-lawrence/39412616302925a1941ad92777daa036 to your computer and use it in GitHub Desktop.
Stan vs LME4 for hiearchical within-subjects designs with binomial outcomes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Installs any packages not already installed | |
#' @examples | |
#' \dontrun{ | |
#' install_if_missing(c('tidyverse','github.com/stan-dev/cmdstanr')) | |
#' } | |
install_if_missing = function(pkgs){ | |
missing_pkgs = NULL | |
for(this_pkg in pkgs){ | |
path = NULL | |
try( | |
path <- find.package(basename(this_pkg),quiet=T,verbose=F) | |
, silent = T | |
) | |
if(is.null(path)){ | |
missing_pkgs = c(missing_pkgs,this_pkg) | |
} | |
} | |
cran_missing = missing_pkgs[!grepl('github.com/',fixed=T,missing_pkgs)] | |
if(length(cran_missing)>0){ | |
message('The following required but uninstalled CRAN packages will now be installed:\n',paste(cran_missing,collapse='\n')) | |
install.packages(cran_missing) | |
} | |
github_missing = missing_pkgs[grepl('github.com/',fixed=T,missing_pkgs)] | |
github_missing = gsub('github.com/','',github_missing) | |
if(length(github_missing)>0){ | |
message('The following required but uninstalled Github packages will now be installed:\n',paste(this_pkg,collapse='\n')) | |
remotes::install_github(github_missing) | |
} | |
invisible() | |
} | |
#define a function that adds diagnostics as a metadata attribute | |
add_diagnostic_bools = function(x,fit){ | |
sink('/dev/null') | |
diagnostics = fit$cmdstan_diagnose()$stdout #annoyingly not quiet-able | |
sink(NULL) | |
diagnostic_bools = list( | |
treedepth_maxed = stringr::str_detect(diagnostics,'transitions hit the maximum') | |
, ebfmi_low = stringr::str_detect(diagnostics,' is below the nominal threshold') | |
, essp_low = stringr::str_detect(diagnostics,'The following parameters had fewer than ') | |
, rhat_high = stringr::str_detect(diagnostics,'The following parameters had split R-hat greater than') | |
) | |
attr(x,'meta') = list(diagnostic_bools=diagnostic_bools) | |
return(x) | |
} | |
#define a custom print method that shows the diagnostics in the metadata attribute | |
print.stan_summary_tbl = function(x,...) { | |
meta = attr(x,'meta') | |
if(any(unlist(meta$diagnostic_bools))){ | |
cat(crayon::bgRed('WARNING:\n')) | |
} | |
if(meta$diagnostic_bools$treedepth_maxed){ | |
cat(crayon::bgRed('Treedepth maxed\n')) | |
} | |
if(meta$diagnostic_bools$ebfmi_low){ | |
cat(crayon::bgRed('E-BMFI low\n')) | |
} | |
if(meta$diagnostic_bools$essp_low){ | |
cat(crayon::bgRed('ESS% low for one or more parameters\n')) | |
} | |
if(meta$diagnostic_bools$rhat_high){ | |
cat(crayon::bgRed('R-hat high for one or more parameters\n')) | |
} | |
NextMethod(x,...) | |
invisible(x) | |
} | |
#create a new S3 class for custom-printing stanfit summary tables | |
add_stan_summary_tbl_class = function(x){ | |
class(x) <- c("stan_summary_tbl",class(x)) | |
return(x) | |
} | |
#function to detect whether a variable name indicates that it's on the diagonal | |
# of a correlation parameter matrix | |
has_underscore_suffix = function(x){ | |
bare_has_underscore_suffix = stringr::str_ends(x,'_') | |
has_index_suffix = stringr::str_ends(x,']') | |
indexed_has_underscore_suffix = | |
( | |
tibble::tibble(x = x[has_index_suffix]) | |
%>% tidyr::separate(x,sep='\\[',into='x',extra='drop') | |
%>% dplyr::mutate( out=stringr::str_ends(x,'_') ) | |
%>% dplyr::pull(out) | |
) | |
bare_has_underscore_suffix[has_index_suffix] = indexed_has_underscore_suffix | |
return(bare_has_underscore_suffix) | |
} | |
#function to detect whether a variable name indicates that it's on the diagonal | |
# or lower-tri element of a correlation parameter matrix | |
is_cor_diag_or_lower_tri = function(x,prefix){ | |
has_prefix = stringr::str_starts(x,prefix) | |
x = x[has_prefix] | |
del = function(x,to_del){gsub(to_del,'',x,fixed=T)} | |
to_toss = | |
( | |
x | |
%>% del(prefix) | |
%>% del('[') | |
%>% del(']') | |
%>% tibble::tibble(x = .) | |
%>% tidyr::separate(x,into=c('i','j')) | |
%>% dplyr::mutate( to_toss = (i==j) | (i>j) ) | |
%>% dplyr::pull(to_toss) | |
) | |
has_prefix[has_prefix] = to_toss | |
return(has_prefix) | |
} | |
#function to sort a stan summary table by size of variables | |
sort_by_variable_size = function(x){ | |
x2 = | |
( | |
x | |
%>% tidyr::separate( | |
variable | |
, sep = '\\[' | |
, into = 'var' | |
, extra = 'drop' | |
, remove = F | |
) | |
) | |
( | |
x2 | |
%>% dplyr::group_by(var) | |
%>% dplyr::summarise(count = dplyr::n(),.groups = 'drop') | |
%>% dplyr::full_join(x2,by='var') | |
%>% dplyr::arrange(count,var,variable) | |
%>% dplyr::select(-count,-var) | |
) | |
} | |
halfsum_contrasts = function(...){ | |
contr.sum(...)*.5 | |
} | |
get_contrast_matrix = function( | |
data | |
, formula | |
, contrast_kind = NULL | |
){ | |
if (inherits(data, "tbl_df")) { | |
data = as.data.frame(data) | |
} | |
vars = attr(terms(formula),'term.labels') | |
vars = vars[!grepl(':',vars)] | |
if(length(vars)==1){ | |
data = data.frame(data[,vars]) | |
names(data) = vars | |
}else{ | |
data = data[,vars] | |
} | |
vars_to_rename = NULL | |
for(i in vars){ | |
if(is.character(data[,i])){ | |
data[,i] = factor(data[,i]) | |
} | |
if( is.factor(data[,i])){ | |
if(length(levels(data[,i]))==2){ | |
vars_to_rename = c(vars_to_rename,i) | |
} | |
if(!is.null(contrast_kind) ){ | |
contrasts(data[,i]) = contrast_kind | |
} | |
} | |
} | |
mm = model.matrix(data=data,object=formula) | |
dimnames(mm)[[2]][dimnames(mm)[[2]]=='(Intercept)'] = '(I)' | |
for(i in vars_to_rename){ | |
dimnames(mm)[[2]] = gsub(paste0(i,1),i,dimnames(mm)[[2]]) | |
} | |
attr(mm,'formula') = formula | |
attr(mm,'data') = data | |
return(mm) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#preamble (options, installs, imports & custom functions) ---- | |
options(warn=1) #really should be default in R | |
`%!in%` = Negate(`%in%`) #should be in base R! | |
# specify the packages used: | |
required_packages = c( | |
'rethinking' #for rlkjcorr & rmvrnom2 | |
, 'crayon' #for coloring terminal output | |
, 'bayesplot' #for convenient posterior plots | |
, 'github.com/stan-dev/cmdstanr' #for Stan stuff | |
, 'tidyverse' #for all that is good and holy | |
) | |
#load the helper functions: | |
source('helper_functions.r') | |
#helper_functions.r defines: | |
# - install_if_missing() | |
# - add_diagnostic_bools() | |
# - print.stan_summary_tbl() | |
# - add_stan_summary_tbl_class() | |
# - is_cor_diag() | |
# - halfsum_contrasts() | |
# - get_contrast_matrix() | |
#install any required packages not already present | |
install_if_missing(required_packages) | |
# define a shorthand for the pipe operator | |
`%>%` = magrittr::`%>%` | |
#simulate data ---- | |
set.seed(1) #change this to make different data | |
#setting the data simulation parameters | |
sim_pars = tibble::lst( | |
#parameters you can play with | |
num_subj = 100 #number of subjects, must be an integer >1 | |
, num_vars = 3 #number of 2-level variables manipulated as crossed and within each subject, must be an integer >0 | |
, num_trials = 100 #number of trials per subject/condition combo, must be an integer >1 | |
#the rest of these you shouldn't touch | |
, num_coef = 2^(num_vars) | |
, coef_means = rnorm(num_coef) | |
, coef_sds = rweibull(num_coef,2,1) | |
, cor_mat = rethinking::rlkjcorr(1,num_coef,eta=1) | |
) | |
#compute the contrast matrix | |
contrast_matrix = | |
( | |
1:sim_pars$num_vars | |
%>% purrr::map(.f=function(x){ | |
factor(c('lo','hi')) | |
}) | |
%>% (function(x){ | |
names(x) = paste0('v',1:sim_pars$num_vars) | |
return(x) | |
}) | |
%>% purrr::cross_df() | |
%>% (function(x){ | |
get_contrast_matrix( | |
data = x | |
, formula = as.formula(paste('~',paste0('v',1:sim_pars$num_vars,collapse='*'))) | |
, contrast_kind = halfsum_contrasts | |
) | |
}) | |
) | |
#get coefficients for each subject | |
subj_coef = | |
( | |
#subj coefs as mvn | |
rethinking::rmvnorm2( | |
n = sim_pars$num_subj | |
, Mu = sim_pars$coef_means | |
, sigma = sim_pars$coef_sds | |
, Rho = sim_pars$cor_mat | |
) | |
#add names to columns | |
%>% (function(x){ | |
dimnames(x)=list(NULL,paste0('X',1:ncol(x))) | |
return(x) | |
}) | |
#make a tibble | |
%>% tibble::as_tibble(.name_repair='unique') | |
#add subject identifier column | |
%>% dplyr::mutate( | |
subj = 1:sim_pars$num_subj | |
) | |
) | |
# get condition means implied by subject coefficients and contrast matrix | |
subj_cond = | |
( | |
subj_coef | |
%>% dplyr::group_by(subj) | |
%>% dplyr::summarise( | |
(function(x){ | |
out = attr(contrast_matrix,'data') | |
out$cond_mean = as.vector(contrast_matrix %*% t(x)) | |
return(out) | |
})(dplyr::cur_data()) | |
, .groups = 'drop' | |
) | |
) | |
# get noisy measurements in each condition for each subject | |
dat = | |
( | |
subj_cond | |
%>% tidyr::expand_grid(trial = 1:sim_pars$num_trials) | |
%>% dplyr::mutate( | |
obs = rbinom(dplyr::n(),1,plogis(cond_mean)) | |
) | |
%>% dplyr::select(-cond_mean) | |
) | |
# Try glmer ---- | |
#annoyingly need this to set contrasts | |
dat_as_df = as.data.frame(dat) | |
#first apply half-sum contrasts as will be used in the Stan model | |
for(i in 1:sim_pars$num_vars){ | |
dat_as_df[,names(dat)==paste0('v',i)] = factor(dat_as_df[,names(dat)==paste0('v',i)]) | |
contrasts(dat_as_df[,names(dat)==paste0('v',i)]) = halfsum_contrasts | |
} | |
#compute the formula string | |
vars_string = paste0('v',1:sim_pars$num_vars,collapse='*') | |
formula_string = paste0('obs~1+',vars_string,'+(1+',vars_string,'|subj)') | |
#fit while timing | |
system.time( | |
fit_glmer <- lme4::glmer( | |
data = dplyr::mutate(dat,subj = factor(subj)) | |
, formula = as.formula(formula_string) | |
, family = binomial(link='logit') | |
, control = lme4::glmerControl( | |
optCtrl = list(maxfun=1e6) | |
, optimizer = 'bobyqa' | |
) | |
) | |
) | |
# Compute inputs to Stan model ---- | |
#first collapse to subject/condition stats | |
dat_summary = | |
( | |
dat | |
%>% dplyr::group_by( | |
dplyr::across(c( | |
-obs | |
, -trial | |
)) | |
) | |
%>% dplyr::summarise( | |
num_obs = dplyr::n() | |
, sum_obs = sum(obs) | |
, .groups = 'drop' | |
) | |
) | |
#W: the full trial-by-trial contrast matrix | |
W = | |
( | |
dat_summary | |
#get the contrast matrix (wrapper on stats::model.matrix) | |
%>% get_contrast_matrix( | |
# the following compilcated specification of the formula is a by-product of making this example | |
# work for any value for sim_pars$num_vars; normally you would do something like this | |
# (for 2 variables for example): | |
# formula = ~ v1*v2 | |
formula = as.formula(paste0('~',paste0('v',1:sim_pars$num_vars,collapse='*'))) | |
# half-sum contrasts are nice for 2-level variables bc they yield parameters whose value | |
# is the difference between conditions | |
, contrast_kind = halfsum_contrasts | |
) | |
#convert to tibble | |
%>% tibble::as_tibble(.name_repair='unique') | |
) | |
#quick glimpse; lots of rows | |
print(W) | |
# get the unique entries in W | |
uW = dplyr::distinct(W) | |
print(uW) | |
#far fewer rows! | |
#for each unique condition specified by uW, the stan model will | |
# work out values for that condition for each subject, and we'll need to index | |
# into the resulting subject-by-condition matrix. So we need to create our own | |
# subject-by-condition matrix and get the indices of the observed data into a | |
# the array produced when that matrix is flattened. | |
uW_per_subj = | |
( | |
uW | |
#first repeat the matrix so there's a copy for each subject | |
%>% dplyr::slice( | |
rep( | |
dplyr::row_number() | |
, length(unique(dat_summary$subj)) | |
) | |
) | |
#now add the subject labels | |
%>% dplyr::mutate( | |
subj = rep(sort(unique(dat_summary$subj)),each=nrow(uW)) | |
) | |
#add row identifier | |
%>% dplyr::mutate( | |
row = 1:dplyr::n() | |
) | |
) | |
obs_index = | |
( | |
W | |
#add the subject column | |
%>% dplyr::mutate(subj=dat_summary$subj) | |
# join to the full contrast matrix W | |
%>% dplyr::left_join( | |
uW_per_subj | |
, by = c(names(uW),'subj') | |
) | |
#pull the row label | |
%>% dplyr::pull(row) | |
) | |
# package for stan & sample ---- | |
data_for_stan = tibble::lst( #lst permits later entries to refer to earlier entries | |
#### | |
# Entries we need to specify ourselves | |
#### | |
# W: within predictor matrix | |
uW = as.matrix(uW) | |
# sim_pars$num_subj: number of subjects | |
, num_subj = length(unique(dat_summary$subj)) | |
# num_obs: number of observations in each subject/condition | |
, num_obs = dat_summary$num_obs | |
# sum_obs: number of 1's in each subject/condition | |
, sum_obs = dat_summary$sum_obs | |
# obs_index: index of each trial in flattened version of subject-by-condition value matrix | |
, obs_index = obs_index | |
#### | |
# Entries computable from the above | |
#### | |
# num_sums: number of sums | |
, num_sums = length(sum_obs) | |
# num_rows_W: num rows in within predictor matrix W | |
, num_rows_uW = nrow(uW) | |
# num_cols_W: num cols in within predictor matrix W | |
, num_cols_uW = ncol(uW) | |
) | |
#double-check: | |
tibble::glimpse(data_for_stan) | |
#compile the model | |
mod = cmdstanr::cmdstan_model('hwb_fast.stan') | |
#how many chains to run in parallel | |
phys_cores_minus_one = parallel::detectCores()/2-1 | |
#we want at least 4 chains. Most CPUs have >=4 cores these days, but parallel::detectCores() | |
# typically returns twice the physical core count thanks to most systems being able to | |
# "hyperthread", treating a single physical core as if it were two. However, only certain | |
# workloads benefit from hyperthreading, and Stan generally doesn't (indeed, it can hurt) | |
# so best to run only as many chains as there are physical cores. Additionally, probably a | |
# good idea to leave one core unused for other processes (inc. monitoring the Stan progress) | |
num_samples_to_obtain = 1e3 | |
#this is the number of samples to run on each chain. If the model samples well, 1e3 should | |
# be plenty (especially since it'll be 1e3*phys_cores_minus_one) for stable inference on | |
# even tail quantities of the posterior | |
sampling_seed = 1 | |
#setting the sampling seed explicitly helps ensure reproducibility | |
#sample the model | |
fit = mod$sample( | |
data = data_for_stan | |
, chains = phys_cores_minus_one | |
, parallel_chains = phys_cores_minus_one | |
, seed = sampling_seed | |
, iter_warmup = num_samples_to_obtain | |
, iter_sampling = num_samples_to_obtain | |
# update every 10% (yeah, the progress info sucks; we're working on it) | |
, refresh = (num_samples_to_obtain*2)/10 | |
) | |
#gather summary (inc. diagnostics) | |
fit_summary = | |
( | |
fit$summary() | |
%>% dplyr::select(variable,mean,q5,q95,rhat,contains('ess')) | |
%>% dplyr::filter( | |
!stringr::str_starts(variable,'chol_corr') | |
, !stringr::str_detect(variable,'helper') | |
, !has_underscore_suffix(variable) | |
, !is_cor_diag_or_lower_tri(variable,prefix='cor') | |
) | |
%>% sort_by_variable_size() | |
%>% add_stan_summary_tbl_class() | |
%>% add_diagnostic_bools(fit) | |
) | |
print(fit_summary,n=nrow(fit_summary)) | |
( | |
bayesplot::mcmc_intervals(fit$draws(variables='mean_coef')) | |
+ ggplot2::geom_point( | |
data = | |
( | |
tibble::tibble(value=sim_pars$coef_means) | |
%>% dplyr::mutate( | |
y = paste0('mean_coef[',1:dplyr::n(),']') | |
) | |
) | |
, mapping = ggplot2::aes( | |
y = y | |
, x = value | |
) | |
, colour = 'red' | |
) | |
) | |
( | |
bayesplot::mcmc_intervals(fit$draws(variables='sd_coef')) | |
+ ggplot2::geom_point( | |
data = | |
( | |
tibble::tibble(value=sim_pars$coef_sds) | |
%>% dplyr::mutate( | |
y = paste0('sd_coef[',1:dplyr::n(),']') | |
) | |
) | |
, mapping = ggplot2::aes( | |
y = y | |
, x = value | |
) | |
, colour = 'red' | |
) | |
) | |
( | |
bayesplot::mcmc_intervals(fit$draws(variables='cor')) | |
+ ggplot2::geom_point( | |
data = | |
( | |
tibble::as_tibble( | |
sim_pars$cor_mat | |
) | |
%>% dplyr::mutate(var1 = 1:dplyr::n()) | |
%>% tidyr::pivot_longer( | |
names_to = 'var2' | |
, names_prefix = 'V' | |
, cols = c(-var1) | |
) | |
%>% dplyr::mutate( | |
y = paste0('cor[',var1,',',var2,']') | |
) | |
) | |
, mapping = ggplot2::aes( | |
y = y | |
, x = value | |
) | |
, colour = 'red' | |
) | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data{ | |
// num_sums: number of sums | |
int<lower=1> num_sums ; | |
// num_obs: number of observations per subject/condition | |
int num_obs[num_sums] ; | |
// sum_obs: sum of observations per subject/condition | |
int sum_obs[num_sums] ; | |
// num_subj: number of subj | |
int<lower=1> num_subj ; | |
// num_rows_uW: num rows in uW | |
int<lower=1> num_rows_uW ; | |
// num_cols_uW: num cols in uW | |
int<lower=1> num_cols_uW ; | |
// uW: unique entries in the within predictor matrix | |
matrix[num_rows_uW,num_cols_uW] uW ; | |
// index: index of each trial in flattened subject-by-condition value matrix | |
int obs_index[num_sums] ; | |
} | |
parameters{ | |
// chol_corr: population-level correlations (on cholesky factor scale) amongst within-subject predictors | |
cholesky_factor_corr[num_cols_uW] chol_corr ; | |
//for parameters below, trailing underscore denotes that they need to be un-scaled in generated quantities | |
// coef_mean_: mean (across subj) for each coefficient | |
row_vector[num_cols_uW] mean_coef ; | |
// coef_sd_: sd (across subj) for each coefficient | |
vector<lower=0>[num_cols_uW] sd_coef ; | |
// multi_normal_helper: a helper variable for implementing non-centered parameterization | |
matrix[num_cols_uW,num_subj] multi_normal_helper ; | |
} | |
model{ | |
//// | |
// Priors | |
//// | |
// multi_normal_helper must have normal(0,1) prior for non-centered parameterization | |
to_vector(multi_normal_helper) ~ std_normal() ; | |
// relatively flat prior on correlations | |
chol_corr ~ lkj_corr_cholesky(2) ; | |
// normal(0,1) priors on all coef_sd | |
sd_coef ~ std_normal() ; | |
// normal(0,1) priors on all coefficients | |
mean_coef ~ std_normal() ; | |
// compute coefficients for each subject/condition | |
matrix[num_subj,num_cols_uW] subj_coef = ( | |
rep_matrix(mean_coef,num_subj) | |
+ transpose( | |
diag_pre_multiply(sd_coef,chol_corr) | |
* multi_normal_helper | |
) | |
) ; | |
// Loop over subj and conditions to compute unique entries in design matrix | |
matrix[num_rows_uW,num_subj] value_for_subj_cond ; | |
for(this_subj in 1:num_subj){ | |
for(this_condition in 1:num_rows_uW){ | |
value_for_subj_cond[this_condition,this_subj] = dot_product( | |
subj_coef[this_subj] | |
, uW[this_condition] | |
) ; | |
} | |
// // slightly less explicit but equally fast: | |
// value_for_subj_cond[,this_subj] = rows_dot_product( | |
// rep_matrix( | |
// subj_coef[this_subj] | |
// , num_rows_uW | |
// ) | |
// , W | |
// ) ; | |
} | |
// Likelihood | |
sum_obs ~ binomial( | |
num_obs | |
, inv_logit(to_vector(value_for_subj_cond))[obs_index] | |
) ; | |
} | |
generated quantities{ | |
// cor: correlation matrix for the full set of within-subject predictors | |
corr_matrix[num_cols_uW] cor = multiply_lower_tri_self_transpose(chol_corr) ; | |
// tweak cor to avoid rhat false-alarm | |
for(i in 1:num_cols_uW){ | |
cor[i,i] += uniform_rng(1e-16, 1e-15) ; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment