Last active
March 12, 2021 21:38
-
-
Save seo-95/f3f95553dfe81f7e780be304449f7879 to your computer and use it in GitHub Desktop.
DDP Memory inspection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import torch.nn as nn | |
from torch.utils.data import DataLoader, Dataset | |
class MyTrainer(): | |
def __init__(self, gpus_n, nodes_n, node_rank): | |
self._gpus_n = gpus_n | |
self._nodes_n = nodes_n | |
self._node_rank = node_rank | |
self._world_size = gpus_n * nodes_n | |
self._batch_size = 128 | |
self._workers = 1 | |
self._epochs = 15 | |
def train(self): | |
assert self._world_size > 1 and torch.cuda.is_available() | |
mp.spawn(self._distributed_training, | |
nprocs=self._gpus_n) | |
def _distributed_training(self, gpu): | |
#set gpu and rank | |
self._rank = self._node_rank * self._gpus_n + gpu | |
self._device = torch.device('cuda:{}'.format(gpu)) | |
torch.cuda.set_device(gpu) | |
self._model = MyModel() | |
self._model.cuda() | |
#initialize multiprocessing | |
dist.init_process_group(backend='nccl', init_method='env://', world_size=self._world_size, rank=self._rank) | |
self._model = torch.nn.parallel.DistributedDataParallel(self._model, device_ids=[gpu]) | |
#prepare data loaders | |
params = {'batch_size' : self._batch_size//self._world_size, | |
'shuffle' : False, #DistributedDataParallel is not compatible with shuffle | |
'num_workers': self._workers, | |
'pin_memory' : True} | |
self._trainset = MyDataset() | |
self._devset = MyDataset() | |
self._tr_sampler = torch.utils.data.distributed.DistributedSampler(self._trainset, num_replicas=self._world_size, rank=self._rank) | |
self._dev_sampler = torch.utils.data.distributed.DistributedSampler(self._devset, num_replicas=self._world_size, rank=self._rank) | |
self._trloader = DataLoader(self._trainset, **params, sampler=self._tr_sampler, collate_fn=self._model.module.collate_fn) | |
self._devloader = DataLoader(self._devset, **params, sampler=self._dev_sampler, collate_fn=self._model.module.collate_fn) | |
self.train_loop() | |
def train_loop(self): | |
for epoch in range(self._epochs): | |
self._tr_sampler.set_epoch(epoch) | |
self._model.train() | |
if self._rank == 0: | |
print('Current epoch: {}'.format(epoch)) | |
for batch in self._trloader: | |
batch = batch.to(self._device, non_blocking=True) | |
out = self._model(batch) | |
#.. loss and optimizer stuffs ... | |
self._model.eval() | |
with torch.no_grad(): | |
for batch in self._devloader: | |
batch = batch.to(self._device, non_blocking=True) | |
out = self._model(batch) | |
class MyModel(nn.Module): | |
def __init__(self): | |
super(MyModel, self).__init__() | |
self.net = nn.Sequential(*[nn.Linear(700, 700) for _ in range(50)]) | |
def forward(self, input): | |
return self.net(input) | |
def collate_fn(self, batch): | |
#here prepare batch ... | |
return torch.stack(batch) | |
class MyDataset(Dataset): | |
def __init__(self): | |
self.data = torch.randn(1500, 700) | |
def __len__(self): | |
return self.data.shape[0] | |
def __getitem__(self, index): | |
return self.data[index] | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'-n', | |
'--nodes', | |
type=int, | |
) | |
parser.add_argument( | |
'-g', | |
'--gpus', | |
type=int, | |
help='number of gpus per node' | |
) | |
parser.add_argument( | |
'-nr', | |
'--nr', | |
type=int, | |
help='Rank of the node within all the nodes (goes from 0 to nodes-1)' | |
) | |
args = parser.parse_args() | |
trainer = MyTrainer(gpus_n=args.gpus, nodes_n=args.nodes, node_rank=args.nr) | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment