Created
March 16, 2020 18:17
-
-
Save odats/58a6d16f606054be70df0882479d8ac6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ignite.metrics import Metric | |
from ignite.exceptions import NotComputableError | |
# These decorators helps with distributed settings | |
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced | |
# Based on https://pytorch.org/ignite/metrics.html#how-to-create-a-custom-metric | |
# Can be impoved by TopK https://pytorch.org/ignite/metrics.html#ignite.metrics.TopKCategoricalAccuracy | |
class PixelToPixelAccuracy(Metric): | |
def __init__(self, output_transform=lambda x: x, device=None): | |
self._num_correct = None | |
self._num_examples = None | |
super(PixelToPixelAccuracy, self).__init__(output_transform=output_transform, device=device) | |
@reinit__is_reduced | |
def reset(self): | |
self._num_correct = 0 | |
self._num_examples = 0 | |
super(PixelToPixelAccuracy, self).reset() | |
@reinit__is_reduced | |
def update(self, output): | |
y_pred_all, y_all = output | |
for i in range(y_pred_all.shape[0]): | |
y = y_all[i] | |
y_pred = y_pred_all[i] | |
indices = torch.argmax(y_pred, dim=0) | |
self._num_correct += 1 if torch.all(torch.eq(indices, y)) else 0 | |
self._num_examples += 1 | |
@sync_all_reduce("_num_examples", "_num_correct") | |
def compute(self): | |
if self._num_examples == 0: | |
raise NotComputableError('PixelToPixelAccuracy must have at least one example before it can be computed.') | |
return self._num_correct / self._num_examples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment