Skip to content

Instantly share code, notes, and snippets.

@matthewfeickert
Last active November 23, 2020 22:27
Show Gist options
  • Save matthewfeickert/2eabe99f1779344582d6635c70681f43 to your computer and use it in GitHub Desktop.
Save matthewfeickert/2eabe99f1779344582d6635c70681f43 to your computer and use it in GitHub Desktop.
Example GPU benchmarking script

Benchmarking of pyhf pseudoexperiments on CPU and GPU vs. ROOT

Setup

Install pyhf with the JAX backend in a virtual environment either locally from the Git repo

python -m pip install --editable .[jax]

or from TestPyPI

python -m pip install --upgrade --extra-index-url https://test.pypi.org/simple/ --pre pyhf[jax]

GPU Support

You need to then install JAX for GPU using something like the following for your system specific configuration

$ cat install_JAX_GPU.sh
# install jaxlib
PYTHON_VERSION=cp38  # alternatives: cp36, cp37, cp38
CUDA_VERSION=cuda101  # alternatives: cuda100, cuda101, cuda102, cuda110
PLATFORM=manylinux2010_x86_64  # alternatives: manylinux2010_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade "$BASE_URL/$CUDA_VERSION/jaxlib-0.1.52-$PYTHON_VERSION-none-$PLATFORM.whl"

pip install --upgrade jax  # install jax

Running

pyhf with JAX CPU

Edit bench.py to have

    xla_bridge_backend = "CPU"  # CPU or GPU

and then run

python bench.py

pyhf with JAX GPU

Edit bench.py to have

    xla_bridge_backend = "GPU"  # CPU or GPU

and then run

python bench.py

ROOT

Run

benchmark_ROOT.sh

Results

On Illinois's Beast deep learning desktop

$ nvidia-smi -L && nvidia-smi | head -n 4 | tail -n 3
GPU 0: GeForce RTX 2080 Ti
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+

benchmakred at

(pyhf-dev) feickert@Beast:~/testarea/lukas-bench-test$ time python bench.py

Running with JAX on cpu


real	132m56.390s
user	253m5.689s
sys	70m39.498s
(pyhf-dev) feickert@Beast:~/testarea/lukas-bench-test$ time python bench.py

Running with JAX on gpu


real	29m50.816s
user	30m41.722s
sys	0m24.069s
import numpy as np
import pyhf
import jax
import json
from jax.lib import xla_bridge
def model_from_spec(nbins, nsamples, seed=0):
np.random.seed(seed)
bins = np.linspace(-1, 1, nbins + 1)
# sm = np.zeros(nbins)
samples = []
for idx, v in enumerate(np.linspace(-1, 1, nsamples)):
pseudodata = np.random.normal(v, 0.2, size=100000)
counts, _ = np.histogram(pseudodata, bins=bins)
mods = [
{
"name": f"norm_{str(idx+1).zfill(4)}",
"type": "normsys",
"data": {"hi": 0.95, "lo": 1.05},
},
# {'name': f'shape_{str(i).zfill(4)}','type': 'shapesys', 'data': (c/10).tolist()}
]
if idx == 0:
mods.append({"name": "mu", "type": "normfactor", "data": None})
sample = {
"name": f"sample_{str(idx+1).zfill(4)}",
"data": counts.tolist(),
"modifiers": mods,
}
# sm += counts
samples.append(sample)
spec = {"channels": [{"name": "single_channel", "samples": samples}]}
return pyhf.Model(spec)
def main(nbins=100, nsamples=100):
xla_bridge_backend = "GPU" # CPU or GPU
# Global flag to set a specific platform, must be used at startup.
# This call must be run before any JAX command is run to make an impact
jax.config.update("jax_platform_name", xla_bridge_backend.lower())
pyhf.set_backend("jax")
print(f"\nRunning with JAX on {xla_bridge.get_backend().platform}\n")
model = model_from_spec(nbins=nbins, nsamples=nsamples)
pars = model.config.suggested_init()
pars[model.config.poi_index] = 0.0
# bkg-only model
data = pyhf.tensorlib.astensor(model.expected_data(pars))
# Serialize the workspace
workspace = pyhf.Workspace.build(model, data, name="measurement")
indent = None if nbins > 10 else 4
with open(f"workspace_{nbins}bins_{nsamples}samples.json", "w") as serialization:
json.dump(workspace, serialization, indent=indent, sort_keys=True)
model.logpdf = jax.jit(model.logpdf)
pars = pyhf.tensorlib.astensor(model.config.suggested_init())
# hypotest_result = pyhf.infer.hypotest(
# 1.0, data, model, calctype="toybased", qtilde=True
# )
hypotest_result = pyhf.infer.hypotest(
1.0, data, model, calctype="toybased", qtilde=True, track_progress=False
)
print(f"hypotest result: {hypotest_result}")
if __name__ == "__main__":
main()
#!/usr/bin/env bash
# backend="pytorch"
backend="jax"
# backend="tensorflow"
output_dir="benchmarks/gpu"
output_file="${output_dir}/gpu_${backend}.txt"
if [[ ! -d "${output_dir}" ]]; then
mkdir -p "${output_dir}"
fi
if [[ -f "${output_file}" ]]; then
cp "${output_file}" "${output_file::-4}_$(date '+%Y-%m-%d-%H%M%S').bak"
fi
printf "# time python bench.py\n\n" &> "${output_file}"
# wrap in () to captrue output of time
(time python bench.py) >> "${output_file}" 2>&1
#!/usr/bin/env bash
output_dir="benchmarks/ROOT"
output_file="${output_dir}/cpu_ROOT.txt"
# n_toys=2000
n_toys=200
if [[ ! -d "${output_dir}" ]]; then
mkdir -p "${output_dir}"
fi
if [[ -f "${output_file}" ]]; then
cp "${output_file}" "${output_file::-4}_$(date '+%Y-%m-%d-%H%M%S').bak"
fi
workspace_json="workspace_100bins_100samples.json"
root_workspace_dir="root_workspace"
if [[ -d "${root_workspace_dir}" ]]; then
rm -r "${root_workspace_dir}"
fi
if [[ ! -d "${root_workspace_dir}" ]]; then
mkdir "${root_workspace_dir}"
pyhf json2xml --output-dir "${root_workspace_dir}" "${workspace_json}"
fi
docker_image="atlasamglab/stats-base:root6.22.02"
docker pull "${docker_image}"
if [[ ! -f "${root_workspace_dir}"/config/FitConfig_measurement.root ]]; then
docker run \
--rm \
-u $(id -u):$(id -g) \
-v $PWD:$PWD \
-w $PWD \
"${docker_image}" \
"hist2workspace ${root_workspace_dir}/FitConfig.xml"
fi
printf "# time docker run --rm -v \$PWD:\$PWD -w \$PWD ${docker_image} 'python run_toys.py ${root_workspace_dir}/config/FitConfig_combined_measurement_model.root ${n_toys}'\n\n" &> "${output_file}"
# wrap in () to captrue output of time
(time docker run \
--rm \
-u $(id -u):$(id -g) \
-v $PWD:$PWD \
-w $PWD \
"${docker_image}" \
"python run_toys.py ${root_workspace_dir}/config/FitConfig_combined_measurement_model.root ${n_toys}") >> "${output_file}" 2>&1
#!/bin/bash
# export XLA_FLAGS="--xla_gpu_cuda_data_dir=$(command -v cuda-gdb)"
export XLA_FLAGS="--xla_gpu_cuda_data_dir=/usr/lib/cuda/"
echo "XLA_FLAGS: ${XLA_FLAGS}"
# install jaxlib
PYTHON_VERSION=cp38 # alternatives: cp36, cp37, cp38
CUDA_VERSION=cuda101 # alternatives: cuda100, cuda101, cuda102, cuda110
PLATFORM=manylinux2010_x86_64 # alternatives: manylinux2010_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade "$BASE_URL/$CUDA_VERSION/jaxlib-0.1.52-$PYTHON_VERSION-none-$PLATFORM.whl"
pip install --upgrade jax # install jax
import sys
import json
import ROOT
def run_toys_ROOT(infile, ntoys):
infile = ROOT.TFile.Open(infile)
workspace = infile.Get("combined")
data = workspace.data("obsData")
sb_model = workspace.obj("ModelConfig")
poi = sb_model.GetParametersOfInterest().first()
sb_model.SetSnapshot(ROOT.RooArgSet(poi))
bkg_model = sb_model.Clone()
bkg_model.SetName("bonly")
poi.setVal(0)
bkg_model.SetSnapshot(ROOT.RooArgSet(poi))
bkg_pdf = bkg_model.GetPdf()
print(f"bkg_pdf.fitTo(data): {bkg_pdf.fitTo(data)}")
sb_pdf = sb_model.GetPdf()
print(f"sb_pdf.fitTo(data): {sb_pdf.fitTo(data)}")
# hypotest_calc = ROOT.RooStats.FrequentistCalculator(data, bkg_model, sb_model)
# hypotest_calc.SetToys(ntoys, ntoys)
#
# test_stat = ROOT.RooStats.ProfileLikelihoodTestStat(bkg_model.GetPdf())
# test_stat.SetOneSidedDiscovery(False) # not q0
# test_stat.SetOneSided(True) # qtilde
# hypotest_calc.GetTestStatSampler().SetTestStatistic(test_stat)
#
# hypotest_inverter = ROOT.RooStats.HypoTestInverter(hypotest_calc)
# hypotest_inverter.SetConfidenceLevel(0.95)
# hypotest_inverter.UseCLs(True)
#
# # npoints = 1
# # hypotest_inverter.RunFixedScan(npoints, 1.0, 1.0)
# test_poi = 1.0
# hypotest_inverter.RunOnePoint(test_poi)
#
# result = hypotest_inverter.GetInterval()
# print(f"result: {result}")
if __name__ == "__main__":
run_toys_ROOT(
infile=sys.argv[1], ntoys=int(sys.argv[2]) if len(sys.argv) > 2 else 2000
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment