Created
March 12, 2021 09:34
-
-
Save janhenriklambrechts/67a0faf5dc05d4e8d4d16973d1c03929 to your computer and use it in GitHub Desktop.
MaxMetric for tracking the maximum value of a scalar or tensor in Pytorch-Lightning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from pytorch_lightning.metrics import Metric | |
from pytorch_lightning.utilities.distributed import gather_all_tensors | |
from pytorch_lightning.metrics import Accuracy | |
class MaxMetric(Metric): | |
""" | |
Pytorch-Lightning Metric that tracks the maximum value of a scalar/tensor across an experiment | |
""" | |
def __init__(self, dist_sync_on_step=False): | |
super().__init__(dist_sync_on_step=dist_sync_on_step) | |
self.add_state("max_val", default=torch.tensor(0)) | |
def _wrap_compute(self, compute): | |
def wrapped_func(*args, **kwargs): | |
# return cached value | |
if self._computed is not None: | |
return self._computed | |
dist_sync_fn = self.dist_sync_fn | |
if ( | |
dist_sync_fn is None | |
and torch.distributed.is_available() | |
and torch.distributed.is_initialized() | |
): | |
# User provided a bool, so we assume DDP if available | |
dist_sync_fn = gather_all_tensors | |
if self._to_sync and dist_sync_fn is not None: | |
self._sync_dist(dist_sync_fn) | |
self._computed = compute(*args, **kwargs) | |
# removed the auto-reset | |
return self._computed | |
return wrapped_func | |
def update(self, val): | |
self.max_val = val if self.max_val < val else self.max_val | |
def compute(self): | |
return self.max_val | |
if __name__ == "__main__": | |
m = MaxMetric() | |
acc = Accuracy() | |
preds_1 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) | |
preds_2 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) | |
preds_3 = torch.Tensor([[0.1, 0.9], [0.8, 0.2]]) | |
labels = torch.Tensor([[0, 1], [0, 1]]).long() | |
acc(preds_1, labels) # acc is 0.5 | |
m(acc.compute()) # max_metrix is 0.5 | |
assert m.compute() == 0.5 | |
acc(preds_2, labels) # acc is 1. | |
m(acc.compute()) # max_metrix is 1. | |
assert m.compute() == 1. | |
acc(preds_3, labels) # acc is 0.5 | |
m(acc.compute()) # max_metrix is 1. | |
assert m.compute() == 1. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment