Created
September 15, 2020 02:24
-
-
Save svp19/7456f6da5cb5e8b748fdc05821178c13 to your computer and use it in GitHub Desktop.
Playground code for distributed training in PyTorch. While the docs and tutorials out there are great, I felt a simple example like this was much needed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Tutorial Code for distributed training in PyTorch that trains | |
an inception_v3 model on dummy data. | |
*Installation: * | |
Use pip/conda to install the following libraries | |
- torch | |
- torchvision | |
- argparse | |
- tqdm | |
*Run using: * | |
`python torch_distributed.py -g 4 --batch_size 128` | |
where, | |
-g: no. of gpus | |
--batch_size: increase for higher memory usage (default: 128 ~14GB) | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.models as models | |
import torch.multiprocessing as mp | |
import torch.distributed as dist | |
from torch.optim import lr_scheduler | |
from torch.utils.data import Dataset, DataLoader | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import os | |
import argparse | |
from itertools import chain | |
from tqdm import tqdm | |
class ExampleDataset(Dataset): | |
def __init__(self, return_shape=(3, 32, 224, 224), return_len=100, return_target=0): | |
'''Example Dataset for Playground''' | |
self.shape = return_shape | |
self.len = return_len | |
self.target = return_target | |
def __getitem__(self, idx): | |
return torch.rand(self.shape), self.target | |
def __len__(self): | |
return self.len | |
def cleanup(): | |
dist.destroy_process_group() | |
#-------------------------------------------------------------------- | |
#Train | |
def train(gpu, args): | |
#Init process | |
rank = args.nr * args.gpus + gpu | |
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank) | |
#Random seed for distributed | |
torch.manual_seed(0) | |
#Define the model | |
model = models.inception_v3(init_weights=False) | |
# Handle the auxilary net | |
num_ftrs = model.AuxLogits.fc.in_features | |
model.AuxLogits.fc = nn.Linear(num_ftrs, 1000) | |
# Handle the primary net | |
num_ftrs = model.fc.in_features | |
model.fc = nn.Linear(num_ftrs,1000) | |
#Send Model to GPU | |
torch.cuda.set_device(gpu) | |
model.cuda(gpu) | |
# Wrap the model for distribution | |
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) | |
#Dataset | |
dataset = ExampleDataset(return_shape=(3, 299, 299), return_len=10000) | |
#Datasampler | |
sampler = torch.utils.data.distributed.DistributedSampler( | |
dataset, | |
num_replicas=args.world_size, | |
rank=rank | |
) | |
#Dataloader | |
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=sampler) | |
#loss, optimizer and scheduler | |
criterion = nn.CrossEntropyLoss() | |
optimize_parameters = model.parameters() | |
optimizer = optim.SGD(optimize_parameters, lr=0.001, momentum=0.9) | |
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) | |
with torch.set_grad_enabled(True): | |
for e in range(args.epochs): | |
with tqdm(desc='Epoch %d - ' % e, unit='it', total=len(dataloader)) as pbar: | |
for inputs, labels in dataloader: | |
#Shift to gpu | |
inputs = inputs.cuda(non_blocking=True) | |
labels = labels.cuda(non_blocking=True) | |
#Forward | |
outputs, aux_outputs = model(inputs) | |
_, predictions = torch.max(outputs, 1) | |
#Backward, optimize and scheduler step | |
loss1 = criterion(outputs, labels) | |
loss2 = criterion(aux_outputs, labels) | |
loss = loss1 + 0.4*loss2 | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
pbar.update() | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', | |
help='number of data loading workers (default: 4)') | |
parser.add_argument('-g', '--gpus', default=1, type=int, | |
help='number of gpus per node') | |
parser.add_argument('-nr', '--nr', default=0, type=int, | |
help='ranking within the nodes') | |
parser.add_argument('--epochs', default=1, type=int, metavar='N', | |
help='number of total epochs to run') | |
parser.add_argument('--batch_size', default=128, type=int, metavar='N', | |
help='batch_size, increase to use more memory') | |
args = parser.parse_args() | |
args.world_size = args.gpus * args.nodes | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12355' | |
num_gpus_available = torch.cuda.device_count() | |
print("Torch found ", num_gpus_available, " GPUs") | |
if args.gpus > num_gpus_available: | |
msg = 'Could communicate with only ' + str(num_gpus_available) + ' GPU(s), but expected ' + str(args.gpus) | |
raise Exception(msg) | |
mp.spawn(train, nprocs=args.gpus, args=(args,)) | |
if __name__ == '__main__': | |
main() | |
print('Done') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment