Skip to content

Instantly share code, notes, and snippets.

@mike-lawrence
Last active April 19, 2017 19:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mike-lawrence/7cc9459a2b5730de51ec8325a2f35762 to your computer and use it in GitHub Desktop.
Save mike-lawrence/7cc9459a2b5730de51ec8325a2f35762 to your computer and use it in GitHub Desktop.
Functions facilitating running Stan models on many-cpu systems
# Code below may not be up to date, see ezStan (https://github.com/mike-lawrence/ezStan/blob/master/R/bigStan.R) for latest version
# todo:
# during-sampling: effective sample size and rhat for each parameter
# during-sampling: diagnostics?
#usage:
# #compile the model using rstan::stan_model:
# mod = rstan::stan_model('my_model.stan')
# #start the chains:
# startBigStan(
# stanMod = mod
# , stanData = data_for_stan
# )
# #watch their progress:
# watchBigStan()
# #when done, get the samples
# fromStan = collectBigStan()
# #delete temporary files
# cleanBigStan()
# #if necessary, kill all cores in broken bigStan run:
# killBigStan()
startBigStan = function(
stanMod
, stanData
, iter = 2e3
, cores = parallel::detectCores()
, stanArgs = NULL
, seedStart = 1
, warmup = NULL
){
if(dir.exists('bigStanTemp')){
cleanBigStan(reportAllClean=F)
}
dir.create('bigStanTemp')
if(is.null(warmup)){
warmup = iter/2
}
cat("\nStarting chains...")
save(stanData,stanMod,file='bigStanTemp/stanData.rda')
if(!dir.exists('bigStanTemp/r')){
dir.create('bigStanTemp/r')
}
if(!dir.exists('bigStanTemp/samples')){
dir.create('bigStanTemp/samples')
}
if(!dir.exists('bigStanTemp/samples')){
dir.create('bigStanTemp/samples')
}
if(!dir.exists('bigStanTemp/stdout')){
dir.create('bigStanTemp/stdout')
}
if(!dir.exists('bigStanTemp/stderr')){
dir.create('bigStanTemp/stderr')
}
if(!dir.exists('bigStanTemp/rdas')){
dir.create('bigStanTemp/rdas')
}
bigStanStuff = list(cores=cores,iter=iter,warmup=warmup)
bigStanStuff$rFileList = list()
bigStanStuff$chainNameList = list()
bigStanStuff$sampleFileList = list()
bigStanStuff$rdaFileList = list()
bigStanStuff$stdoutFileList = list()
bigStanStuff$stderrFileList = list()
bigStanStuff$progressList = list()
bigStanStuff$samplesList = list()
for(i in 1:cores){
bigStanStuff$progressList[[i]] = 0
bigStanStuff$samplesList[[i]] = NULL
bigStanStuff$rFileList[[i]] = paste0('bigStanTemp/r/',i,'.r')
bigStanStuff$chainNameList[[i]] = sprintf(paste0("chain%0",ceiling(log10(cores)),"d"),i)
bigStanStuff$sampleFileList[[i]] = paste0('bigStanTemp/samples/',bigStanStuff$chainNameList[[i]],'.txt')
bigStanStuff$rdaFileList[[i]] = paste0('bigStanTemp/rdas/',bigStanStuff$chainNameList[[i]],'.rda')
bigStanStuff$stdoutFileList[[i]] = paste0('bigStanTemp/stdout/',bigStanStuff$chainNameList[[i]],'.txt')
bigStanStuff$stderrFileList[[i]] = paste0('bigStanTemp/stderr/',bigStanStuff$chainNameList[[i]],'.txt')
cat(
paste0('seed = ',seedStart-1+i)
, "\n"
, paste0('iter = ',iter)
, "\n"
, paste0('warmup = ',warmup)
, "\n"
, "suppressMessages(library(rstan,quietly=T))"
, "\n"
, 'load("bigStanTemp/stanData.rda")'
, "\n"
, bigStanStuff$chainNameList[[i]]
, " = NULL"
, "\n"
, "while(is.null("
, bigStanStuff$chainNameList[[i]]
, ")){"
, "\n"
, "try("
, bigStanStuff$chainNameList[[i]]
, "<-rstan::sampling(
object = stanMod
, data = stanData
, seed = seed
, iter = iter
, warmup = warmup
, refresh = 0
, init = 0
, chains = 1
, cores = 1
, "
, stanArgs
, ", sample_file = '"
, bigStanStuff$sampleFileList[[i]]
, "'\n))}"
, "\n"
, "save(",bigStanStuff$chainNameList[[i]],",file='",bigStanStuff$rdaFileList[[i]],"')"
, "\n"
, sep = ''
, file = bigStanStuff$rFileList[[i]]
, append = FALSE
)
system2(
command = "Rscript"
, args = c('--vanilla',bigStanStuff$rFileList[[i]])
, stdout = bigStanStuff$stdoutFileList[[i]]
, stderr = bigStanStuff$stderrFileList[[i]]
, wait = FALSE
)
Sys.sleep(.1)
}
bigStanStuff$startTime = Sys.time()
save(bigStanStuff,file='bigStanTemp/bigStanStuff.rda')
cat("\nChains started. Run watchBigStan() to watch progress")
return(invisible(NULL))
}
watchBigStan = function(updateInterval=1,one_line_per_chain=TRUE,spacing=3){
load('bigStanTemp/bigStanStuff.rda')
bigStanStuff$numDone = length(list.files(path="bigStanTemp/rdas"))
while(bigStanStuff$numDone<bigStanStuff$cores){
chains_with_stderr = c()
bigStanStuff$numDone = length(list.files(path="bigStanTemp/rdas"))
Sys.sleep(updateInterval)
for(i in 1:bigStanStuff$cores){
if(bigStanStuff$progressList[[i]]<bigStanStuff$iter){ #only check this chain if it isn't done
if(file.exists(bigStanStuff$stderrFileList[[i]])){ #check if the stderr file exists
temp = readLines(bigStanStuff$stderrFileList[[i]])
if(length(temp)>0){ #stderr file has contents
chains_with_stderr = c(chains_with_stderr,i)
}
}
if(file.exists(bigStanStuff$sampleFile[[i]])){ #only try reading the sample file if it exists
a = readLines(bigStanStuff$sampleFile[[i]])
a = a[substr(a,1,1)!="#"]
a = a[substr(a,1,4)!="lp__"]
a = a[a!='']
old_progress = bigStanStuff$progressList[[i]]
bigStanStuff$progressList[[i]] = length(a)
# if(length(a)>warmup){
# old_warn = options("warn")
# options(warn=-1)
# samplesList[[i]] = rstan::read_stan_csv(sampleFile)
# options(warn=old_warn$warn)
# }
if(bigStanStuff$progressList[[i]]!=old_progress){
save(bigStanStuff,file='bigStanTemp/bigStanStuff.rda')
}
}
}
minDone = min(unlist(bigStanStuff$progressList))
timeElapsed = as.numeric(Sys.time() - bigStanStuff$startTime)
timeLeft = "?"
if(is.finite(minDone)){
if(minDone>0){
timeLeft = timeAsString(timeElapsed/minDone*(bigStanStuff$iter-minDone))
}
}
temp = paste0(unlist(bigStanStuff$chainNameList),': ',unlist(bigStanStuff$progressList),'/',bigStanStuff$iter)
updateTextToPrint = '\r'
for(i in temp){
updateTextToPrint = appendString(updateTextToPrint,i,spacing,one_line_per_chain)
}
temp = paste0('Chains complete: ',bigStanStuff$numDone,'/',bigStanStuff$cores)
updateTextToPrint = appendString(updateTextToPrint,temp,spacing,one_line_per_chain)
temp = paste0('Estimated time remaining: ',timeLeft)
updateTextToPrint = appendString(updateTextToPrint,temp,spacing,one_line_per_chain)
if(length(chains_with_stderr)>0){
temp = paste0('chains with errors: ',paste(chains_with_stderr,collapse=', '))
updateTextToPrint = appendString(updateTextToPrint,temp,spacing,one_line_per_chain)
}
cat(updateTextToPrint)
utils::flush.console()
}
}
chains_with_stderr = c()
for(i in 1:bigStanStuff$cores){
if(file.exists(bigStanStuff$stderrFileList[[i]])){ #check if the stderr file exists
temp = readLines(bigStanStuff$stderrFileList[[i]])
if(length(temp)>0){ #stderr file has contents
chains_with_stderr = c(chains_with_stderr,i)
}
}
}
updateTextToPrint = '\n'
temp = paste0('All done! Elapsed time: ',timeAsString(as.numeric(Sys.time() - bigStanStuff$startTime)))
updateTextToPrint = appendString(updateTextToPrint,temp,spacing,one_line_per_chain)
if(length(chains_with_stderr)>0){
temp = paste0('chains with messages from Stan: ',paste(chains_with_stderr,collapse=', '))
updateTextToPrint = appendString(updateTextToPrint,temp,spacing,one_line_per_chain)
}
cat(updateTextToPrint)
utils::flush.console()
return(invisible(NULL))
}
collectBigStan = function(){
load('bigStanTemp/bigStanStuff.rda')
rdaList = list()
for(i in 1:bigStanStuff$cores){
load(bigStanStuff$rdaFileList[[i]])
rdaList[[i]] = get(bigStanStuff$chainNameList[[i]])
}
return(sflist2stanfit(rdaList))
}
killBigStan = function(){
system2(
command = "killall"
, args = "R"
)
}
cleanBigStan = function(reportAllClean=T){
if(dir.exists('bigStanTemp')){
file.remove(list.files(path='bigStanTemp',all.files=T,full.names=T,recursive=T),recursive=T)
}
if(dir.exists('bigStanTemp/r')){
file.remove('bigStanTemp/r')
}
if(dir.exists('bigStanTemp/samples')){
file.remove('bigStanTemp/samples')
}
if(dir.exists('bigStanTemp/samples')){
file.remove('bigStanTemp/samples')
}
if(dir.exists('bigStanTemp/stdout')){
file.remove('bigStanTemp/stdout')
}
if(dir.exists('bigStanTemp/stderr')){
file.remove('bigStanTemp/stderr')
}
if(dir.exists('bigStanTemp/rdas')){
file.remove('bigStanTemp/rdas')
}
if(dir.exists('bigStanTemp')){
file.remove('bigStanTemp')
}
if(dir.exists('bigStanTemp')){
file.remove(list.files(path='bigStanTemp',all.files=T,full.names=T,recursive=T),recursive=T)
file.remove('bigStanTemp')
}
if(reportAllClean){
cat("\r","All clean!")
utils::flush.console()
}
return(invisible(NULL))
}
str_rep = function(x, i) {
paste(rep.int(x, i), collapse = "")
}
timeAsString = function(x) {
tolower(lubridate::seconds_to_period(round(x)))
}
appendString = function(s,i,spacing,one_line_per_chain){
if(one_line_per_chain){
w = getOption("width")
s = paste0(s,i,str_rep(' ',w-nchar(i)+spacing))
}else{
s = paste0(s,i,str_rep(' ',spacing))
}
return(s)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment