Skip to content

Instantly share code, notes, and snippets.

@taesiri
Last active August 29, 2022 12:43
Show Gist options
  • Save taesiri/a52e23db68a56db175ed5a764f9dad70 to your computer and use it in GitHub Desktop.
Save taesiri/a52e23db68a56db175ed5a764f9dad70 to your computer and use it in GitHub Desktop.
Dockerfile for JAX Cuda + PyTorch
FROM nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04
# Install python3
RUN apt update && apt install -y python3-pip htop curl wget git
RUN ln -sf /usr/bin/python3 /usr/bin/python && \
ln -sf /usr/bin/pip3 /usr/bin/pip
RUN pip --no-cache-dir install --upgrade pip setuptools_rust
# Install ML Packages built with CUDA11 support
RUN ln -s /usr/lib/cuda /usr/local/cuda-11.6.1
RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip --no-cache-dir install tensorflow trax
RUN pip --no-cache-dir install jupyterlab ipython matplotlib
RUN pip --no-cache-dir install ipywidgets
RUN pip --no-cache-dir install pandas tqdm
RUN pip --no-cache-dir install git+https://github.com/google-research/vision_transformer
RUN pip --no-cache-dir install git+https://github.com/google-research/vision_transformer
RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment