Last active
January 3, 2021 21:19
-
-
Save mike-lawrence/bf86fe62e8fdebb1c5d8cfa0c398af30 to your computer and use it in GitHub Desktop.
composed_divergences
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#preamble (installs/imports & custom functions) ---- | |
# specify the packages used: | |
required_packages = c( | |
'rethinking' #for rlkjcorr & rmvrnom2 | |
, 'crayon' #for coloring terminal output | |
, 'tidyverse' #for all that is good and holy | |
, 'progress' #for progress bar | |
, 'github.com/stan-dev/cmdstanr' #for Stan stuff | |
, 'github.com/mike-lawrence/ezStan' #for extra Stan stuff (here just get_contrast_matrix) | |
) | |
#load the helper functions: | |
source('helper_functions.r') | |
#helper_functions.r contains: | |
# - install_if_missing() | |
# - simulate_data() | |
# - get_num_diverged() | |
#install any required packages not already present | |
install_if_missing(required_packages) | |
# define a shorthand for the pipe operator | |
`%>%` = magrittr::`%>%` | |
# initialize the model & iter/chain counts | |
mod = cmdstanr::cmdstan_model('hwg_fast.stan') | |
iter_warmup = 1e3 | |
iter_sample = 1e3 | |
parallel_chains = parallel::detectCores()/2-1 | |
# define the core loop function ---- | |
run_seed = function(seed){ | |
sink('/dev/null') #only way to keep things quiet | |
data_for_stan = simulate_data(seed) | |
both = mod$sample( | |
data = data_for_stan | |
, chains = parallel_chains | |
, parallel_chains = parallel_chains | |
, refresh = 0 | |
, show_messages = F | |
, seed = seed | |
, iter_warmup = iter_warmup | |
, iter_sampling = iter_sample | |
) | |
both_num_diverged = get_num_diverged(both) | |
warmup = mod$sample( | |
data = data_for_stan | |
, chains = parallel_chains | |
, parallel_chains = parallel_chains | |
, refresh = 0 | |
, show_messages = F | |
, seed = seed | |
, iter_warmup = iter_warmup | |
, save_warmup = T #for inits | |
, sig_figs = 18 | |
, iter_sampling = 0 | |
) | |
get_inits = function(chain_id){ | |
warmup_draws = warmup$draws(inc_warmup=T) | |
final_warmup_value = warmup_draws[iter_warmup,chain_id,] | |
init_list = as.list(final_warmup_value) | |
names(init_list) = dimnames(final_warmup_value)[[3]] | |
init_list = init_list[names(init_list)!='lp__'] | |
return(init_list) | |
} | |
samples = mod$sample( | |
data = data_for_stan | |
, chains = parallel_chains | |
, parallel_chains = parallel_chains | |
, refresh = 0 | |
, show_messages = F | |
, seed = seed+1 | |
, iter_warmup = 0 | |
, adapt_engaged = FALSE | |
, inv_metric = warmup$inv_metric(matrix=F) | |
, step_size = warmup$metadata()$step_size_adaptation | |
, iter_sampling = iter_sample | |
, init = get_inits | |
) | |
samples_num_diverged = get_num_diverged(samples) | |
check_inputs = function(chain){ | |
sampling_step_size = as.numeric(strsplit(readLines(samples$output_files()[chain])[42],'=')[[1]][2]) | |
warmup_step_size = warmup$metadata()$step_size_adaptation[chain] | |
step_size_ok = dplyr::near( | |
sampling_step_size | |
, warmup_step_size | |
) | |
temp = readLines(samples$output_files()[chain])[44] | |
temp = stringr::str_replace(temp,'#','') | |
sampling_inv_metric = as.numeric(unlist(strsplit(temp,',',))) | |
warmup_inv_metric = unlist(warmup$inv_metric(matrix=F)[chain]) | |
inv_metric_ok = all(dplyr::near( | |
sampling_inv_metric | |
, warmup_inv_metric | |
)) | |
samples_inits = | |
( | |
tibble::tibble( | |
v = readLines(samples$metadata()$init[chain]) | |
) | |
%>% dplyr::slice(2:(dplyr::n()-1)) | |
%>% tidyr::separate( | |
v | |
, into = c('par','value') | |
, sep = ':' | |
) | |
%>% dplyr::mutate( | |
value = as.numeric(stringr::str_replace(value,',','')) | |
) | |
%>% dplyr::pull(value) | |
) | |
warmup_inits = as.numeric(get_inits(chain)) | |
inits_ok = all(dplyr::near(samples_inits,warmup_inits)) | |
return(all(step_size_ok,inv_metric_ok,inits_ok)) | |
} | |
out = tibble::tibble( | |
#checking that all the step_sizes are identical | |
identical_step_size = identical( | |
warmup$metadata()$step_size_adaptation | |
, both$metadata()$step_size_adaptation | |
) | |
#checking that all the inv_metrics are identical | |
, identical_inv_metric = all(purrr::map2_lgl( | |
.x = warmup$inv_metric(matrix=F) | |
, .y = both$inv_metric(matrix=F) | |
, .f = identical | |
)) | |
, identical_inputs = all(purrr::map_lgl(1:4,check_inputs)) | |
, both_num_diverged = both_num_diverged | |
, samples_num_diverged = samples_num_diverged | |
) | |
sink(NULL) | |
print(out) | |
pb$tick() | |
return(out) | |
} | |
# loop over seeds ---- | |
num_seeds = 1e2 | |
pb = progress::progress_bar$new( | |
total = num_seeds | |
, format = "[:bar] :percent eta: :eta", | |
) | |
out = purrr::map_dfr(1:num_seeds,run_seed) | |
# compute some summaries | |
summary(out) | |
( | |
out | |
%>% tidyr::pivot_longer( | |
cols = contains('num_diverged') | |
) | |
%>% dplyr::group_by( | |
name | |
) | |
%>% dplyr::summarise( | |
any_diverged = mean(value>0) | |
, mean_num_diverged_when_any = mean(value[value>0]) | |
, sd_num_diverged_when_any = sd(value[value>0]) | |
, mean_num_diverged = mean(value) | |
, sd_num_diverged = sd(value) | |
) | |
) | |
( | |
out | |
%>% tidyr::pivot_longer( | |
cols = contains('num_diverged') | |
) | |
%>% dplyr::mutate( | |
name = dplyr::case_when( | |
name=='both_num_diverged' ~ 'traditional' | |
, name=='samples_num_diverged' ~ 'two-stage' | |
) | |
) | |
%>% ggplot2::ggplot() | |
+ ggplot2::geom_freqpoly( | |
ggplot2::aes( | |
x = value | |
# , y = after_stat(count) | |
, colour = name | |
, linetype = name | |
) | |
, stat = 'bin' | |
# , alpha = .5 | |
, position = 'identity' | |
) | |
+ ggplot2::scale_x_log10() | |
+ ggplot2::scale_y_log10() | |
+ ggplot2::labs( | |
colour = 'run type' | |
, linetype = 'run type' | |
, x = '# divergences' | |
) | |
+ ggplot2::theme( | |
legend.position = 'top' | |
) | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Installs any packages not already installed | |
#' @examples | |
#' \dontrun{ | |
#' install_if_missing(c('tidyverse','github.com/stan-dev/cmdstanr')) | |
#' } | |
install_if_missing = function(pkgs){ | |
missing_pkgs = NULL | |
for(this_pkg in pkgs){ | |
path = NULL | |
try( | |
path <- find.package(basename(this_pkg),quiet=T,verbose=F) | |
, silent = T | |
) | |
if(is.null(path)){ | |
missing_pkgs = c(missing_pkgs,this_pkg) | |
} | |
} | |
cran_missing = missing_pkgs[!grepl('github.com/',fixed=T,missing_pkgs)] | |
if(length(cran_missing)>0){ | |
message('The following required but uninstalled CRAN packages will now be installed:\n',paste(cran_missing,collapse='\n')) | |
install.packages(cran_missing) | |
} | |
github_missing = missing_pkgs[grepl('github.com/',fixed=T,missing_pkgs)] | |
github_missing = gsub('github.com/','',github_missing) | |
if(length(github_missing)>0){ | |
message('The following required but uninstalled Github packages will now be installed:\n',paste(this_pkg,collapse='\n')) | |
remotes::install_github(github_missing) | |
} | |
invisible() | |
} | |
get_num_diverged = function(x){ | |
temp = x$cmdstan_diagnose()$stdout | |
temp = strsplit(temp,'Checking sampler transitions for divergences.\n')[[1]][2] | |
temp = strsplit(temp,' ')[[1]][1] | |
if(temp=='No'){ | |
return(0) | |
}else{ | |
return(as.numeric(temp)) | |
} | |
} | |
simulate_data = function(seed){ | |
set.seed(seed) | |
sim_pars = tibble::lst( | |
num_subj = 10 | |
, num_trials = 10 | |
, num_vars = 1 | |
, num_coef = 2^(num_vars) | |
, coef_means = rnorm(num_coef) | |
, coef_sds = rweibull(num_coef,2,1) | |
, cor_mat = rethinking::rlkjcorr(1,num_coef,eta=1) | |
, noise = rweibull(1,2,1) | |
) | |
#compute the contrast matrix | |
contrast_matrix = | |
( | |
1:sim_pars$num_vars | |
%>% purrr::map(.f=function(x){ | |
factor(c('lo','hi')) | |
}) | |
%>% (function(x){ | |
names(x) = paste0('v',1:sim_pars$num_vars) | |
return(x) | |
}) | |
%>% purrr::cross_df() | |
%>% (function(x){ | |
ezStan::get_contrast_matrix( | |
data = x | |
, formula = as.formula(paste('~',paste0('v',1:sim_pars$num_vars,collapse='*'))) | |
, contrast_kind = ezStan::halfsum_contrasts | |
) | |
}) | |
) | |
subj_coef = | |
( | |
#subj coefs as mvn | |
rethinking::rmvnorm2( | |
n = sim_pars$num_subj | |
, Mu = sim_pars$coef_means | |
, sigma = sim_pars$coef_sds | |
, Rho = sim_pars$cor_mat | |
) | |
%>% (function(x){ | |
dimnames(x) = list(NULL,paste0('X',1:ncol(x))) | |
return(x) | |
}) | |
%>% tibble::as_tibble(.name_repair='unique') | |
%>% dplyr::mutate( | |
subj = 1:sim_pars$num_subj | |
) | |
) | |
subj_cond = | |
( | |
subj_coef | |
%>% dplyr::group_by(subj) | |
%>% dplyr::summarise( | |
(function(x){ | |
out = attr(contrast_matrix,'data') | |
out$cond_mean = as.vector(contrast_matrix %*% t(x)) | |
return(out) | |
})(dplyr::cur_data()) | |
, .groups = 'drop' | |
) | |
) | |
dat = | |
( | |
subj_cond | |
%>% tidyr::expand_grid(trial = 1:sim_pars$num_trials) | |
%>% dplyr::mutate( | |
obs = rnorm(dplyr::n(),cond_mean,sim_pars$noise) | |
) | |
) | |
# Compute inputs to model ---- | |
#W: the full trial-by-trial contrast matrix | |
W = | |
( | |
dat | |
#use ezStan to get the contrast matrix (wrapper on stats::model.matrix) | |
%>% ezStan::get_contrast_matrix( | |
formula = ~ v1 # breaks if sim_pars$num_vars!=1 | |
, contrast_kind = ezStan::halfsum_contrasts | |
) | |
#convert to tibble | |
%>% tibble::as_tibble(.name_repair='unique') | |
) | |
# get the unique entries in W | |
uW = dplyr::distinct(W) | |
#for each unique condition specified by uW, the stan model will | |
# work out values for that condition for each subject, and we'll need to index | |
# into the resulting subject-by-condition matrix. So we need to create our own | |
# subject-by-condition matrix and get the indices of the observed data into a | |
# the array produced when that matrix is flattened. | |
obs_index = | |
( | |
uW | |
#first repeat the matrix so there's a copy for each subject | |
%>% dplyr::slice( | |
rep( | |
dplyr::row_number() | |
, length(unique(dat$subj)) | |
) | |
) | |
#now add the subject labels | |
%>% dplyr::mutate( | |
subj = rep(sort(unique(dat$subj)),each=nrow(uW)) | |
) | |
#add row identifier | |
%>% dplyr::mutate( | |
row = 1:dplyr::n() | |
) | |
# join to the full contrast matrix W | |
%>% dplyr::right_join( | |
#add the subject column | |
dplyr::mutate(W,subj=dat$subj) | |
, by = c(names(uW),'subj') | |
) | |
#pull the row label | |
%>% dplyr::pull(row) | |
) | |
# package for stan | |
data_for_stan = tibble::lst( #lst permits later entries to refer to earlier entries | |
#### | |
# Entries we need to specify ourselves | |
#### | |
# W: within predictor matrix | |
uW = as.matrix(uW) | |
# sim_pars$num_subj: number of subjects | |
, num_subj = length(unique(dat$subj)) | |
# outcome: outcome on each trial | |
, obs = dat$obs | |
# obs_index: index of each trial in flattened version of subject-by-condition value matrix | |
, obs_index = obs_index | |
#### | |
# Entries computable from the above | |
#### | |
# num_obs_total: total number of observations | |
, num_obs_total = length(obs) | |
# num_rows_W: num rows in within predictor matrix W | |
, num_rows_uW = nrow(uW) | |
# num_cols_W: num cols in within predictor matrix W | |
, num_cols_uW = ncol(uW) | |
) | |
return(data_for_stan) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data{ | |
// num_obs_total: number of trials | |
int<lower=1> num_obs_total ; | |
// obs: observation on each trial | |
vector[num_obs_total] obs ; | |
// num_subj: number of subj | |
int<lower=1> num_subj ; | |
// num_rows_uW: num rows in uW | |
int<lower=1> num_rows_uW ; | |
// num_cols_uW: num cols in uW | |
int<lower=1> num_cols_uW ; | |
// uW: unique entries in the within predictor matrix | |
matrix[num_rows_uW,num_cols_uW] uW ; | |
// index: index of each trial in flattened subject-by-condition value matrix | |
int obs_index[num_obs_total] ; | |
} | |
transformed data{ | |
// obs_mean: mean obs value | |
real obs_mean = mean(obs) ; | |
// obs_sd: sd of obss | |
real obs_sd = sd(obs) ; | |
// obs_: observations scaled to have zero mean and unit variance | |
vector[num_obs_total] obs_ = (obs-obs_mean)/obs_sd ; | |
} | |
parameters{ | |
// chol_corr: population-level correlations (on cholesky factor scale) amongst within-subject predictors | |
cholesky_factor_corr[num_cols_uW] chol_corr ; | |
//for parameters below, trailing underscore denotes that they need to be un-scaled in generated quantities | |
// coef_mean_: mean (across subj) for each coefficient | |
row_vector[num_cols_uW] mean_coef_ ; | |
// coef_sd_: sd (across subj) for each coefficient | |
vector<lower=0>[num_cols_uW] sd_coef_ ; | |
// multi_normal_helper: a helper variable for implementing non-centered parameterization | |
matrix[num_cols_uW,num_subj] multi_normal_helper ; | |
// noise_: measurement noise | |
real<lower=0> noise_ ; | |
} | |
model{ | |
//// | |
// Priors | |
//// | |
// multi_normal_helper must have normal(0,1) prior for non-centered parameterization | |
to_vector(multi_normal_helper) ~ std_normal() ; | |
// relatively flat prior on correlations | |
chol_corr ~ lkj_corr_cholesky(2) ; | |
// normal(0,1) priors on all coef_sd | |
sd_coef_ ~ std_normal() ; | |
// normal(0,1) priors on all coefficients | |
mean_coef_ ~ std_normal() ; | |
// low-near-zero prior on measurement noise | |
noise_ ~ weibull(2,1) ; // weibull(2,1) is peaked around .8 | |
// compute coefficients for each subject/condition | |
matrix[num_subj,num_cols_uW] subj_coef_ = ( | |
rep_matrix(mean_coef_,num_subj) | |
+ transpose( | |
diag_pre_multiply(sd_coef_,chol_corr) | |
* multi_normal_helper | |
) | |
) ; | |
// Loop over subj and conditions to compute unique entries in design matrix | |
matrix[num_rows_uW,num_subj] value_for_subj_cond ; | |
for(this_subj in 1:num_subj){ | |
for(this_condition in 1:num_rows_uW){ | |
value_for_subj_cond[this_condition,this_subj] = dot_product( | |
subj_coef_[this_subj] | |
, uW[this_condition] | |
) ; | |
} | |
// // slightly less explicit but equally fast: | |
// value_for_subj_cond[,this_subj] = rows_dot_product( | |
// rep_matrix( | |
// subj_coef_[this_subj] | |
// , num_rows_uW | |
// ) | |
// , W | |
// ) ; | |
} | |
// Likelihood | |
obs_ ~ normal( | |
to_vector(value_for_subj_cond)[obs_index] | |
, noise_ | |
) ; | |
} | |
// generated quantities{ | |
// | |
// // cor: correlation matrix for the full set of within-subject predictors | |
// corr_matrix[num_cols_uW] cor = multiply_lower_tri_self_transpose(chol_corr) ; | |
// | |
// // coef_sd_: sd (across subj) for each coefficient | |
// vector[num_cols_uW] sd_coef = sd_coef_ * obs_sd ; | |
// | |
// // coef_mean: mean (across subj) for each coefficient | |
// row_vector[num_cols_uW] mean_coef = mean_coef_ * obs_sd ; | |
// mean_coef[1] = mean_coef[1] + obs_mean ; //adding the intercept | |
// | |
// // noise: measurement noise | |
// real noise = noise_ * obs_sd ; | |
// | |
// // tweak cor to avoid rhat false-alarm | |
// for(i in 1:num_cols_uW){ | |
// cor[i,i] += uniform_rng(1e-16, 1e-15) ; | |
// } | |
// | |
// } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment