Skip to content

Instantly share code, notes, and snippets.

@saurabh-kataria
Forked from sgraaf/ddp_example.py
Created February 21, 2021 01:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saurabh-kataria/4d56f8298314afcc3282f74a4ec1f288 to your computer and use it in GitHub Desktop.
Save saurabh-kataria/4d56f8298314afcc3282f74a4ec1f288 to your computer and use it in GitHub Desktop.
PyTorch Distributed Data Parallel (DDP) example
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from argparse import ArgumentParser
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers import BertForMaskedLM
SEED = 42
BATCH_SIZE = 8
NUM_EPOCHS = 3
class YourDataset(Dataset):
def __init__(self):
pass
def main():
parser = ArgumentParser('DDP usage example')
parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.') # you need this argument in your scripts for DDP to work
args = parser.parse_args()
# keep track of whether the current process is the `master` process (totally optional, but I find it useful for data laoding, logging, etc.)
args.is_master = args.local_rank == 0
# set the device
args.device = torch.cuda.device(args.local_rank)
# initialize PyTorch distributed using environment variables (you could also do this more explicitly by specifying `rank` and `world_size`, but I find using environment variables makes it so that you can easily use the same script on different machines)
dist.init_process_group(backend='nccl', init_method='env://')
torch.cuda.set_device(args.local_rank)
# set the seed for all GPUs (also make sure to set the seed for random, numpy, etc.)
torch.cuda.manual_seed_all(SEED)
# initialize your model (BERT in this example)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
# send your model to GPU
model = model.to(device)
# initialize distributed data parallel (DDP)
model = DDP(
model,
device_ids=[args.local_rank],
output_device=args.local_rank
)
# initialize your dataset
dataset = YourDataset()
# initialize the DistributedSampler
sampler = DistributedSampler(dataset)
# initialize the dataloader
dataloader = DataLoader(
dataset=dataset,
sampler=sampler,
batch_size=BATCH_SIZE
)
# start your training!
for epoch in range(NUM_EPOCHS):
# put model in train mode
model.train()
# let all processes sync up before starting with a new epoch of training
dist.barrier()
for step, batch in enumerate(dataloader):
# send batch to device
batch = tuple(t.to(args.device) for t in batch)
# forward pass
outputs = model(*batch)
# compute loss
loss = outputs[0]
# etc.
if __name__ == '__main__':
main()
#!/bin/bash
# this example uses a single node (`NUM_NODES=1`) w/ 4 GPUs (`NUM_GPUS_PER_NODE=4`)
export NUM_NODES=1
export NUM_GPUS_PER_NODE=4
export NODE_RANK=0
export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
# launch your script w/ `torch.distributed.launch`
python -m torch.distributed.launch \
--nproc_per_node=$NUM_GPUS_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank $NODE_RANK \
ddp_example.py \
# include any arguments to your script, e.g:
# --seed 42
# etc.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment