Skip to content

Instantly share code, notes, and snippets.

@JakeColor
Created March 13, 2021 16:42
Show Gist options
  • Save JakeColor/d2a25749bf0f51f74919b918797a4f66 to your computer and use it in GitHub Desktop.
Save JakeColor/d2a25749bf0f51f74919b918797a4f66 to your computer and use it in GitHub Desktop.
PyTorch Lightning MetricTrackerCallback
from pytorch_lightning.callbacks.base import Callback
from torch import lt, gt
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class MetricTrackerCallback(Callback):
"""A callback outline that takes action each time a tracked variable improves."""
def __init__(self, metric: str, mode: str = 'min'):
"""
:param metric: the metric that should be tracked
:param mode: the direction of improvement. 'min' means a smaller metric value is better
"""
super().__init__()
if mode not in ['min', 'max']:
raise MisconfigurationException(f"mode argument must be either 'min' or 'max'")
self.mode = mode
self.metric = metric
self.metric_op = {"min": lt, "max": gt}[self.mode]
self.best = None
def do_something(self):
"""The action to take when the metric improves"""
pass
def on_validation_epoch_end(self, trainer, pl_module):
mets = trainer.callback_metrics
current_metric_value = mets[self.metric]
current_better_than_best = not self.best or \
self.metric_op(current_metric_value, self.best).item()
if current_better_than_best:
self.do_something()
self.best = current_metric_value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment