Skip to content

Instantly share code, notes, and snippets.

@michaelchughes
Created November 29, 2022 00:01
Show Gist options
  • Save michaelchughes/a43b98edb51feca49c3e6b5bb82347ca to your computer and use it in GitHub Desktop.
Save michaelchughes/a43b98edb51feca49c3e6b5bb82347ca to your computer and use it in GitHub Desktop.

Install notes from BDL2022f env install on 2022-11-28

  • For CUDA TOOLKIT 11.3, which can be used on older devices but may not be optimal
  1. set up basic conda env without any torch or jax packages, via
conda env create -f bdl_2022f.yml 
  1. install pytorch with specific cudatoolkit version
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
  1. install jax with same specific cudatoolkit version
conda install -c "nvidia/label/cuda-11.3.1" cuda-nvcc
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

The first line is needed to be sure that the correct version of nvidia's cuda-nvcc tools are installed without that line, I got an error about missing "ptxas", which I solved via this helpful thread google/jax#6843 (reply in thread)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment