Skip to content

Instantly share code, notes, and snippets.

@Kshitij09
Last active June 20, 2020 12:31
Show Gist options
  • Save Kshitij09/eb95bd6667e3b29c649bf95ebb9e1fc7 to your computer and use it in GitHub Desktop.
Save Kshitij09/eb95bd6667e3b29c649bf95ebb9e1fc7 to your computer and use it in GitHub Desktop.
`PrintCallback` for PytorchLightining to create tabular logs of metrics in Jupyter Notebook
import torch
from pytorch_lightning import Callback
from IPython.display import display, clear_output
import copy
import pandas as pd
def unwrap(x):
if isinstance(x,torch.Tensor):
return x.item()
return x
class PrintCallback(Callback):
def __init__(self):
self.metrics = []
def on_epoch_end(self,trainer,pl_module):
clear_output(wait=True)
metrics_dict = copy.deepcopy(trainer.callback_metrics)
del metrics_dict['loss']
metrics_dict = {k:unwrap(v) for k,v in metrics_dict.items()}
self.metrics.append(metrics_dict)
del metrics_dict
#column-names should be modified as per your usage
metrics_df = pd.DataFrame.from_records(self.metrics,
columns=['epoch',
'train_loss',
'val_loss',
'accuracy',
'f1_score'])
display(metrics_df)
@Kshitij09
Copy link
Author

This creates tabular logs as follows:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment