Skip to content

Instantly share code, notes, and snippets.

@neel04
Last active March 2, 2025 11:16
Show Gist options
  • Save neel04/9c26d460793466187b5dd8ffb2e4d90b to your computer and use it in GitHub Desktop.
Save neel04/9c26d460793466187b5dd8ffb2e4d90b to your computer and use it in GitHub Desktop.
#!/bin/bash -e
BRANCH="dev"
# Create and mount ramdisk
RAMDISK_SIZE=400G
RAMDISK_PATH="/workspace"
rm -rf /workspace
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"
# Check if ramdisk is already mounted
if ! mountpoint -q "$RAMDISK_PATH"; then
# Create the directory if it doesn't exist
sudo mkdir -p "$RAMDISK_PATH"
# Mount a tmpfs ramdisk
sudo mount -t tmpfs -o size=$RAMDISK_SIZE tmpfs "$RAMDISK_PATH"
# Set appropriate permissions
sudo chown $(whoami):$(whoami) "$RAMDISK_PATH"
sudo chmod 755 "$RAMDISK_PATH"
fi
# Export environment variables pointing to the ramdisk
export HF_HOME="$RAMDISK_PATH/huggingface"
export HF_DATASETS_CACHE="$RAMDISK_PATH/huggingface_datasets"
# Ensure cache directories exist
mkdir -p "$HF_HOME"
mkdir -p "$HF_DATASETS_CACHE"
# Other environment variables
export jax_threefry_partitionable=1
export WANDB_API_KEY=78c7285b02548bf0c06dca38776c08bb6018593f
export HF_TOKEN=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
export JAX_TRACEBACK_FILTERING=off
# arguments for train_model.py
TRAIN_ARGS="--save_dir ./ReAct/outputs/ --dataset owt --group owt_repro --exp_logging \
--log_interval 1500 --save_interval 10000 --seqlen 512 --num_classes 50304 \
--num_blocks 13 --width 1024 --n_heads 8 --epochs 1 --max_iters 3 \
--batch_size 512 --accum_steps 1 --warmup_steps 1000 \
--lr 9e-4 --beta_1 0.9 --beta_2 0.98 --nesterov \
--weight_decay 3e-3 --drop_rate 0.00 \
--tune_hyperparams --sweep_metadata _NDR --resume"
git clone -b $BRANCH https://github.com/neel04/ReAct_Jax.git
FLAG_FILE="./env_flag"
git config --global safe.directory '*'
cd ReAct_Jax/
git pull --all
cd ..
if [ ! -f "$FLAG_FILE" ]; then
echo "Setting up environment..."
sudo apt-get update -y
sudo apt-get install neovim tmux -y
# Set default python to python3
sudo ln -sf /usr/bin/python3 /usr/bin/python
# Create virtual environment
pip3 install uv
source ~/.profile
uv venv 'main_env' --python 3.11
source main_env/bin/activate
uv pip install --no-cache-dir "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --prerelease allow
uv pip install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
uv pip install -U optuna-integration plotly lm-eval pdbpp
uv pip install git+https://github.com/deepmind/jmp
uv pip install git+https://github.com/Findus23/jax-array-info.git
uv pip install -q tensorflow tensorboard-plugin-profile etils importlib_resources "cloud-tpu-profiler>=2.3.0"
# ------------------
# Create the flag file
touch "$FLAG_FILE"
else
echo "Reusing existing venv..."
fi
echo "Executing train_model.py"
source main_env/bin/activate
echo "Executing train_model.py inside uv venv..."
cd ReAct_Jax/
python3 train_model.py $TRAIN_ARGS
echo "Finished training!"
@neel04
Copy link
Author

neel04 commented Feb 6, 2025

52.35

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment