Skip to content

Instantly share code, notes, and snippets.

@trevor-m
Last active June 5, 2023 22:43
Show Gist options
  • Save trevor-m/b804620933ae6fc1e2df7068d9d1aa8a to your computer and use it in GitHub Desktop.
Save trevor-m/b804620933ae6fc1e2df7068d9d1aa8a to your computer and use it in GitHub Desktop.
PAXML + PJRT Container
# Build using `docker build -t pjrt .`
# Inside container, run using:
# 1 device
# CUDA_VISIBLE_DEVICES=0 JAX_PLATFORMS=iree_cuda python main.py --exp=tasks.lm.params.nvidia.NVIDIA1_3BPmap --job_log_dir=log_NVIDIA1_3BPmap
# 2 devices
# `CUDA_VISIBLE_DEVICES=0,1 JAX_PLATFORMS=iree_cuda mpirun --allow-run-as-root -np 2 python main.py --exp=tasks.lm.params.nvidia.NVIDIA1_3BPmap --job_log_dir=log_NVIDIA1_3BPmap --mode=eval --multiprocess_gpu`
FROM ghcr.io/nvidia/pax:nightly-2023-05-23
SHELL [ "/bin/bash", "-c" ]
ENV CUDA_SDK_DIR=/usr/local/cuda
RUN apt-get update \
&& apt-get install -y openmpi-bin openmpi-common libopenmpi-dev curl rsync vim
# Install bazel
RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \
&& chmod +x bazelisk-linux-amd64 \
&& mv bazelisk-linux-amd64 /usr/local/bin/bazel
# Build PJRT.
RUN mkdir /workspace && cd /workspace \
&& git clone -b use-mpi https://github.com/okkwon/openxla-pjrt-plugin.git \
&& cd openxla-pjrt-plugin \
&& python ./sync_deps.py \
&& python -m pip install -U -r requirements.txt \
&& python ./configure.py --cc=clang --cxx=clang++ --cuda-sdk-dir=$CUDA_SDK_DIR \
&& source .env.sh \
&& bazel build iree/integrations/pjrt/... partitioner/...
# Install NCCL 2.18.1 (required by IREE).
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libnccl2_2.18.1-1+cuda12.1_amd64.deb \
&& dpkg -i libnccl2_2.18.1-1+cuda12.1_amd64.deb && rm libnccl2_2.18.1-1+cuda12.1_amd64.deb
# Temporary patch for mhlo.rng_bit_generator
RUN sed -i "s/jax_default_prng_impl', 'rbg'/jax_default_prng_impl', 'threefry2x32'/" /opt/praxis/praxis/py_utils.py
# Temporary fix for jax multiclient initialization with MPI
RUN perl -0777 -i -pe 's/jax.distributed.initialize\(jax_distributed_options.coordinator_address,\n[ \t]+jax_distributed_options.num_processes,\n[ \t]+jax_distributed_options.process_id\)/jax.distributed.initialize\(\)/' /opt/paxml/paxml/setup_jax.py
# Add a pmap config to paxml
RUN echo $'\n@experiment_registry.register\n\
class NVIDIA1_3BPmap(NVIDIA1_3B):\n\
DCN_MESH_SHAPE = None\n\
ICI_MESH_SHAPE = None\n\
PERCORE_BATCH_SIZE = 1\n\
FPROP_DTYPE = jnp.float32\n' >> /opt/paxml/paxml/tasks/lm/params/nvidia.py
ENV IREE_PJRT_COMPILER_LIB_PATH="/usr/local/lib/python3.10/dist-packages/iree/compiler/_mlir_libs/libIREECompiler.so"
ENV IREE_PJRT_PARTITIONER_LIB_PATH="/workspace/openxla-pjrt-plugin/bazel-bin/partitioner/libOpenXLAPartitioner.so"
ENV PJRT_NAMES_AND_LIBRARY_PATHS="iree_cpu:/workspace/openxla-pjrt-plugin/bazel-bin/iree/integrations/pjrt/cpu/pjrt_plugin_iree_cpu.so,iree_cuda:/workspace/openxla-pjrt-plugin/bazel-bin/iree/integrations/pjrt/cuda/pjrt_plugin_iree_cuda.so"
ENV IREE_CUDA_DEPS_DIR="/usr/local/cuda"
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu
WORKDIR /opt/paxml/paxml
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment