Skip to content

Instantly share code, notes, and snippets.

Last active July 8, 2024 22:06
Show Gist options
  • Save TengdaHan/1dd10d335c7ca6f13810fff41e809904 to your computer and use it in GitHub Desktop.
Save TengdaHan/1dd10d335c7ca6f13810fff41e809904 to your computer and use it in GitHub Desktop.
Multi-node-training on slurm with PyTorch

Multi-node-training on slurm with PyTorch

What's this?

  • A simple note for how to start multi-node-training on slurm scheduler with PyTorch.
  • Useful especially when scheduler is too busy that you cannot get multiple GPUs allocated, or you need more than 4 GPUs for a single job.
  • Requirement: Have to use PyTorch DistributedDataParallel(DDP) for this purpose.
  • Warning: might need to re-factor your own code.
  • Warning: might be secretly condemned by your colleagues because using too many GPUs.

Setup python script

  • create a file for example:
import os
import builtins
import argparse
import torch
import numpy as np 
import random
import torch.distributed as dist

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--net', default='resnet18', type=str)
    parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--batch_size', default=16, type=int, help='batch size per GPU')
    parser.add_argument('--gpu', default=None, type=int)
    parser.add_argument('--start_epoch', default=0, type=int, 
                        help='start epoch number (useful on restarts)')
    parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run')
    # DDP configs:
    parser.add_argument('--world-size', default=-1, type=int, 
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int, 
                        help='node rank for distributed training')
    parser.add_argument('--dist-url', default='env://', type=str, 
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, 
                        help='distributed backend')
    parser.add_argument('--local_rank', default=-1, type=int, 
                        help='local rank for distributed training')
    args = parser.parse_args()
    return args
def main(args):
    # DDP setting
    if "WORLD_SIZE" in os.environ:
        args.world_size = int(os.environ["WORLD_SIZE"])
    args.distributed = args.world_size > 1
    ngpus_per_node = torch.cuda.device_count()

    if args.distributed:
        if args.local_rank != -1: # for torch.distributed.launch
            args.rank = args.local_rank
            args.gpu = args.local_rank
        elif 'SLURM_PROCID' in os.environ: # for slurm scheduler
            args.rank = int(os.environ['SLURM_PROCID'])
            args.gpu = args.rank % torch.cuda.device_count()
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # suppress printing if not on master gpu
    if args.rank!=0:
        def print_pass(*args):
        builtins.print = print_pass
    ### model ###
    model = MyModel()
    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
            model_without_ddp = model.module
            model = torch.nn.parallel.DistributedDataParallel(model)
            model_without_ddp = model.module
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    ### optimizer ###
    optimizer = torch.optim.Adam(model.parameters(),, weight_decay=1e-5)
    ### resume training if necessary ###
    if args.resume:
    ### data ###
    train_dataset = MyDataset(mode='train')
    train_sampler = data.distributed.DistributedSampler(dataset, shuffle=True)
    train_loader =
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    val_dataset = MyDataset(mode='val')
    val_sampler = None
    val_loader =
            val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=True)
    torch.backends.cudnn.benchmark = True
    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        # fix sampling seed such that each gpu gets different part of dataset
        if args.distributed: 
        # adjust lr if needed #
        train_one_epoch(train_loader, model, criterion, optimizer, epoch, args)
        if args.rank == 0: # only val and save on master node
            validate(val_loader, model, criterion, epoch, args)
            # save checkpoint if needed #

def train_one_epoch(train_loader, model, criterion, optimizer, epoch, args):
    # only one gpu is visible here, so you can send cpu data to gpu by 
    # input_data = input_data.cuda() as normal
def validate(val_loader, model, criterion, epoch, args):

if __name__ == '__main__':
    args = parse_args()
  • this script is already executable on single node (e.g. slurm's interactive mode by salloc, e.g. with 2 GPUs) by
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
--nproc_per_node=2 --net resnet18 \
--lr 1e-3 --epochs 50 --other_args
  • alternatively it can be executed with slurm, see below

Setup slurm script

  • create a file as follows:
#SBATCH --job-name=your-job-name
#SBATCH --partition=gpu
#SBATCH --time=72:00:00

### e.g. request 4 nodes with 1 gpu each, totally 4 gpus (WORLD_SIZE==4)
### Note: --gres=gpu:x should equal to ntasks-per-node
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --constraint=p40&gmem24G
#SBATCH --cpus-per-task=8
#SBATCH --mem=64gb
#SBATCH --chdir=/scratch/shared/beegfs/your_dir/
#SBATCH --output=/scratch/shared/beegfs/your_dir/%x-%j.out

### change 5-digit MASTER_PORT as you wish, slurm will raise Error if duplicated with others
### change WORLD_SIZE as gpus/node * num_nodes
export MASTER_PORT=12340
export WORLD_SIZE=4

### get the first node name as master address - customized for vgg slurm
### e.g. master(gnodee[2-5],gnoded1) == gnodee2
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr

### init virtual environment if needed
source ~/anaconda3/etc/profile.d/
conda activate myenv

### the command to run
srun python --net resnet18 \
--lr 1e-3 --epochs 50 --other_args
  • run command in cluster sbatch

Reference & Acknowledgement

Copy link

mesllo commented Apr 2, 2022

This has been really helpful and easy to follow but unfortunately, I have not succeeded yet. I'm trying to implement this on a University supercomputer where I'm logging in via ssh using port 22. When I set MASTER_PORT=12340 or some other number on the SLURM script, I obviously get no response since there's nothing happening on it. If I set MASTER_PORT=22, I get a permission denied when the code reaches the dist.init_process_group() method, specifically:

Traceback (most recent call last):
  File "", line 262, in <module>
  File "", line 220, in main
    world_size=opt.world_size, rank=opt.rank)
  File "/home/miniconda3/envs/vit/lib/python3.7/site-packages/torch/distributed/", line 595, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
  File "/home/miniconda3/envs/vit/lib/python3.7/site-packages/torch/distributed/", line 232, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
  File "/home/miniconda3/envs/vit/lib/python3.7/site-packages/torch/distributed/", line 161, in _create_c10d_store
    hostname, port, world_size, start_daemon, timeout, multi_tenant=True
RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:22 (errno: 13 - Permission denied). The server socket has failed to bind to (errno: 13 - Permission denied).

What I have tried to do is rerouting the port 22 traffic to some other port (eg. 65000) but I also get permission denied for even attempting this rerouting. I'm not sure what else I can try to do at this point, anyone has any suggestions?

Copy link

mesllo commented Apr 4, 2022

@TengdaHan What do we have to do if we are in the ssh port 22? I get Permission Denied if I specify MASTER_PORT=22.

Copy link

likejazz commented Jun 6, 2022

torchrun provides a superset of the functionality as torch.distributed.launch with the additional functionalities. --use_env is now deprecated.

Copy link

Hello @mesllo , I have exactly the same issue. Did you manage to find a solution yet?

Copy link

@hendriklohse @mesllo
Hi both, from my experience, the MASTER_PORT has nothing to do with the ssh port. The MASTER_PORT is for GPUs to communicate with each other, not for ssh to outside.
Typically you can just choose a five-digit number.

Copy link

@TengdaHan Thank you for the code! Is there a reason why destroy_process_group() wasn't used?

Copy link

@TengdaHan I think it might be good to add torch.nn.SyncBatchNorm.convert_sync_batchnorm(model), as the current implementation means that the batch statistics won't get shared across the ranks.

Copy link

Regarding the num_workers of the Dataloaders which value is better for our slurm configuration? I'm asking this since I saw other article that suggest to set the num_workers = int(os.environ["SLURM_CPUS_PER_TASK"]) however in my case if I do this the training time increase exponentially respect to not setting the dataloader workers (so leaving equal to 0), but on the other hand setting this results in having both gpus that syncghronously run each epoch.

Whereas again by not setting the num_worker I got a mixing of training of gpus for a sequence of epochs each, like if the training were not synchronized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment