Skip to content

Instantly share code, notes, and snippets.

@svp19
Created September 15, 2020 02:24
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save svp19/7456f6da5cb5e8b748fdc05821178c13 to your computer and use it in GitHub Desktop.
Save svp19/7456f6da5cb5e8b748fdc05821178c13 to your computer and use it in GitHub Desktop.
Playground code for distributed training in PyTorch. While the docs and tutorials out there are great, I felt a simple example like this was much needed.
'''
Tutorial Code for distributed training in PyTorch that trains
an inception_v3 model on dummy data.
*Installation: *
Use pip/conda to install the following libraries
- torch
- torchvision
- argparse
- tqdm
*Run using: *
`python torch_distributed.py -g 4 --batch_size 128`
where,
-g: no. of gpus
--batch_size: increase for higher memory usage (default: 128 ~14GB)
'''
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import os
import argparse
from itertools import chain
from tqdm import tqdm
class ExampleDataset(Dataset):
def __init__(self, return_shape=(3, 32, 224, 224), return_len=100, return_target=0):
'''Example Dataset for Playground'''
self.shape = return_shape
self.len = return_len
self.target = return_target
def __getitem__(self, idx):
return torch.rand(self.shape), self.target
def __len__(self):
return self.len
def cleanup():
dist.destroy_process_group()
#--------------------------------------------------------------------
#Train
def train(gpu, args):
#Init process
rank = args.nr * args.gpus + gpu
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
#Random seed for distributed
torch.manual_seed(0)
#Define the model
model = models.inception_v3(init_weights=False)
# Handle the auxilary net
num_ftrs = model.AuxLogits.fc.in_features
model.AuxLogits.fc = nn.Linear(num_ftrs, 1000)
# Handle the primary net
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs,1000)
#Send Model to GPU
torch.cuda.set_device(gpu)
model.cuda(gpu)
# Wrap the model for distribution
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
#Dataset
dataset = ExampleDataset(return_shape=(3, 299, 299), return_len=10000)
#Datasampler
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=args.world_size,
rank=rank
)
#Dataloader
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=sampler)
#loss, optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimize_parameters = model.parameters()
optimizer = optim.SGD(optimize_parameters, lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
with torch.set_grad_enabled(True):
for e in range(args.epochs):
with tqdm(desc='Epoch %d - ' % e, unit='it', total=len(dataloader)) as pbar:
for inputs, labels in dataloader:
#Shift to gpu
inputs = inputs.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
#Forward
outputs, aux_outputs = model(inputs)
_, predictions = torch.max(outputs, 1)
#Backward, optimize and scheduler step
loss1 = criterion(outputs, labels)
loss2 = criterion(aux_outputs, labels)
loss = loss1 + 0.4*loss2
loss.backward()
optimizer.step()
scheduler.step()
pbar.update()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('-nr', '--nr', default=0, type=int,
help='ranking within the nodes')
parser.add_argument('--epochs', default=1, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--batch_size', default=128, type=int, metavar='N',
help='batch_size, increase to use more memory')
args = parser.parse_args()
args.world_size = args.gpus * args.nodes
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
num_gpus_available = torch.cuda.device_count()
print("Torch found ", num_gpus_available, " GPUs")
if args.gpus > num_gpus_available:
msg = 'Could communicate with only ' + str(num_gpus_available) + ' GPU(s), but expected ' + str(args.gpus)
raise Exception(msg)
mp.spawn(train, nprocs=args.gpus, args=(args,))
if __name__ == '__main__':
main()
print('Done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment