Skip to content

Instantly share code, notes, and snippets.

@elistevens
Created April 20, 2020 18:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elistevens/7edacdafdb45747a22da2ef0c6ce1af3 to your computer and use it in GitHub Desktop.
Save elistevens/7edacdafdb45747a22da2ef0c6ce1af3 to your computer and use it in GitHub Desktop.
import datetime
import math
import os
import time
import torch
import torch.distributed
import torch.multiprocessing
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel
#from apex.parallel import DistributedDataParallel
import torchvision
num_nodes = int(os.environ['NODES'])
num_gpus = int(os.environ['GPUS'])
def main(ddp_wrapper=None, sampler_cls=None, gpu_ndx=0):
ds = torchvision.datasets.FakeData(
int(os.environ['EPOCH_SIZE']),
num_classes=100,
transform=torchvision.transforms.ToTensor(),
)
dl = DataLoader(
ds,
batch_size=int(os.environ['BATCH_SIZE']),
num_workers=4,
pin_memory=True,
sampler=sampler_cls(ds) if sampler_cls else None,
)
model = torchvision.models.resnet50()
model = model.to('cuda')
if ddp_wrapper:
model = ddp_wrapper(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
import cProfile, pstats, io
pr = cProfile.Profile()
pr.enable()
start_ts = time.time()
for epoch_ndx in range(1, int(os.environ['EPOCHS']) + 1):
print(datetime.datetime.now(), f"Epoch {epoch_ndx}, dl: {len(dl)}")
for batch_ndx, batch_tup in enumerate(dl):
optimizer.zero_grad()
x, y = batch_tup
x = x.to('cuda')
y = y.to('cuda')
y_hat = model(x)
loss_var = F.cross_entropy(y_hat, y)
loss_var.backward()
optimizer.step()
end_ts = time.time()
pr.disable()
if gpu_ndx == 0:
pr.dump_stats('/tmp/min_profile.out')
# pstats.Stats(pr).sort_stats('cumulative').print_stats()
pstats.Stats(pr).sort_stats('tot').print_stats()
print(datetime.datetime.now(), f"training loop time: {end_ts - start_ts} seconds")
print('\n'.join(
['min ddp', 'cluster']
+ [os.environ[x] for x in ['NODES', 'GPUS', 'BATCH_SIZE', 'EPOCH_SIZE', 'EPOCHS', 'OMP_NUM_THREADS']]
+ [f'{end_ts - start_ts}']
+ [f"{int(os.environ['EPOCH_SIZE']) * int(os.environ['EPOCHS']) / (end_ts - start_ts) / int(os.environ['GPUS'])}"]
+ [f"{int(os.environ['EPOCH_SIZE']) * int(os.environ['EPOCHS']) / (end_ts - start_ts) / int(os.environ['GPUS']) / 1.737005}"]
))
def ddp_spawn(gpu_ndx):
node_rank = 0
rank = num_gpus * node_rank + gpu_ndx
world_size = num_nodes * num_gpus
print(datetime.datetime.now(), f"torch.cuda.set_device({gpu_ndx}); torch.distributed.init_process_group('nccl', rank={rank}, world_size={world_size})")
torch.cuda.set_device(gpu_ndx)
torch.distributed.init_process_group('nccl', rank=rank, world_size=world_size)
main(
ddp_wrapper=lambda m: DistributedDataParallel(m, [gpu_ndx]),
sampler_cls=torch.utils.data.distributed.DistributedSampler,
gpu_ndx=gpu_ndx,
)
if __name__ == '__main__':
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '1234'
torch.multiprocessing.spawn(ddp_spawn, nprocs=num_gpus, args=())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment