Created
March 13, 2021 16:42
-
-
Save JakeColor/d2a25749bf0f51f74919b918797a4f66 to your computer and use it in GitHub Desktop.
PyTorch Lightning MetricTrackerCallback
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning.callbacks.base import Callback | |
from torch import lt, gt | |
from pytorch_lightning.utilities.exceptions import MisconfigurationException | |
class MetricTrackerCallback(Callback): | |
"""A callback outline that takes action each time a tracked variable improves.""" | |
def __init__(self, metric: str, mode: str = 'min'): | |
""" | |
:param metric: the metric that should be tracked | |
:param mode: the direction of improvement. 'min' means a smaller metric value is better | |
""" | |
super().__init__() | |
if mode not in ['min', 'max']: | |
raise MisconfigurationException(f"mode argument must be either 'min' or 'max'") | |
self.mode = mode | |
self.metric = metric | |
self.metric_op = {"min": lt, "max": gt}[self.mode] | |
self.best = None | |
def do_something(self): | |
"""The action to take when the metric improves""" | |
pass | |
def on_validation_epoch_end(self, trainer, pl_module): | |
mets = trainer.callback_metrics | |
current_metric_value = mets[self.metric] | |
current_better_than_best = not self.best or \ | |
self.metric_op(current_metric_value, self.best).item() | |
if current_better_than_best: | |
self.do_something() | |
self.best = current_metric_value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment