Skip to content

Instantly share code, notes, and snippets.

Last active July 22, 2024 17:55
Show Gist options
  • Save TengdaHan/1dd10d335c7ca6f13810fff41e809904 to your computer and use it in GitHub Desktop.
Save TengdaHan/1dd10d335c7ca6f13810fff41e809904 to your computer and use it in GitHub Desktop.
Multi-node-training on slurm with PyTorch

Multi-node-training on slurm with PyTorch

What's this?

  • A simple note for how to start multi-node-training on slurm scheduler with PyTorch.
  • Useful especially when scheduler is too busy that you cannot get multiple GPUs allocated, or you need more than 4 GPUs for a single job.
  • Requirement: Have to use PyTorch DistributedDataParallel(DDP) for this purpose.
  • Warning: might need to re-factor your own code.
  • Warning: might be secretly condemned by your colleagues because using too many GPUs.

Setup python script

  • create a file for example:
import os
import builtins
import argparse
import torch
import numpy as np 
import random
import torch.distributed as dist

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--net', default='resnet18', type=str)
    parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--batch_size', default=16, type=int, help='batch size per GPU')
    parser.add_argument('--gpu', default=None, type=int)
    parser.add_argument('--start_epoch', default=0, type=int, 
                        help='start epoch number (useful on restarts)')
    parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run')
    # DDP configs:
    parser.add_argument('--world-size', default=-1, type=int, 
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int, 
                        help='node rank for distributed training')
    parser.add_argument('--dist-url', default='env://', type=str, 
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, 
                        help='distributed backend')
    parser.add_argument('--local_rank', default=-1, type=int, 
                        help='local rank for distributed training')
    args = parser.parse_args()
    return args
def main(args):
    # DDP setting
    if "WORLD_SIZE" in os.environ:
        args.world_size = int(os.environ["WORLD_SIZE"])
    args.distributed = args.world_size > 1
    ngpus_per_node = torch.cuda.device_count()

    if args.distributed:
        if args.local_rank != -1: # for torch.distributed.launch
            args.rank = args.local_rank
            args.gpu = args.local_rank
        elif 'SLURM_PROCID' in os.environ: # for slurm scheduler
            args.rank = int(os.environ['SLURM_PROCID'])
            args.gpu = args.rank % torch.cuda.device_count()
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # suppress printing if not on master gpu
    if args.rank!=0:
        def print_pass(*args):
        builtins.print = print_pass
    ### model ###
    model = MyModel()
    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
            model_without_ddp = model.module
            model = torch.nn.parallel.DistributedDataParallel(model)
            model_without_ddp = model.module
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    ### optimizer ###
    optimizer = torch.optim.Adam(model.parameters(),, weight_decay=1e-5)
    ### resume training if necessary ###
    if args.resume:
    ### data ###
    train_dataset = MyDataset(mode='train')
    train_sampler = data.distributed.DistributedSampler(dataset, shuffle=True)
    train_loader =
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    val_dataset = MyDataset(mode='val')
    val_sampler = None
    val_loader =
            val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=True)
    torch.backends.cudnn.benchmark = True
    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        # fix sampling seed such that each gpu gets different part of dataset
        if args.distributed: 
        # adjust lr if needed #
        train_one_epoch(train_loader, model, criterion, optimizer, epoch, args)
        if args.rank == 0: # only val and save on master node
            validate(val_loader, model, criterion, epoch, args)
            # save checkpoint if needed #

def train_one_epoch(train_loader, model, criterion, optimizer, epoch, args):
    # only one gpu is visible here, so you can send cpu data to gpu by 
    # input_data = input_data.cuda() as normal
def validate(val_loader, model, criterion, epoch, args):

if __name__ == '__main__':
    args = parse_args()
  • this script is already executable on single node (e.g. slurm's interactive mode by salloc, e.g. with 2 GPUs) by
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
--nproc_per_node=2 --net resnet18 \
--lr 1e-3 --epochs 50 --other_args
  • alternatively it can be executed with slurm, see below

Setup slurm script

  • create a file as follows:
#SBATCH --job-name=your-job-name
#SBATCH --partition=gpu
#SBATCH --time=72:00:00

### e.g. request 4 nodes with 1 gpu each, totally 4 gpus (WORLD_SIZE==4)
### Note: --gres=gpu:x should equal to ntasks-per-node
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --constraint=p40&gmem24G
#SBATCH --cpus-per-task=8
#SBATCH --mem=64gb
#SBATCH --chdir=/scratch/shared/beegfs/your_dir/
#SBATCH --output=/scratch/shared/beegfs/your_dir/%x-%j.out

### change 5-digit MASTER_PORT as you wish, slurm will raise Error if duplicated with others
### change WORLD_SIZE as gpus/node * num_nodes
export MASTER_PORT=12340
export WORLD_SIZE=4

### get the first node name as master address - customized for vgg slurm
### e.g. master(gnodee[2-5],gnoded1) == gnodee2
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr

### init virtual environment if needed
source ~/anaconda3/etc/profile.d/
conda activate myenv

### the command to run
srun python --net resnet18 \
--lr 1e-3 --epochs 50 --other_args
  • run command in cluster sbatch

Reference & Acknowledgement

Copy link

Is there any specific reason to call this train_one_epoch function for every epoch ?
can I call this train function at once , shouldn't have any issues with that right ?

My concern is that calling after every epoch could create an additional overhead

Copy link

@rakshith291 It doesn't matter, you can write your own loop. I haven't tested the additional overhead and that's not the main point of this gist.

Copy link

royvelich commented Mar 13, 2022

Thanks for sharing. I have a question, why does your example doesn't call mp.spawn like I see in other examples, like here
Where exactly in the code the main process is forked into multiple processes?

Copy link

royvelich commented Mar 14, 2022

Did anyone get this kind of error when running the script on multiple gpus?

/opt/conda/envs/deep-signature/lib/python3.9/site-packages/numpy/core/ RuntimeWarning: Mean of empty slice. return _methods._mean(a, axis=axis, dtype=dtype, /opt/conda/envs/deep-signature/lib/python3.9/site-packages/numpy/core/ RuntimeWarning: invalid value encountered in double_scalars ret = ret.dtype.type(ret / rcount)

The script works with no problem when I run it on a single GPU.

Copy link

Ieremie commented Mar 22, 2022

@royvelich The slrum script uses 'srun' to run the python script. This means that the python file is run on each separate process and there is no need to use the spawn function.

see: for more details

Copy link

mesllo commented Apr 2, 2022

This has been really helpful and easy to follow but unfortunately, I have not succeeded yet. I'm trying to implement this on a University supercomputer where I'm logging in via ssh using port 22. When I set MASTER_PORT=12340 or some other number on the SLURM script, I obviously get no response since there's nothing happening on it. If I set MASTER_PORT=22, I get a permission denied when the code reaches the dist.init_process_group() method, specifically:

Traceback (most recent call last):
  File "", line 262, in <module>
  File "", line 220, in main
    world_size=opt.world_size, rank=opt.rank)
  File "/home/miniconda3/envs/vit/lib/python3.7/site-packages/torch/distributed/", line 595, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
  File "/home/miniconda3/envs/vit/lib/python3.7/site-packages/torch/distributed/", line 232, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
  File "/home/miniconda3/envs/vit/lib/python3.7/site-packages/torch/distributed/", line 161, in _create_c10d_store
    hostname, port, world_size, start_daemon, timeout, multi_tenant=True
RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:22 (errno: 13 - Permission denied). The server socket has failed to bind to (errno: 13 - Permission denied).

What I have tried to do is rerouting the port 22 traffic to some other port (eg. 65000) but I also get permission denied for even attempting this rerouting. I'm not sure what else I can try to do at this point, anyone has any suggestions?

Copy link

mesllo commented Apr 4, 2022

@TengdaHan What do we have to do if we are in the ssh port 22? I get Permission Denied if I specify MASTER_PORT=22.

Copy link

likejazz commented Jun 6, 2022

torchrun provides a superset of the functionality as torch.distributed.launch with the additional functionalities. --use_env is now deprecated.

Copy link

Hello @mesllo , I have exactly the same issue. Did you manage to find a solution yet?

Copy link

@hendriklohse @mesllo
Hi both, from my experience, the MASTER_PORT has nothing to do with the ssh port. The MASTER_PORT is for GPUs to communicate with each other, not for ssh to outside.
Typically you can just choose a five-digit number.

Copy link

@TengdaHan Thank you for the code! Is there a reason why destroy_process_group() wasn't used?

Copy link

@TengdaHan I think it might be good to add torch.nn.SyncBatchNorm.convert_sync_batchnorm(model), as the current implementation means that the batch statistics won't get shared across the ranks.

Copy link

Regarding the num_workers of the Dataloaders which value is better for our slurm configuration? I'm asking this since I saw other article that suggest to set the num_workers = int(os.environ["SLURM_CPUS_PER_TASK"]) however in my case if I do this the training time increase exponentially respect to not setting the dataloader workers (so leaving equal to 0), but on the other hand setting this results in having both gpus that syncghronously run each epoch.

Whereas again by not setting the num_worker I got a mixing of training of gpus for a sequence of epochs each, like if the training were not synchronized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment