Created
April 16, 2020 16:17
-
-
Save alexeykarnachev/47de06b93a717ab0664eded42ed2826a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import gc | |
import resource | |
import pytorch_lightning as pl | |
import pytorch_lightning.loggers | |
import torch | |
import torch.nn as nn | |
import torch.optim | |
import torch.utils.data | |
def get_num_of_tensors(): | |
tensors_num = 0 | |
sizes = collections.Counter() | |
for obj in gc.get_objects(): | |
try: | |
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): | |
tensors_num += 1 | |
sizes[obj.size()] += 1 | |
except: | |
pass | |
res = sizes[torch.Size([])] | |
return res | |
def get_cpu_mem(): | |
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self._losses = nn.ParameterList( | |
[nn.Parameter(torch.tensor(1.0), requires_grad=True) | |
for _ in range(1000)] | |
) | |
def forward(self, data): | |
return list(self._losses) | |
class PLModule(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self._model = Model() | |
def forward(self, data): | |
losses = self._model(data) | |
return losses | |
def training_step(self, batch, batch_idx): | |
losses = self.forward(batch) | |
num_of_tensors = get_num_of_tensors() | |
log = {'Num-of-tensors': num_of_tensors, 'Cpu-mem-usg': get_cpu_mem()} | |
for i, loss in enumerate(losses): | |
log[f'loss{i}'] = loss | |
print(num_of_tensors) | |
return {'loss': losses[0], 'log': log} | |
def validation_step(self, batch, batch_idx): | |
losses = self.forward(batch) | |
return {'val_loss': losses[0]} | |
def validation_epoch_end(self, outputs): | |
loss = torch.stack([x['val_loss'] for x in outputs]).mean() | |
log = {'Loss/valid': loss, } | |
return {'val_loss': loss, 'log': log} | |
@staticmethod | |
def _get_dl(): | |
ds = torch.utils.data.TensorDataset(torch.tensor(list(range(1000)))) | |
return torch.utils.data.DataLoader(ds, batch_size=4) | |
def train_dataloader(self): | |
return self._get_dl() | |
def val_dataloader(self): | |
return self._get_dl() | |
def configure_optimizers(self): | |
parameters = self._model.parameters() | |
optimizer = torch.optim.SGD(parameters, lr=1e-10) | |
return optimizer | |
class Logger(pytorch_lightning.loggers.TensorBoardLogger): | |
def log_metrics(self, metrics, step) -> None: | |
for k, v in metrics.items(): | |
if 'loss' not in k: | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
self.experiment.add_scalar(k, v, step) | |
def main(): | |
tb_logger_callback = Logger( | |
save_dir='./tb_logs', | |
name='gpu-mem-leak' | |
) | |
trainer = pl.Trainer( | |
log_gpu_memory='all', | |
gpus=[2], | |
logger=tb_logger_callback, | |
amp_level='O2', | |
precision=16 | |
) | |
module = PLModule() | |
trainer.fit(module) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment