Skip to content

Instantly share code, notes, and snippets.

@garrett361
Last active May 16, 2024 12:41
Show Gist options
  • Save garrett361/8737600663059af61a1e4e4a11b28c26 to your computer and use it in GitHub Desktop.
Save garrett361/8737600663059af61a1e4e4a11b28c26 to your computer and use it in GitHub Desktop.
torchrun Sunspot
#!/bin/bash -l
# Minimal torchrun-based launch script
# See https://docs.alcf.anl.gov/aurora/data-science/frameworks/pytorch for more recommendations.
# Usage:
#
# # qsub -v SCRIPT_PATH=your_script_path] [ARGS=...] [NPROC_PER_NODE=...] launch_torch.sh
#PBS -A Aurora_deployment
#PBS -l select=1
#PBS -l place=scatter
#PBS -l walltime=00:30:00
#PBS -q workq
#PBS -j oe
#PBS -k doe
#####################################################################
# This block configures the total number of ranks, discovering
# it from PBS variables.
# 12 Ranks per node, if doing rank/tile
#####################################################################
NNODES=`wc -l < $PBS_NODEFILE`
export NPROC_PER_NODE="${NPROC_PER_NODE:-12}"
# TODO: affinities as in https://github.com/pytorch/pytorch/issues/115305#issuecomment-1845957682
# CPU_BIND values recommended by Corey Adams (ANL): https://github.com/coreyjadams/CosmicTagger/blob/master/example_submission_scripts/sunspot/train_pt_single_tile_ddp.sh
# export CPU_BIND="verbose,list:0-7,104-111:8-15,112-119:16-23,120-127:24-31,128-135:32-39,136-143:40-47,144-151:52-59,156-163:60-67,164-171:68-75,172-179:76-83,180-187:84-91,188-195:92-99,196-203"
ulimit -c 0
# Set the MASTER_ADDR by reading from the nodefile
MASTER_ADDR=$(cat $PBS_NODEFILE | head -n 1)
# Randomly chosen port:
MASTER_PORT=29500
# Use pbsdsh to launch on each node.
# Need to pass NPROC_PER_NODE through to the wrapper script which will set CCL_LOCAL_SIZE to this
# value. This is needed for torch run to work w/ CCL. See
# https://github.com/huggingface/accelerate/pull/2339
# Note: attempting to initialize just with --master-addr and --master-port (which also forces
# --rdzv-backend static) leads to a timeout error, for yet-unknown reasons.
pbsdsh -v -- path/to/torchrun_wrapper.sh $NPROC_PER_NODE \
--nnodes $NNODES --nproc-per-node $NPROC_PER_NODE\
--rdzv-backend c10d --rdzv-endpoint ${MASTER_ADDR}:$MASTER_PORT \
$SCRIPT_PATH ${ARGS:-}
"""
Minimal distributed script. Performs an all-gather and checks for correctness.
"""
import os
import intel_extension_for_pytorch as ipex # noqa
import oneccl_bindings_for_pytorch # noqa
import torch
import torch.distributed as dist
RANK = int(os.getenv("RANK", 0))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
MASTER_ADDR = os.getenv("MASTER_ADDR")
MASTER_PORT = os.getenv("MASTER_PORT")
DEVICE = torch.device(f"xpu:{LOCAL_RANK}")
# Env vars needed if launching with torchrun. See
# https://github.com/huggingface/accelerate/pull/2339
os.environ["CCL_LOCAL_RANK"] = str(LOCAL_RANK)
torch.xpu.set_device(DEVICE)
def main() -> None:
t = torch.tensor([RANK], device=DEVICE)
tensor_list = [torch.empty_like(t) for _ in range(WORLD_SIZE)]
dist.all_gather(tensor_list, t)
tensor_list_t = torch.cat(tensor_list, dim=0)
# Print for debugging:
print(f"[{RANK=}]: {t=}, {tensor_list_t=}")
expected_tensor_list_t = torch.arange(WORLD_SIZE, device=DEVICE)
torch.testing.assert_close(tensor_list_t, expected_tensor_list_t)
if __name__ == "__main__":
try:
dist.init_process_group(backend="ccl")
main()
if not RANK:
print("PASSED TEST")
finally:
dist.destroy_process_group()
#!/bin/bash -l
#Note: the -l flag after the bash shebang is necessary to avoid a module: command not found
module use /soft/modulefiles
module load frameworks/2023.12.15.001
# Other exports needed for torch run to work w/ CCL
# See https://github.com/huggingface/accelerate/pull/2339
# Note: must export *after* loading modules, since they overwrite CCL_PROCESS_LAUNCHER to pmix
export CCL_PROCESS_LAUNCHER=none
export CCL_LOCAL_SIZE=$1
shift
# Also change CCL_ATL_TRANSPORT from the mpi default (this would be done automatically, otherwise).
export CCL_ATL_TRANSPORT=ofi
torchrun $@
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment