Last active
March 2, 2025 11:16
-
-
Save neel04/9c26d460793466187b5dd8ffb2e4d90b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash -e | |
BRANCH="dev" | |
# Create and mount ramdisk | |
RAMDISK_SIZE=400G | |
RAMDISK_PATH="/workspace" | |
rm -rf /workspace | |
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache" | |
# Check if ramdisk is already mounted | |
if ! mountpoint -q "$RAMDISK_PATH"; then | |
# Create the directory if it doesn't exist | |
sudo mkdir -p "$RAMDISK_PATH" | |
# Mount a tmpfs ramdisk | |
sudo mount -t tmpfs -o size=$RAMDISK_SIZE tmpfs "$RAMDISK_PATH" | |
# Set appropriate permissions | |
sudo chown $(whoami):$(whoami) "$RAMDISK_PATH" | |
sudo chmod 755 "$RAMDISK_PATH" | |
fi | |
# Export environment variables pointing to the ramdisk | |
export HF_HOME="$RAMDISK_PATH/huggingface" | |
export HF_DATASETS_CACHE="$RAMDISK_PATH/huggingface_datasets" | |
# Ensure cache directories exist | |
mkdir -p "$HF_HOME" | |
mkdir -p "$HF_DATASETS_CACHE" | |
# Other environment variables | |
export jax_threefry_partitionable=1 | |
export WANDB_API_KEY=78c7285b02548bf0c06dca38776c08bb6018593f | |
export HF_TOKEN=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 | |
export JAX_TRACEBACK_FILTERING=off | |
# arguments for train_model.py | |
TRAIN_ARGS="--save_dir ./ReAct/outputs/ --dataset owt --group owt_repro --exp_logging \ | |
--log_interval 1500 --save_interval 10000 --seqlen 512 --num_classes 50304 \ | |
--num_blocks 13 --width 1024 --n_heads 8 --epochs 1 --max_iters 3 \ | |
--batch_size 512 --accum_steps 1 --warmup_steps 1000 \ | |
--lr 9e-4 --beta_1 0.9 --beta_2 0.98 --nesterov \ | |
--weight_decay 3e-3 --drop_rate 0.00 \ | |
--tune_hyperparams --sweep_metadata _NDR --resume" | |
git clone -b $BRANCH https://github.com/neel04/ReAct_Jax.git | |
FLAG_FILE="./env_flag" | |
git config --global safe.directory '*' | |
cd ReAct_Jax/ | |
git pull --all | |
cd .. | |
if [ ! -f "$FLAG_FILE" ]; then | |
echo "Setting up environment..." | |
sudo apt-get update -y | |
sudo apt-get install neovim tmux -y | |
# Set default python to python3 | |
sudo ln -sf /usr/bin/python3 /usr/bin/python | |
# Create virtual environment | |
pip3 install uv | |
source ~/.profile | |
uv venv 'main_env' --python 3.11 | |
source main_env/bin/activate | |
uv pip install --no-cache-dir "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --prerelease allow | |
uv pip install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich | |
uv pip install -U optuna-integration plotly lm-eval pdbpp | |
uv pip install git+https://github.com/deepmind/jmp | |
uv pip install git+https://github.com/Findus23/jax-array-info.git | |
uv pip install -q tensorflow tensorboard-plugin-profile etils importlib_resources "cloud-tpu-profiler>=2.3.0" | |
# ------------------ | |
# Create the flag file | |
touch "$FLAG_FILE" | |
else | |
echo "Reusing existing venv..." | |
fi | |
echo "Executing train_model.py" | |
source main_env/bin/activate | |
echo "Executing train_model.py inside uv venv..." | |
cd ReAct_Jax/ | |
python3 train_model.py $TRAIN_ARGS | |
echo "Finished training!" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
52.35