Skip to content

Instantly share code, notes, and snippets.

@mike-lawrence
Last active January 3, 2021 21:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mike-lawrence/bf86fe62e8fdebb1c5d8cfa0c398af30 to your computer and use it in GitHub Desktop.
Save mike-lawrence/bf86fe62e8fdebb1c5d8cfa0c398af30 to your computer and use it in GitHub Desktop.
composed_divergences
#preamble (installs/imports & custom functions) ----
# specify the packages used:
required_packages = c(
'rethinking' #for rlkjcorr & rmvrnom2
, 'crayon' #for coloring terminal output
, 'tidyverse' #for all that is good and holy
, 'progress' #for progress bar
, 'github.com/stan-dev/cmdstanr' #for Stan stuff
, 'github.com/mike-lawrence/ezStan' #for extra Stan stuff (here just get_contrast_matrix)
)
#load the helper functions:
source('helper_functions.r')
#helper_functions.r contains:
# - install_if_missing()
# - simulate_data()
# - get_num_diverged()
#install any required packages not already present
install_if_missing(required_packages)
# define a shorthand for the pipe operator
`%>%` = magrittr::`%>%`
# initialize the model & iter/chain counts
mod = cmdstanr::cmdstan_model('hwg_fast.stan')
iter_warmup = 1e3
iter_sample = 1e3
parallel_chains = parallel::detectCores()/2-1
# define the core loop function ----
run_seed = function(seed){
sink('/dev/null') #only way to keep things quiet
data_for_stan = simulate_data(seed)
both = mod$sample(
data = data_for_stan
, chains = parallel_chains
, parallel_chains = parallel_chains
, refresh = 0
, show_messages = F
, seed = seed
, iter_warmup = iter_warmup
, iter_sampling = iter_sample
)
both_num_diverged = get_num_diverged(both)
warmup = mod$sample(
data = data_for_stan
, chains = parallel_chains
, parallel_chains = parallel_chains
, refresh = 0
, show_messages = F
, seed = seed
, iter_warmup = iter_warmup
, save_warmup = T #for inits
, sig_figs = 18
, iter_sampling = 0
)
get_inits = function(chain_id){
warmup_draws = warmup$draws(inc_warmup=T)
final_warmup_value = warmup_draws[iter_warmup,chain_id,]
init_list = as.list(final_warmup_value)
names(init_list) = dimnames(final_warmup_value)[[3]]
init_list = init_list[names(init_list)!='lp__']
return(init_list)
}
samples = mod$sample(
data = data_for_stan
, chains = parallel_chains
, parallel_chains = parallel_chains
, refresh = 0
, show_messages = F
, seed = seed+1
, iter_warmup = 0
, adapt_engaged = FALSE
, inv_metric = warmup$inv_metric(matrix=F)
, step_size = warmup$metadata()$step_size_adaptation
, iter_sampling = iter_sample
, init = get_inits
)
samples_num_diverged = get_num_diverged(samples)
check_inputs = function(chain){
sampling_step_size = as.numeric(strsplit(readLines(samples$output_files()[chain])[42],'=')[[1]][2])
warmup_step_size = warmup$metadata()$step_size_adaptation[chain]
step_size_ok = dplyr::near(
sampling_step_size
, warmup_step_size
)
temp = readLines(samples$output_files()[chain])[44]
temp = stringr::str_replace(temp,'#','')
sampling_inv_metric = as.numeric(unlist(strsplit(temp,',',)))
warmup_inv_metric = unlist(warmup$inv_metric(matrix=F)[chain])
inv_metric_ok = all(dplyr::near(
sampling_inv_metric
, warmup_inv_metric
))
samples_inits =
(
tibble::tibble(
v = readLines(samples$metadata()$init[chain])
)
%>% dplyr::slice(2:(dplyr::n()-1))
%>% tidyr::separate(
v
, into = c('par','value')
, sep = ':'
)
%>% dplyr::mutate(
value = as.numeric(stringr::str_replace(value,',',''))
)
%>% dplyr::pull(value)
)
warmup_inits = as.numeric(get_inits(chain))
inits_ok = all(dplyr::near(samples_inits,warmup_inits))
return(all(step_size_ok,inv_metric_ok,inits_ok))
}
out = tibble::tibble(
#checking that all the step_sizes are identical
identical_step_size = identical(
warmup$metadata()$step_size_adaptation
, both$metadata()$step_size_adaptation
)
#checking that all the inv_metrics are identical
, identical_inv_metric = all(purrr::map2_lgl(
.x = warmup$inv_metric(matrix=F)
, .y = both$inv_metric(matrix=F)
, .f = identical
))
, identical_inputs = all(purrr::map_lgl(1:4,check_inputs))
, both_num_diverged = both_num_diverged
, samples_num_diverged = samples_num_diverged
)
sink(NULL)
print(out)
pb$tick()
return(out)
}
# loop over seeds ----
num_seeds = 1e2
pb = progress::progress_bar$new(
total = num_seeds
, format = "[:bar] :percent eta: :eta",
)
out = purrr::map_dfr(1:num_seeds,run_seed)
# compute some summaries
summary(out)
(
out
%>% tidyr::pivot_longer(
cols = contains('num_diverged')
)
%>% dplyr::group_by(
name
)
%>% dplyr::summarise(
any_diverged = mean(value>0)
, mean_num_diverged_when_any = mean(value[value>0])
, sd_num_diverged_when_any = sd(value[value>0])
, mean_num_diverged = mean(value)
, sd_num_diverged = sd(value)
)
)
(
out
%>% tidyr::pivot_longer(
cols = contains('num_diverged')
)
%>% dplyr::mutate(
name = dplyr::case_when(
name=='both_num_diverged' ~ 'traditional'
, name=='samples_num_diverged' ~ 'two-stage'
)
)
%>% ggplot2::ggplot()
+ ggplot2::geom_freqpoly(
ggplot2::aes(
x = value
# , y = after_stat(count)
, colour = name
, linetype = name
)
, stat = 'bin'
# , alpha = .5
, position = 'identity'
)
+ ggplot2::scale_x_log10()
+ ggplot2::scale_y_log10()
+ ggplot2::labs(
colour = 'run type'
, linetype = 'run type'
, x = '# divergences'
)
+ ggplot2::theme(
legend.position = 'top'
)
)
#' Installs any packages not already installed
#' @examples
#' \dontrun{
#' install_if_missing(c('tidyverse','github.com/stan-dev/cmdstanr'))
#' }
install_if_missing = function(pkgs){
missing_pkgs = NULL
for(this_pkg in pkgs){
path = NULL
try(
path <- find.package(basename(this_pkg),quiet=T,verbose=F)
, silent = T
)
if(is.null(path)){
missing_pkgs = c(missing_pkgs,this_pkg)
}
}
cran_missing = missing_pkgs[!grepl('github.com/',fixed=T,missing_pkgs)]
if(length(cran_missing)>0){
message('The following required but uninstalled CRAN packages will now be installed:\n',paste(cran_missing,collapse='\n'))
install.packages(cran_missing)
}
github_missing = missing_pkgs[grepl('github.com/',fixed=T,missing_pkgs)]
github_missing = gsub('github.com/','',github_missing)
if(length(github_missing)>0){
message('The following required but uninstalled Github packages will now be installed:\n',paste(this_pkg,collapse='\n'))
remotes::install_github(github_missing)
}
invisible()
}
get_num_diverged = function(x){
temp = x$cmdstan_diagnose()$stdout
temp = strsplit(temp,'Checking sampler transitions for divergences.\n')[[1]][2]
temp = strsplit(temp,' ')[[1]][1]
if(temp=='No'){
return(0)
}else{
return(as.numeric(temp))
}
}
simulate_data = function(seed){
set.seed(seed)
sim_pars = tibble::lst(
num_subj = 10
, num_trials = 10
, num_vars = 1
, num_coef = 2^(num_vars)
, coef_means = rnorm(num_coef)
, coef_sds = rweibull(num_coef,2,1)
, cor_mat = rethinking::rlkjcorr(1,num_coef,eta=1)
, noise = rweibull(1,2,1)
)
#compute the contrast matrix
contrast_matrix =
(
1:sim_pars$num_vars
%>% purrr::map(.f=function(x){
factor(c('lo','hi'))
})
%>% (function(x){
names(x) = paste0('v',1:sim_pars$num_vars)
return(x)
})
%>% purrr::cross_df()
%>% (function(x){
ezStan::get_contrast_matrix(
data = x
, formula = as.formula(paste('~',paste0('v',1:sim_pars$num_vars,collapse='*')))
, contrast_kind = ezStan::halfsum_contrasts
)
})
)
subj_coef =
(
#subj coefs as mvn
rethinking::rmvnorm2(
n = sim_pars$num_subj
, Mu = sim_pars$coef_means
, sigma = sim_pars$coef_sds
, Rho = sim_pars$cor_mat
)
%>% (function(x){
dimnames(x) = list(NULL,paste0('X',1:ncol(x)))
return(x)
})
%>% tibble::as_tibble(.name_repair='unique')
%>% dplyr::mutate(
subj = 1:sim_pars$num_subj
)
)
subj_cond =
(
subj_coef
%>% dplyr::group_by(subj)
%>% dplyr::summarise(
(function(x){
out = attr(contrast_matrix,'data')
out$cond_mean = as.vector(contrast_matrix %*% t(x))
return(out)
})(dplyr::cur_data())
, .groups = 'drop'
)
)
dat =
(
subj_cond
%>% tidyr::expand_grid(trial = 1:sim_pars$num_trials)
%>% dplyr::mutate(
obs = rnorm(dplyr::n(),cond_mean,sim_pars$noise)
)
)
# Compute inputs to model ----
#W: the full trial-by-trial contrast matrix
W =
(
dat
#use ezStan to get the contrast matrix (wrapper on stats::model.matrix)
%>% ezStan::get_contrast_matrix(
formula = ~ v1 # breaks if sim_pars$num_vars!=1
, contrast_kind = ezStan::halfsum_contrasts
)
#convert to tibble
%>% tibble::as_tibble(.name_repair='unique')
)
# get the unique entries in W
uW = dplyr::distinct(W)
#for each unique condition specified by uW, the stan model will
# work out values for that condition for each subject, and we'll need to index
# into the resulting subject-by-condition matrix. So we need to create our own
# subject-by-condition matrix and get the indices of the observed data into a
# the array produced when that matrix is flattened.
obs_index =
(
uW
#first repeat the matrix so there's a copy for each subject
%>% dplyr::slice(
rep(
dplyr::row_number()
, length(unique(dat$subj))
)
)
#now add the subject labels
%>% dplyr::mutate(
subj = rep(sort(unique(dat$subj)),each=nrow(uW))
)
#add row identifier
%>% dplyr::mutate(
row = 1:dplyr::n()
)
# join to the full contrast matrix W
%>% dplyr::right_join(
#add the subject column
dplyr::mutate(W,subj=dat$subj)
, by = c(names(uW),'subj')
)
#pull the row label
%>% dplyr::pull(row)
)
# package for stan
data_for_stan = tibble::lst( #lst permits later entries to refer to earlier entries
####
# Entries we need to specify ourselves
####
# W: within predictor matrix
uW = as.matrix(uW)
# sim_pars$num_subj: number of subjects
, num_subj = length(unique(dat$subj))
# outcome: outcome on each trial
, obs = dat$obs
# obs_index: index of each trial in flattened version of subject-by-condition value matrix
, obs_index = obs_index
####
# Entries computable from the above
####
# num_obs_total: total number of observations
, num_obs_total = length(obs)
# num_rows_W: num rows in within predictor matrix W
, num_rows_uW = nrow(uW)
# num_cols_W: num cols in within predictor matrix W
, num_cols_uW = ncol(uW)
)
return(data_for_stan)
}
data{
// num_obs_total: number of trials
int<lower=1> num_obs_total ;
// obs: observation on each trial
vector[num_obs_total] obs ;
// num_subj: number of subj
int<lower=1> num_subj ;
// num_rows_uW: num rows in uW
int<lower=1> num_rows_uW ;
// num_cols_uW: num cols in uW
int<lower=1> num_cols_uW ;
// uW: unique entries in the within predictor matrix
matrix[num_rows_uW,num_cols_uW] uW ;
// index: index of each trial in flattened subject-by-condition value matrix
int obs_index[num_obs_total] ;
}
transformed data{
// obs_mean: mean obs value
real obs_mean = mean(obs) ;
// obs_sd: sd of obss
real obs_sd = sd(obs) ;
// obs_: observations scaled to have zero mean and unit variance
vector[num_obs_total] obs_ = (obs-obs_mean)/obs_sd ;
}
parameters{
// chol_corr: population-level correlations (on cholesky factor scale) amongst within-subject predictors
cholesky_factor_corr[num_cols_uW] chol_corr ;
//for parameters below, trailing underscore denotes that they need to be un-scaled in generated quantities
// coef_mean_: mean (across subj) for each coefficient
row_vector[num_cols_uW] mean_coef_ ;
// coef_sd_: sd (across subj) for each coefficient
vector<lower=0>[num_cols_uW] sd_coef_ ;
// multi_normal_helper: a helper variable for implementing non-centered parameterization
matrix[num_cols_uW,num_subj] multi_normal_helper ;
// noise_: measurement noise
real<lower=0> noise_ ;
}
model{
////
// Priors
////
// multi_normal_helper must have normal(0,1) prior for non-centered parameterization
to_vector(multi_normal_helper) ~ std_normal() ;
// relatively flat prior on correlations
chol_corr ~ lkj_corr_cholesky(2) ;
// normal(0,1) priors on all coef_sd
sd_coef_ ~ std_normal() ;
// normal(0,1) priors on all coefficients
mean_coef_ ~ std_normal() ;
// low-near-zero prior on measurement noise
noise_ ~ weibull(2,1) ; // weibull(2,1) is peaked around .8
// compute coefficients for each subject/condition
matrix[num_subj,num_cols_uW] subj_coef_ = (
rep_matrix(mean_coef_,num_subj)
+ transpose(
diag_pre_multiply(sd_coef_,chol_corr)
* multi_normal_helper
)
) ;
// Loop over subj and conditions to compute unique entries in design matrix
matrix[num_rows_uW,num_subj] value_for_subj_cond ;
for(this_subj in 1:num_subj){
for(this_condition in 1:num_rows_uW){
value_for_subj_cond[this_condition,this_subj] = dot_product(
subj_coef_[this_subj]
, uW[this_condition]
) ;
}
// // slightly less explicit but equally fast:
// value_for_subj_cond[,this_subj] = rows_dot_product(
// rep_matrix(
// subj_coef_[this_subj]
// , num_rows_uW
// )
// , W
// ) ;
}
// Likelihood
obs_ ~ normal(
to_vector(value_for_subj_cond)[obs_index]
, noise_
) ;
}
// generated quantities{
//
// // cor: correlation matrix for the full set of within-subject predictors
// corr_matrix[num_cols_uW] cor = multiply_lower_tri_self_transpose(chol_corr) ;
//
// // coef_sd_: sd (across subj) for each coefficient
// vector[num_cols_uW] sd_coef = sd_coef_ * obs_sd ;
//
// // coef_mean: mean (across subj) for each coefficient
// row_vector[num_cols_uW] mean_coef = mean_coef_ * obs_sd ;
// mean_coef[1] = mean_coef[1] + obs_mean ; //adding the intercept
//
// // noise: measurement noise
// real noise = noise_ * obs_sd ;
//
// // tweak cor to avoid rhat false-alarm
// for(i in 1:num_cols_uW){
// cor[i,i] += uniform_rng(1e-16, 1e-15) ;
// }
//
// }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment