Skip to content

Instantly share code, notes, and snippets.

@janhenriklambrechts
Created March 12, 2021 09:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save janhenriklambrechts/67a0faf5dc05d4e8d4d16973d1c03929 to your computer and use it in GitHub Desktop.
Save janhenriklambrechts/67a0faf5dc05d4e8d4d16973d1c03929 to your computer and use it in GitHub Desktop.
MaxMetric for tracking the maximum value of a scalar or tensor in Pytorch-Lightning
import torch
from pytorch_lightning.metrics import Metric
from pytorch_lightning.utilities.distributed import gather_all_tensors
from pytorch_lightning.metrics import Accuracy
class MaxMetric(Metric):
"""
Pytorch-Lightning Metric that tracks the maximum value of a scalar/tensor across an experiment
"""
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("max_val", default=torch.tensor(0))
def _wrap_compute(self, compute):
def wrapped_func(*args, **kwargs):
# return cached value
if self._computed is not None:
return self._computed
dist_sync_fn = self.dist_sync_fn
if (
dist_sync_fn is None
and torch.distributed.is_available()
and torch.distributed.is_initialized()
):
# User provided a bool, so we assume DDP if available
dist_sync_fn = gather_all_tensors
if self._to_sync and dist_sync_fn is not None:
self._sync_dist(dist_sync_fn)
self._computed = compute(*args, **kwargs)
# removed the auto-reset
return self._computed
return wrapped_func
def update(self, val):
self.max_val = val if self.max_val < val else self.max_val
def compute(self):
return self.max_val
if __name__ == "__main__":
m = MaxMetric()
acc = Accuracy()
preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]])
preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]])
preds_3 = torch.Tensor([[0.1, 0.9], [0.8, 0.2]])
labels = torch.Tensor([[0, 1], [0, 1]]).long()
acc(preds_1, labels) # acc is 0.5
m(acc.compute()) # max_metrix is 0.5
assert m.compute() == 0.5
acc(preds_2, labels) # acc is 1.
m(acc.compute()) # max_metrix is 1.
assert m.compute() == 1.
acc(preds_3, labels) # acc is 0.5
m(acc.compute()) # max_metrix is 1.
assert m.compute() == 1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment