Skip to content

Instantly share code, notes, and snippets.

@mike-lawrence
Last active April 23, 2024 13:03
Show Gist options
  • Save mike-lawrence/39412616302925a1941ad92777daa036 to your computer and use it in GitHub Desktop.
Save mike-lawrence/39412616302925a1941ad92777daa036 to your computer and use it in GitHub Desktop.
Stan vs LME4 for hiearchical within-subjects designs with binomial outcomes
#' Installs any packages not already installed
#' @examples
#' \dontrun{
#' install_if_missing(c('tidyverse','github.com/stan-dev/cmdstanr'))
#' }
install_if_missing = function(pkgs){
missing_pkgs = NULL
for(this_pkg in pkgs){
path = NULL
try(
path <- find.package(basename(this_pkg),quiet=T,verbose=F)
, silent = T
)
if(is.null(path)){
missing_pkgs = c(missing_pkgs,this_pkg)
}
}
cran_missing = missing_pkgs[!grepl('github.com/',fixed=T,missing_pkgs)]
if(length(cran_missing)>0){
message('The following required but uninstalled CRAN packages will now be installed:\n',paste(cran_missing,collapse='\n'))
install.packages(cran_missing)
}
github_missing = missing_pkgs[grepl('github.com/',fixed=T,missing_pkgs)]
github_missing = gsub('github.com/','',github_missing)
if(length(github_missing)>0){
message('The following required but uninstalled Github packages will now be installed:\n',paste(this_pkg,collapse='\n'))
remotes::install_github(github_missing)
}
invisible()
}
#define a function that adds diagnostics as a metadata attribute
add_diagnostic_bools = function(x,fit){
sink('/dev/null')
diagnostics = fit$cmdstan_diagnose()$stdout #annoyingly not quiet-able
sink(NULL)
diagnostic_bools = list(
treedepth_maxed = stringr::str_detect(diagnostics,'transitions hit the maximum')
, ebfmi_low = stringr::str_detect(diagnostics,' is below the nominal threshold')
, essp_low = stringr::str_detect(diagnostics,'The following parameters had fewer than ')
, rhat_high = stringr::str_detect(diagnostics,'The following parameters had split R-hat greater than')
)
attr(x,'meta') = list(diagnostic_bools=diagnostic_bools)
return(x)
}
#define a custom print method that shows the diagnostics in the metadata attribute
print.stan_summary_tbl = function(x,...) {
meta = attr(x,'meta')
if(any(unlist(meta$diagnostic_bools))){
cat(crayon::bgRed('WARNING:\n'))
}
if(meta$diagnostic_bools$treedepth_maxed){
cat(crayon::bgRed('Treedepth maxed\n'))
}
if(meta$diagnostic_bools$ebfmi_low){
cat(crayon::bgRed('E-BMFI low\n'))
}
if(meta$diagnostic_bools$essp_low){
cat(crayon::bgRed('ESS% low for one or more parameters\n'))
}
if(meta$diagnostic_bools$rhat_high){
cat(crayon::bgRed('R-hat high for one or more parameters\n'))
}
NextMethod(x,...)
invisible(x)
}
#create a new S3 class for custom-printing stanfit summary tables
add_stan_summary_tbl_class = function(x){
class(x) <- c("stan_summary_tbl",class(x))
return(x)
}
#function to detect whether a variable name indicates that it's on the diagonal
# of a correlation parameter matrix
has_underscore_suffix = function(x){
bare_has_underscore_suffix = stringr::str_ends(x,'_')
has_index_suffix = stringr::str_ends(x,']')
indexed_has_underscore_suffix =
(
tibble::tibble(x = x[has_index_suffix])
%>% tidyr::separate(x,sep='\\[',into='x',extra='drop')
%>% dplyr::mutate( out=stringr::str_ends(x,'_') )
%>% dplyr::pull(out)
)
bare_has_underscore_suffix[has_index_suffix] = indexed_has_underscore_suffix
return(bare_has_underscore_suffix)
}
#function to detect whether a variable name indicates that it's on the diagonal
# or lower-tri element of a correlation parameter matrix
is_cor_diag_or_lower_tri = function(x,prefix){
has_prefix = stringr::str_starts(x,prefix)
x = x[has_prefix]
del = function(x,to_del){gsub(to_del,'',x,fixed=T)}
to_toss =
(
x
%>% del(prefix)
%>% del('[')
%>% del(']')
%>% tibble::tibble(x = .)
%>% tidyr::separate(x,into=c('i','j'))
%>% dplyr::mutate( to_toss = (i==j) | (i>j) )
%>% dplyr::pull(to_toss)
)
has_prefix[has_prefix] = to_toss
return(has_prefix)
}
#function to sort a stan summary table by size of variables
sort_by_variable_size = function(x){
x2 =
(
x
%>% tidyr::separate(
variable
, sep = '\\['
, into = 'var'
, extra = 'drop'
, remove = F
)
)
(
x2
%>% dplyr::group_by(var)
%>% dplyr::summarise(count = dplyr::n(),.groups = 'drop')
%>% dplyr::full_join(x2,by='var')
%>% dplyr::arrange(count,var,variable)
%>% dplyr::select(-count,-var)
)
}
halfsum_contrasts = function(...){
contr.sum(...)*.5
}
get_contrast_matrix = function(
data
, formula
, contrast_kind = NULL
){
if (inherits(data, "tbl_df")) {
data = as.data.frame(data)
}
vars = attr(terms(formula),'term.labels')
vars = vars[!grepl(':',vars)]
if(length(vars)==1){
data = data.frame(data[,vars])
names(data) = vars
}else{
data = data[,vars]
}
vars_to_rename = NULL
for(i in vars){
if(is.character(data[,i])){
data[,i] = factor(data[,i])
}
if( is.factor(data[,i])){
if(length(levels(data[,i]))==2){
vars_to_rename = c(vars_to_rename,i)
}
if(!is.null(contrast_kind) ){
contrasts(data[,i]) = contrast_kind
}
}
}
mm = model.matrix(data=data,object=formula)
dimnames(mm)[[2]][dimnames(mm)[[2]]=='(Intercept)'] = '(I)'
for(i in vars_to_rename){
dimnames(mm)[[2]] = gsub(paste0(i,1),i,dimnames(mm)[[2]])
}
attr(mm,'formula') = formula
attr(mm,'data') = data
return(mm)
}
#preamble (options, installs, imports & custom functions) ----
options(warn=1) #really should be default in R
`%!in%` = Negate(`%in%`) #should be in base R!
# specify the packages used:
required_packages = c(
'rethinking' #for rlkjcorr & rmvrnom2
, 'crayon' #for coloring terminal output
, 'bayesplot' #for convenient posterior plots
, 'github.com/stan-dev/cmdstanr' #for Stan stuff
, 'tidyverse' #for all that is good and holy
)
#load the helper functions:
source('helper_functions.r')
#helper_functions.r defines:
# - install_if_missing()
# - add_diagnostic_bools()
# - print.stan_summary_tbl()
# - add_stan_summary_tbl_class()
# - is_cor_diag()
# - halfsum_contrasts()
# - get_contrast_matrix()
#install any required packages not already present
install_if_missing(required_packages)
# define a shorthand for the pipe operator
`%>%` = magrittr::`%>%`
#simulate data ----
set.seed(1) #change this to make different data
#setting the data simulation parameters
sim_pars = tibble::lst(
#parameters you can play with
num_subj = 100 #number of subjects, must be an integer >1
, num_vars = 3 #number of 2-level variables manipulated as crossed and within each subject, must be an integer >0
, num_trials = 100 #number of trials per subject/condition combo, must be an integer >1
#the rest of these you shouldn't touch
, num_coef = 2^(num_vars)
, coef_means = rnorm(num_coef)
, coef_sds = rweibull(num_coef,2,1)
, cor_mat = rethinking::rlkjcorr(1,num_coef,eta=1)
)
#compute the contrast matrix
contrast_matrix =
(
1:sim_pars$num_vars
%>% purrr::map(.f=function(x){
factor(c('lo','hi'))
})
%>% (function(x){
names(x) = paste0('v',1:sim_pars$num_vars)
return(x)
})
%>% purrr::cross_df()
%>% (function(x){
get_contrast_matrix(
data = x
, formula = as.formula(paste('~',paste0('v',1:sim_pars$num_vars,collapse='*')))
, contrast_kind = halfsum_contrasts
)
})
)
#get coefficients for each subject
subj_coef =
(
#subj coefs as mvn
rethinking::rmvnorm2(
n = sim_pars$num_subj
, Mu = sim_pars$coef_means
, sigma = sim_pars$coef_sds
, Rho = sim_pars$cor_mat
)
#add names to columns
%>% (function(x){
dimnames(x)=list(NULL,paste0('X',1:ncol(x)))
return(x)
})
#make a tibble
%>% tibble::as_tibble(.name_repair='unique')
#add subject identifier column
%>% dplyr::mutate(
subj = 1:sim_pars$num_subj
)
)
# get condition means implied by subject coefficients and contrast matrix
subj_cond =
(
subj_coef
%>% dplyr::group_by(subj)
%>% dplyr::summarise(
(function(x){
out = attr(contrast_matrix,'data')
out$cond_mean = as.vector(contrast_matrix %*% t(x))
return(out)
})(dplyr::cur_data())
, .groups = 'drop'
)
)
# get noisy measurements in each condition for each subject
dat =
(
subj_cond
%>% tidyr::expand_grid(trial = 1:sim_pars$num_trials)
%>% dplyr::mutate(
obs = rbinom(dplyr::n(),1,plogis(cond_mean))
)
%>% dplyr::select(-cond_mean)
)
# Try glmer ----
#annoyingly need this to set contrasts
dat_as_df = as.data.frame(dat)
#first apply half-sum contrasts as will be used in the Stan model
for(i in 1:sim_pars$num_vars){
dat_as_df[,names(dat)==paste0('v',i)] = factor(dat_as_df[,names(dat)==paste0('v',i)])
contrasts(dat_as_df[,names(dat)==paste0('v',i)]) = halfsum_contrasts
}
#compute the formula string
vars_string = paste0('v',1:sim_pars$num_vars,collapse='*')
formula_string = paste0('obs~1+',vars_string,'+(1+',vars_string,'|subj)')
#fit while timing
system.time(
fit_glmer <- lme4::glmer(
data = dplyr::mutate(dat,subj = factor(subj))
, formula = as.formula(formula_string)
, family = binomial(link='logit')
, control = lme4::glmerControl(
optCtrl = list(maxfun=1e6)
, optimizer = 'bobyqa'
)
)
)
# Compute inputs to Stan model ----
#first collapse to subject/condition stats
dat_summary =
(
dat
%>% dplyr::group_by(
dplyr::across(c(
-obs
, -trial
))
)
%>% dplyr::summarise(
num_obs = dplyr::n()
, sum_obs = sum(obs)
, .groups = 'drop'
)
)
#W: the full trial-by-trial contrast matrix
W =
(
dat_summary
#get the contrast matrix (wrapper on stats::model.matrix)
%>% get_contrast_matrix(
# the following compilcated specification of the formula is a by-product of making this example
# work for any value for sim_pars$num_vars; normally you would do something like this
# (for 2 variables for example):
# formula = ~ v1*v2
formula = as.formula(paste0('~',paste0('v',1:sim_pars$num_vars,collapse='*')))
# half-sum contrasts are nice for 2-level variables bc they yield parameters whose value
# is the difference between conditions
, contrast_kind = halfsum_contrasts
)
#convert to tibble
%>% tibble::as_tibble(.name_repair='unique')
)
#quick glimpse; lots of rows
print(W)
# get the unique entries in W
uW = dplyr::distinct(W)
print(uW)
#far fewer rows!
#for each unique condition specified by uW, the stan model will
# work out values for that condition for each subject, and we'll need to index
# into the resulting subject-by-condition matrix. So we need to create our own
# subject-by-condition matrix and get the indices of the observed data into a
# the array produced when that matrix is flattened.
uW_per_subj =
(
uW
#first repeat the matrix so there's a copy for each subject
%>% dplyr::slice(
rep(
dplyr::row_number()
, length(unique(dat_summary$subj))
)
)
#now add the subject labels
%>% dplyr::mutate(
subj = rep(sort(unique(dat_summary$subj)),each=nrow(uW))
)
#add row identifier
%>% dplyr::mutate(
row = 1:dplyr::n()
)
)
obs_index =
(
W
#add the subject column
%>% dplyr::mutate(subj=dat_summary$subj)
# join to the full contrast matrix W
%>% dplyr::left_join(
uW_per_subj
, by = c(names(uW),'subj')
)
#pull the row label
%>% dplyr::pull(row)
)
# package for stan & sample ----
data_for_stan = tibble::lst( #lst permits later entries to refer to earlier entries
####
# Entries we need to specify ourselves
####
# W: within predictor matrix
uW = as.matrix(uW)
# sim_pars$num_subj: number of subjects
, num_subj = length(unique(dat_summary$subj))
# num_obs: number of observations in each subject/condition
, num_obs = dat_summary$num_obs
# sum_obs: number of 1's in each subject/condition
, sum_obs = dat_summary$sum_obs
# obs_index: index of each trial in flattened version of subject-by-condition value matrix
, obs_index = obs_index
####
# Entries computable from the above
####
# num_sums: number of sums
, num_sums = length(sum_obs)
# num_rows_W: num rows in within predictor matrix W
, num_rows_uW = nrow(uW)
# num_cols_W: num cols in within predictor matrix W
, num_cols_uW = ncol(uW)
)
#double-check:
tibble::glimpse(data_for_stan)
#compile the model
mod = cmdstanr::cmdstan_model('hwb_fast.stan')
#how many chains to run in parallel
phys_cores_minus_one = parallel::detectCores()/2-1
#we want at least 4 chains. Most CPUs have >=4 cores these days, but parallel::detectCores()
# typically returns twice the physical core count thanks to most systems being able to
# "hyperthread", treating a single physical core as if it were two. However, only certain
# workloads benefit from hyperthreading, and Stan generally doesn't (indeed, it can hurt)
# so best to run only as many chains as there are physical cores. Additionally, probably a
# good idea to leave one core unused for other processes (inc. monitoring the Stan progress)
num_samples_to_obtain = 1e3
#this is the number of samples to run on each chain. If the model samples well, 1e3 should
# be plenty (especially since it'll be 1e3*phys_cores_minus_one) for stable inference on
# even tail quantities of the posterior
sampling_seed = 1
#setting the sampling seed explicitly helps ensure reproducibility
#sample the model
fit = mod$sample(
data = data_for_stan
, chains = phys_cores_minus_one
, parallel_chains = phys_cores_minus_one
, seed = sampling_seed
, iter_warmup = num_samples_to_obtain
, iter_sampling = num_samples_to_obtain
# update every 10% (yeah, the progress info sucks; we're working on it)
, refresh = (num_samples_to_obtain*2)/10
)
#gather summary (inc. diagnostics)
fit_summary =
(
fit$summary()
%>% dplyr::select(variable,mean,q5,q95,rhat,contains('ess'))
%>% dplyr::filter(
!stringr::str_starts(variable,'chol_corr')
, !stringr::str_detect(variable,'helper')
, !has_underscore_suffix(variable)
, !is_cor_diag_or_lower_tri(variable,prefix='cor')
)
%>% sort_by_variable_size()
%>% add_stan_summary_tbl_class()
%>% add_diagnostic_bools(fit)
)
print(fit_summary,n=nrow(fit_summary))
(
bayesplot::mcmc_intervals(fit$draws(variables='mean_coef'))
+ ggplot2::geom_point(
data =
(
tibble::tibble(value=sim_pars$coef_means)
%>% dplyr::mutate(
y = paste0('mean_coef[',1:dplyr::n(),']')
)
)
, mapping = ggplot2::aes(
y = y
, x = value
)
, colour = 'red'
)
)
(
bayesplot::mcmc_intervals(fit$draws(variables='sd_coef'))
+ ggplot2::geom_point(
data =
(
tibble::tibble(value=sim_pars$coef_sds)
%>% dplyr::mutate(
y = paste0('sd_coef[',1:dplyr::n(),']')
)
)
, mapping = ggplot2::aes(
y = y
, x = value
)
, colour = 'red'
)
)
(
bayesplot::mcmc_intervals(fit$draws(variables='cor'))
+ ggplot2::geom_point(
data =
(
tibble::as_tibble(
sim_pars$cor_mat
)
%>% dplyr::mutate(var1 = 1:dplyr::n())
%>% tidyr::pivot_longer(
names_to = 'var2'
, names_prefix = 'V'
, cols = c(-var1)
)
%>% dplyr::mutate(
y = paste0('cor[',var1,',',var2,']')
)
)
, mapping = ggplot2::aes(
y = y
, x = value
)
, colour = 'red'
)
)
data{
// num_sums: number of sums
int<lower=1> num_sums ;
// num_obs: number of observations per subject/condition
int num_obs[num_sums] ;
// sum_obs: sum of observations per subject/condition
int sum_obs[num_sums] ;
// num_subj: number of subj
int<lower=1> num_subj ;
// num_rows_uW: num rows in uW
int<lower=1> num_rows_uW ;
// num_cols_uW: num cols in uW
int<lower=1> num_cols_uW ;
// uW: unique entries in the within predictor matrix
matrix[num_rows_uW,num_cols_uW] uW ;
// index: index of each trial in flattened subject-by-condition value matrix
int obs_index[num_sums] ;
}
parameters{
// chol_corr: population-level correlations (on cholesky factor scale) amongst within-subject predictors
cholesky_factor_corr[num_cols_uW] chol_corr ;
//for parameters below, trailing underscore denotes that they need to be un-scaled in generated quantities
// coef_mean_: mean (across subj) for each coefficient
row_vector[num_cols_uW] mean_coef ;
// coef_sd_: sd (across subj) for each coefficient
vector<lower=0>[num_cols_uW] sd_coef ;
// multi_normal_helper: a helper variable for implementing non-centered parameterization
matrix[num_cols_uW,num_subj] multi_normal_helper ;
}
model{
////
// Priors
////
// multi_normal_helper must have normal(0,1) prior for non-centered parameterization
to_vector(multi_normal_helper) ~ std_normal() ;
// relatively flat prior on correlations
chol_corr ~ lkj_corr_cholesky(2) ;
// normal(0,1) priors on all coef_sd
sd_coef ~ std_normal() ;
// normal(0,1) priors on all coefficients
mean_coef ~ std_normal() ;
// compute coefficients for each subject/condition
matrix[num_subj,num_cols_uW] subj_coef = (
rep_matrix(mean_coef,num_subj)
+ transpose(
diag_pre_multiply(sd_coef,chol_corr)
* multi_normal_helper
)
) ;
// Loop over subj and conditions to compute unique entries in design matrix
matrix[num_rows_uW,num_subj] value_for_subj_cond ;
for(this_subj in 1:num_subj){
for(this_condition in 1:num_rows_uW){
value_for_subj_cond[this_condition,this_subj] = dot_product(
subj_coef[this_subj]
, uW[this_condition]
) ;
}
// // slightly less explicit but equally fast:
// value_for_subj_cond[,this_subj] = rows_dot_product(
// rep_matrix(
// subj_coef[this_subj]
// , num_rows_uW
// )
// , W
// ) ;
}
// Likelihood
sum_obs ~ binomial(
num_obs
, inv_logit(to_vector(value_for_subj_cond))[obs_index]
) ;
}
generated quantities{
// cor: correlation matrix for the full set of within-subject predictors
corr_matrix[num_cols_uW] cor = multiply_lower_tri_self_transpose(chol_corr) ;
// tweak cor to avoid rhat false-alarm
for(i in 1:num_cols_uW){
cor[i,i] += uniform_rng(1e-16, 1e-15) ;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment