Skip to content

Instantly share code, notes, and snippets.

@odats
Created March 16, 2020 18:17
Show Gist options
  • Save odats/58a6d16f606054be70df0882479d8ac6 to your computer and use it in GitHub Desktop.
Save odats/58a6d16f606054be70df0882479d8ac6 to your computer and use it in GitHub Desktop.
from ignite.metrics import Metric
from ignite.exceptions import NotComputableError
# These decorators helps with distributed settings
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced
# Based on https://pytorch.org/ignite/metrics.html#how-to-create-a-custom-metric
# Can be impoved by TopK https://pytorch.org/ignite/metrics.html#ignite.metrics.TopKCategoricalAccuracy
class PixelToPixelAccuracy(Metric):
def __init__(self, output_transform=lambda x: x, device=None):
self._num_correct = None
self._num_examples = None
super(PixelToPixelAccuracy, self).__init__(output_transform=output_transform, device=device)
@reinit__is_reduced
def reset(self):
self._num_correct = 0
self._num_examples = 0
super(PixelToPixelAccuracy, self).reset()
@reinit__is_reduced
def update(self, output):
y_pred_all, y_all = output
for i in range(y_pred_all.shape[0]):
y = y_all[i]
y_pred = y_pred_all[i]
indices = torch.argmax(y_pred, dim=0)
self._num_correct += 1 if torch.all(torch.eq(indices, y)) else 0
self._num_examples += 1
@sync_all_reduce("_num_examples", "_num_correct")
def compute(self):
if self._num_examples == 0:
raise NotComputableError('PixelToPixelAccuracy must have at least one example before it can be computed.')
return self._num_correct / self._num_examples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment