Install pyhf
with the JAX backend in a virtual environment either locally from the Git repo
python -m pip install --editable .[jax]
or from TestPyPI
python -m pip install --upgrade --extra-index-url https://test.pypi.org/simple/ --pre pyhf[jax]
You need to then install JAX for GPU using something like the following for your system specific configuration
$ cat install_JAX_GPU.sh
# install jaxlib
PYTHON_VERSION=cp38 # alternatives: cp36, cp37, cp38
CUDA_VERSION=cuda101 # alternatives: cuda100, cuda101, cuda102, cuda110
PLATFORM=manylinux2010_x86_64 # alternatives: manylinux2010_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade "$BASE_URL/$CUDA_VERSION/jaxlib-0.1.52-$PYTHON_VERSION-none-$PLATFORM.whl"
pip install --upgrade jax # install jax
Edit bench.py
to have
xla_bridge_backend = "CPU" # CPU or GPU
and then run
python bench.py
Edit bench.py
to have
xla_bridge_backend = "GPU" # CPU or GPU
and then run
python bench.py
Run
benchmark_ROOT.sh
On Illinois's Beast deep learning desktop
$ nvidia-smi -L && nvidia-smi | head -n 4 | tail -n 3
GPU 0: GeForce RTX 2080 Ti
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02 Driver Version: 450.80.02 CUDA Version: 11.0 |
|-------------------------------+----------------------+----------------------+
benchmakred at
(pyhf-dev) feickert@Beast:~/testarea/lukas-bench-test$ time python bench.py
Running with JAX on cpu
real 132m56.390s
user 253m5.689s
sys 70m39.498s
(pyhf-dev) feickert@Beast:~/testarea/lukas-bench-test$ time python bench.py
Running with JAX on gpu
real 29m50.816s
user 30m41.722s
sys 0m24.069s