Skip to content

Instantly share code, notes, and snippets.

@karnigili
Last active August 17, 2021 20:47
Show Gist options
  • Save karnigili/c5519b3b62ab494dedf5a0a5a4aebdeb to your computer and use it in GitHub Desktop.
Save karnigili/c5519b3b62ab494dedf5a0a5a4aebdeb to your computer and use it in GitHub Desktop.
#!/bin/bash
# install a temporary miniconda3 for testing so is from scratch
# (assuming access to pni modulefiles)
# load pnitoolbox for the pni_install_miniconda command
module load pnitoolbox
# make a tmp dir for the miniconda3 install
mkdir -p /tmp/$USER
pni_install_miniconda /tmp/$USER
# if the cmdstan conda environment already exists, activate it.
# else: create it and install cmdstanpy into it.
if [[ -d /tmp/$USER/envs/cmdstan ]]; then
source /tmp/$USER/bin/activate cmdstan
else
source /tmp/$USER/bin/activate base
conda create -n cmdstan -y
conda activate cmdstan
conda update -n base -c defaults conda -y
conda install conda-forge::cmdstanpy -y
fi
# but it exists, so load cmdstan
module load cmdstan/2.26.1
# run the hello_world test
# https://cmdstanpy.readthedocs.io/en/v0.9.75/hello_world.html
python << CMDSTAN_HELLO_WORLD
# import packages
import os
from cmdstanpy import cmdstan_path, CmdStanModel
# Instantiate the Stan model, assemble the data ---------------------
# specify Stan program file
bernoulli_stan = os.path.join(cmdstan_path(), 'examples', 'bernoulli', 'bernoulli.stan')
# instantiate the model; compiles the Stan program as needed.
bernoulli_model = CmdStanModel(stan_file=bernoulli_stan)
# inspect model object
print(bernoulli_model)
# Run the HMC-NUTS sampler ------------------------------------------
# specify data file
bernoulli_data = os.path.join(cmdstan_path(), 'examples', 'bernoulli', 'bernoulli.data.json')
# fit the model
bern_fit = bernoulli_model.sample(data=bernoulli_data, output_dir='.')
# printing the object reports sampler commands, output files
print(bern_fit)
# Access the sample -------------------------------------------------
bern_fit.draws().shape
bern_fit.draws(concat_chains=True).shape
draws_theta = bern_fit.stan_variable(name='theta')
draws_theta.shape
sampler_variables = bern_fit.sampler_vars_cols
stan_variables = bern_fit.stan_vars_cols
print('Sampler variables:\n{}'.format(sampler_variables))
print('Stan variables:\n{}'.format(stan_variables))
# Summarize the results ---------------------------------------------
bern_fit.summary()
bern_fit.diagnose()
# Save the Stan CSV files
bern_fit.save_csvfiles(dir='/tmp/$USER')
CMDSTAN_HELLO_WORLD
printf "\nResults are in /tmp/$USER/bernoulli*.csv, here is a listing:\n"
ls /tmp/$USER/bernoulli*.csv
# Remember to clean up!
printf "\nDon't forget to clean up:\n\trm -rf /tmp/$USER\nrm -f bernoulli*stdout.txt\n"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment