Created
December 25, 2018 19:57
-
-
Save vfdev-5/7559dcc1de9ef868531ee5612a652d41 to your computer and use it in GitHub Desktop.
Custom ignite metrics : Accuracy/Precision/Recall with multilabel option
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import torch | |
from ignite.metrics.metric import Metric | |
from ignite.exceptions import NotComputableError | |
class _BaseClassification(Metric): | |
def __init__(self, output_transform=lambda x: x, is_multilabel=False): | |
self._is_multilabel = is_multilabel | |
self._type = None | |
super(_BaseClassification, self).__init__(output_transform=output_transform) | |
def _check_shape(self, output): | |
y_pred, y = output | |
if y.ndimension() > 1 and y.shape[1] == 1: | |
# (N, 1, ...) -> (N, ...) | |
y = y.squeeze(dim=1) | |
if y_pred.ndimension() > 1 and y_pred.shape[1] == 1: | |
# (N, 1, ...) -> (N, ...) | |
y_pred = y_pred.squeeze(dim=1) | |
if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()): | |
raise ValueError("y must have shape of (batch_size, ...) and y_pred must have " | |
"shape of (batch_size, num_categories, ...) or (batch_size, ...), " | |
"but given {} vs {}".format(y.shape, y_pred.shape)) | |
y_shape = y.shape | |
y_pred_shape = y_pred.shape | |
if y.ndimension() + 1 == y_pred.ndimension(): | |
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:] | |
if not (y_shape == y_pred_shape): | |
raise ValueError("y and y_pred must have compatible shapes.") | |
if self._is_multilabel and not (y.shape == y_pred.shape and y.ndimension() > 1 and y.shape[1] != 1): | |
raise ValueError("y and y_pred must have same shape of (batch_size, num_categories, ...).") | |
return y_pred, y | |
def _check_type(self, output): | |
y_pred, y = output | |
if y.ndimension() + 1 == y_pred.ndimension(): | |
update_type = "multiclass" | |
elif y.ndimension() == y_pred.ndimension(): | |
if not torch.equal(y, y ** 2): | |
raise ValueError("For binary cases, y must be comprised of 0's and 1's.") | |
if not torch.equal(y_pred, y_pred ** 2): | |
raise ValueError("For binary cases, y_pred must be comprised of 0's and 1's.") | |
if self._is_multilabel: | |
update_type = "multilabel" | |
else: | |
update_type = "binary" | |
else: | |
raise RuntimeError("Invalid shapes of y (shape={}) and y_pred (shape={}), check documentation" | |
" for expected shapes of y and y_pred.".format(y.shape, y_pred.shape)) | |
if self._type is None: | |
self._type = update_type | |
else: | |
if self._type != update_type: | |
raise RuntimeError("Input data type has changed from {} to {}.".format(self._type, update_type)) | |
class Accuracy(_BaseClassification): | |
""" | |
Calculates the accuracy for binary, multiclass and multilabel data | |
- `update` must receive output of the form `(y_pred, y)`. | |
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...) | |
- `y` must be in the following shape (batch_size, ...) | |
- `y` and `y_pred` must be in the following shape of (batch_size, num_categories, ...) for multilabel cases. | |
In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of | |
predictions can be done as below: | |
.. code-block:: python | |
def thresholded_output_transform(output): | |
y_pred, y = output | |
y_pred = torch.round(y_pred) | |
return y_pred, y | |
binary_accuracy = Accuracy(thresholded_output_transform) | |
Args: | |
is_multilabel (bool, optional) flag to use in multilabel case. By default, False. | |
""" | |
def reset(self): | |
self._num_correct = 0 | |
self._num_examples = 0 | |
def update(self, output): | |
y_pred, y = self._check_shape(output) | |
self._check_type((y_pred, y)) | |
if self._type == "binary": | |
correct = torch.eq(y_pred.type(y.type()), y).view(-1) | |
elif self._type == "multiclass": | |
indices = torch.max(y_pred, dim=1)[1] | |
correct = torch.eq(indices, y).view(-1) | |
elif self._type == "multilabel": | |
if y_pred.ndimension() > 2: | |
# if y, y_pred shape is (N, C, ...) -> (N x ..., C) | |
num_classes = y_pred.size(1) | |
last_dim = y_pred.ndimension() | |
y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes) | |
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes) | |
correct = torch.all(y == y_pred.type_as(y), dim=-1) | |
self._num_correct += torch.sum(correct).item() | |
self._num_examples += correct.shape[0] | |
def compute(self): | |
if self._num_examples == 0: | |
raise NotComputableError('Accuracy must have at least one example before it can be computed') | |
return self._num_correct / self._num_examples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import torch | |
from custom_ignite.metrics.accuracy import _BaseClassification | |
from ignite.exceptions import NotComputableError | |
from ignite._utils import to_onehot | |
class _BasePrecisionRecall(_BaseClassification): | |
def __init__(self, output_transform=lambda x: x, average=False, is_multilabel=False): | |
self._average = average | |
super(_BasePrecisionRecall, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel) | |
self.eps = 1e-20 | |
def reset(self): | |
self._true_positives = torch.DoubleTensor(0) if self._is_multilabel else 0 | |
self._positives = torch.DoubleTensor(0) if self._is_multilabel else 0 | |
def compute(self): | |
if not (isinstance(self._positives, torch.Tensor) or self._positives > 0): | |
raise NotComputableError("{} must have at least one example before" | |
" it can be computed".format(self.__class__.__name__)) | |
result = self._true_positives / (self._positives + self.eps) | |
if self._average: | |
return result.mean().item() | |
else: | |
return result | |
class Precision(_BasePrecisionRecall): | |
""" | |
Calculates precision for binary and multiclass data | |
- `update` must receive output of the form `(y_pred, y)`. | |
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...) | |
- `y` must be in the following shape (batch_size, ...) | |
In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of | |
predictions can be done as below: | |
.. code-block:: python | |
def thresholded_output_transform(output): | |
y_pred, y = output | |
y_pred = torch.round(y_pred) | |
return y_pred, y | |
binary_accuracy = Precision(output_transform=thresholded_output_transform) | |
Args: | |
average (bool, optional): if True, precision is computed as the unweighted average (across all classes | |
in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). | |
is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average | |
parameter should be True and the average is computed across samples, instead of classes. | |
""" | |
def update(self, output): | |
y_pred, y = self._check_shape(output) | |
self._check_type((y_pred, y)) | |
if self._type == "binary": | |
y_pred = y_pred.view(-1) | |
y = y.view(-1) | |
elif self._type == "multiclass": | |
num_classes = y_pred.size(1) | |
y = to_onehot(y.view(-1), num_classes=num_classes) | |
indices = torch.max(y_pred, dim=1)[1].view(-1) | |
y_pred = to_onehot(indices, num_classes=num_classes) | |
elif self._type == "multilabel": | |
# if y, y_pred shape is (N, C, ...) -> (C, N x ...) | |
num_classes = y_pred.size(1) | |
y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) | |
y = torch.transpose(y, 1, 0).reshape(num_classes, -1) | |
y = y.type_as(y_pred) | |
correct = y * y_pred | |
all_positives = y_pred.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu | |
if correct.sum() == 0: | |
true_positives = torch.zeros_like(all_positives) | |
else: | |
true_positives = correct.sum(dim=0) | |
# Convert from int cuda/cpu to double cpu | |
# We need double precision for the division true_positives / all_positives | |
true_positives = true_positives.type(torch.DoubleTensor) | |
if self._type == "multilabel": | |
self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) | |
self._positives = torch.cat([self._positives, all_positives], dim=0) | |
else: | |
self._true_positives += true_positives | |
self._positives += all_positives |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import torch | |
from custom_ignite.metrics.precision import _BasePrecisionRecall | |
from ignite._utils import to_onehot | |
class Recall(_BasePrecisionRecall): | |
""" | |
Calculates recall for binary and multiclass data | |
- `update` must receive output of the form `(y_pred, y)`. | |
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...) | |
- `y` must be in the following shape (batch_size, ...) | |
In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of | |
predictions can be done as below: | |
.. code-block:: python | |
def thresholded_output_transform(output): | |
y_pred, y = output | |
y_pred = torch.round(y_pred) | |
return y_pred, y | |
binary_accuracy = Recall(output_transform=thresholded_output_transform) | |
Args: | |
average (bool, optional): if True, precision is computed as the unweighted average (across all classes | |
in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). | |
is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average | |
parameter should be True and the average is computed across samples, instead of classes. | |
""" | |
def update(self, output): | |
y_pred, y = self._check_shape(output) | |
self._check_type((y_pred, y)) | |
if self._type == "binary": | |
y_pred = y_pred.view(-1) | |
y = y.view(-1) | |
elif self._type == "multiclass": | |
num_classes = y_pred.size(1) | |
y = to_onehot(y.view(-1), num_classes=num_classes) | |
indices = torch.max(y_pred, dim=1)[1].view(-1) | |
y_pred = to_onehot(indices, num_classes=num_classes) | |
elif self._type == "multilabel": | |
# if y, y_pred shape is (N, C, ...) -> (C, N x ...) | |
num_classes = y_pred.size(1) | |
y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) | |
y = torch.transpose(y, 1, 0).reshape(num_classes, -1) | |
y = y.type_as(y_pred) | |
correct = y * y_pred | |
actual_positives = y.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu | |
if correct.sum() == 0: | |
true_positives = torch.zeros_like(actual_positives) | |
else: | |
true_positives = correct.sum(dim=0) | |
# Convert from int cuda/cpu to double cpu | |
# We need double precision for the division true_positives / actual_positives | |
true_positives = true_positives.type(torch.DoubleTensor) | |
if self._type == "multilabel": | |
self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) | |
self._positives = torch.cat([self._positives, actual_positives], dim=0) | |
else: | |
self._true_positives += true_positives | |
self._positives += actual_positives |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment