Skip to content

Instantly share code, notes, and snippets.

@seo-95
Last active March 12, 2021 21:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save seo-95/f3f95553dfe81f7e780be304449f7879 to your computer and use it in GitHub Desktop.
Save seo-95/f3f95553dfe81f7e780be304449f7879 to your computer and use it in GitHub Desktop.
DDP Memory inspection
import argparse
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
class MyTrainer():
def __init__(self, gpus_n, nodes_n, node_rank):
self._gpus_n = gpus_n
self._nodes_n = nodes_n
self._node_rank = node_rank
self._world_size = gpus_n * nodes_n
self._batch_size = 128
self._workers = 1
self._epochs = 15
def train(self):
assert self._world_size > 1 and torch.cuda.is_available()
mp.spawn(self._distributed_training,
nprocs=self._gpus_n)
def _distributed_training(self, gpu):
#set gpu and rank
self._rank = self._node_rank * self._gpus_n + gpu
self._device = torch.device('cuda:{}'.format(gpu))
torch.cuda.set_device(gpu)
self._model = MyModel()
self._model.cuda()
#initialize multiprocessing
dist.init_process_group(backend='nccl', init_method='env://', world_size=self._world_size, rank=self._rank)
self._model = torch.nn.parallel.DistributedDataParallel(self._model, device_ids=[gpu])
#prepare data loaders
params = {'batch_size' : self._batch_size//self._world_size,
'shuffle' : False, #DistributedDataParallel is not compatible with shuffle
'num_workers': self._workers,
'pin_memory' : True}
self._trainset = MyDataset()
self._devset = MyDataset()
self._tr_sampler = torch.utils.data.distributed.DistributedSampler(self._trainset, num_replicas=self._world_size, rank=self._rank)
self._dev_sampler = torch.utils.data.distributed.DistributedSampler(self._devset, num_replicas=self._world_size, rank=self._rank)
self._trloader = DataLoader(self._trainset, **params, sampler=self._tr_sampler, collate_fn=self._model.module.collate_fn)
self._devloader = DataLoader(self._devset, **params, sampler=self._dev_sampler, collate_fn=self._model.module.collate_fn)
self.train_loop()
def train_loop(self):
for epoch in range(self._epochs):
self._tr_sampler.set_epoch(epoch)
self._model.train()
if self._rank == 0:
print('Current epoch: {}'.format(epoch))
for batch in self._trloader:
batch = batch.to(self._device, non_blocking=True)
out = self._model(batch)
#.. loss and optimizer stuffs ...
self._model.eval()
with torch.no_grad():
for batch in self._devloader:
batch = batch.to(self._device, non_blocking=True)
out = self._model(batch)
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.net = nn.Sequential(*[nn.Linear(700, 700) for _ in range(50)])
def forward(self, input):
return self.net(input)
def collate_fn(self, batch):
#here prepare batch ...
return torch.stack(batch)
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(1500, 700)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
return self.data[index]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-n',
'--nodes',
type=int,
)
parser.add_argument(
'-g',
'--gpus',
type=int,
help='number of gpus per node'
)
parser.add_argument(
'-nr',
'--nr',
type=int,
help='Rank of the node within all the nodes (goes from 0 to nodes-1)'
)
args = parser.parse_args()
trainer = MyTrainer(gpus_n=args.gpus, nodes_n=args.nodes, node_rank=args.nr)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment