Skip to content

Instantly share code, notes, and snippets.

@eminorhan
Last active May 7, 2021 09:19
Show Gist options
  • Save eminorhan/316beacb4dd61303f88becb587805060 to your computer and use it in GitHub Desktop.
Save eminorhan/316beacb4dd61303f88becb587805060 to your computer and use it in GitHub Desktop.
A minimal example demonstrating how to do multi-node distributed training with pytorch on a slurm cluster

The following code is intentionally skeletal. Please feel free to flesh out the details according to your own needs.

import os
import builtins
import argparse
import torch
import torch.distributed as dist
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

parser = argparse.ArgumentParser(description='Train a model on multiple machines')
parser.add_argument('data', metavar='DIR', help='path to data')
parser.add_argument('--epochs', default=100, type=int, help='number of training epochs')
parser.add_argument('--batch_size', default=8, type=int, help='batch size per GPU')
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use')
parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed 
                    training')
parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
parser.add_argument('--local-rank', default=-1, type=int, help='local rank for distributed training')

args = parser.parse_args()

if "WORLD_SIZE" in os.environ:
    args.world_size = int(os.environ["WORLD_SIZE"])
    
args.distributed = args.world_size > 1
ngpus_per_node = torch.cuda.device_count()

if args.distributed:
    if args.local_rank != -1:  # for torch.distributed.launch
        args.rank = args.local_rank
        args.gpu = args.local_rank
    elif "SLURM_PROCID" in os.environ:  # for slurm scheduler
        args.rank = int(os.environ["SLURM_PROCID"])
        args.gpu = args.rank % torch.cuda.device_count()
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 
                            world_size=args.world_size, rank=args.rank)

# suppress printing if not on master process
if args.rank!=0:
    def print_pass(*args):
        pass
    builtins.print = print_pass

# prepare training data, note that we need to use a DistributedSampler, so shuffling must be set to False
train_dataset = ImageFolder(args.data)
train_sampler = DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, 
                          sampler=train_sampler)

# model, optimizer, criterion, etc.
model = ...
optimizer = ...
criterion = ...

if args.distributed:
    # For multiprocessing distributed, DDP constructor should always set the single device scope, 
    # otherwise DDP will use all available devices.
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model.cuda(args.gpu)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
        model_without_ddp = model.module
else:
    raise NotImplementedError("Only DistributedDataParallel is supported.")

# training loop
for epoch in range(args.epochs):
    if args.distributed:
        train_sampler.set_epoch(epoch)

    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch, args)  # implement your training loop here

    # save model on the master process
    if args.rank == 0:
        torch.save({'model_state_dict': model_without_ddp.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict()}, 'chkpt.pt')

Assuming this is your main training script train.py, run the following slurm batch script for multi-node training:

#!/bin/bash

#SBATCH --account=cds
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:rtx8000:4
#SBATCH --cpus-per-task=4
#SBATCH --mem=320GB
#SBATCH --time=48:00:00
#SBATCH --array=0
#SBATCH --job-name=train
#SBATCH --output=train_%A_%a.out

## change WORLD_SIZE as gpus/node * num_nodes
export MASTER_ADDR=$(hostname -s)
export MASTER_PORT=$(shuf -i 10000-65500 -n 1)
export WORLD_SIZE=8

module purge
module load cuda/11.1.74

## note that batch_size below is per GPU batch size 
srun python -u train.py /DATA/DIR --batch_size 8

echo "Done"

Here, we have requested 2 nodes (or machines) with 4 Titan RTX8000 GPUs on each (8 GPUs in total). You can change this configuration according to your needs. Please keep in mind ntasks-per-node should be the same as the number of GPUs you request in gres=gpu:rtx8000:X. Also, if you change the total number of processes requested (i.e. the total number of GPUs), be sure to change the WORLD_SIZE variable accordingly as well.

On the NYU Greene cluster, there is currently a limit of 16 GPUs per user, although I understand that this limit is dynamic and can change depending on demand. This means that you should be able to request up to 4 nodes with 4 GPUs on each for a total of 16 GPUs. I have, however, never tried requesting this many GPUs, as I suspect it would take quite a bit of time to get the requested resources in this case. 2 nodes and 8 GPUs total is a much more feasible configuration in my experience. I have usually been able to get this particular configuration within 12 hours even during times of peak demand.

Acknowledgements: I benefited greatly from this similar gist by Tengda Han in preparing this gist.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment