Last active
June 5, 2023 22:43
-
-
Save trevor-m/b804620933ae6fc1e2df7068d9d1aa8a to your computer and use it in GitHub Desktop.
PAXML + PJRT Container
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Build using `docker build -t pjrt .` | |
# Inside container, run using: | |
# 1 device | |
# CUDA_VISIBLE_DEVICES=0 JAX_PLATFORMS=iree_cuda python main.py --exp=tasks.lm.params.nvidia.NVIDIA1_3BPmap --job_log_dir=log_NVIDIA1_3BPmap | |
# 2 devices | |
# `CUDA_VISIBLE_DEVICES=0,1 JAX_PLATFORMS=iree_cuda mpirun --allow-run-as-root -np 2 python main.py --exp=tasks.lm.params.nvidia.NVIDIA1_3BPmap --job_log_dir=log_NVIDIA1_3BPmap --mode=eval --multiprocess_gpu` | |
FROM ghcr.io/nvidia/pax:nightly-2023-05-23 | |
SHELL [ "/bin/bash", "-c" ] | |
ENV CUDA_SDK_DIR=/usr/local/cuda | |
RUN apt-get update \ | |
&& apt-get install -y openmpi-bin openmpi-common libopenmpi-dev curl rsync vim | |
# Install bazel | |
RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \ | |
&& chmod +x bazelisk-linux-amd64 \ | |
&& mv bazelisk-linux-amd64 /usr/local/bin/bazel | |
# Build PJRT. | |
RUN mkdir /workspace && cd /workspace \ | |
&& git clone -b use-mpi https://github.com/okkwon/openxla-pjrt-plugin.git \ | |
&& cd openxla-pjrt-plugin \ | |
&& python ./sync_deps.py \ | |
&& python -m pip install -U -r requirements.txt \ | |
&& python ./configure.py --cc=clang --cxx=clang++ --cuda-sdk-dir=$CUDA_SDK_DIR \ | |
&& source .env.sh \ | |
&& bazel build iree/integrations/pjrt/... partitioner/... | |
# Install NCCL 2.18.1 (required by IREE). | |
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/libnccl2_2.18.1-1+cuda12.1_amd64.deb \ | |
&& dpkg -i libnccl2_2.18.1-1+cuda12.1_amd64.deb && rm libnccl2_2.18.1-1+cuda12.1_amd64.deb | |
# Temporary patch for mhlo.rng_bit_generator | |
RUN sed -i "s/jax_default_prng_impl', 'rbg'/jax_default_prng_impl', 'threefry2x32'/" /opt/praxis/praxis/py_utils.py | |
# Temporary fix for jax multiclient initialization with MPI | |
RUN perl -0777 -i -pe 's/jax.distributed.initialize\(jax_distributed_options.coordinator_address,\n[ \t]+jax_distributed_options.num_processes,\n[ \t]+jax_distributed_options.process_id\)/jax.distributed.initialize\(\)/' /opt/paxml/paxml/setup_jax.py | |
# Add a pmap config to paxml | |
RUN echo $'\n@experiment_registry.register\n\ | |
class NVIDIA1_3BPmap(NVIDIA1_3B):\n\ | |
DCN_MESH_SHAPE = None\n\ | |
ICI_MESH_SHAPE = None\n\ | |
PERCORE_BATCH_SIZE = 1\n\ | |
FPROP_DTYPE = jnp.float32\n' >> /opt/paxml/paxml/tasks/lm/params/nvidia.py | |
ENV IREE_PJRT_COMPILER_LIB_PATH="/usr/local/lib/python3.10/dist-packages/iree/compiler/_mlir_libs/libIREECompiler.so" | |
ENV IREE_PJRT_PARTITIONER_LIB_PATH="/workspace/openxla-pjrt-plugin/bazel-bin/partitioner/libOpenXLAPartitioner.so" | |
ENV PJRT_NAMES_AND_LIBRARY_PATHS="iree_cpu:/workspace/openxla-pjrt-plugin/bazel-bin/iree/integrations/pjrt/cpu/pjrt_plugin_iree_cpu.so,iree_cuda:/workspace/openxla-pjrt-plugin/bazel-bin/iree/integrations/pjrt/cuda/pjrt_plugin_iree_cuda.so" | |
ENV IREE_CUDA_DEPS_DIR="/usr/local/cuda" | |
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu | |
WORKDIR /opt/paxml/paxml |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment