Skip to content

Instantly share code, notes, and snippets.

@the-bass
Last active November 24, 2023 09:56
Show Gist options
  • Save the-bass/cae9f3976866776dea17a5049013258d to your computer and use it in GitHub Desktop.
Save the-bass/cae9f3976866776dea17a5049013258d to your computer and use it in GitHub Desktop.
Calculating the confusion matrix between two PyTorch tensors (a batch of predictions) - Last tested with PyTorch 0.4.1
import torch
def confusion(prediction, truth):
""" Returns the confusion matrix for the values in the `prediction` and `truth`
tensors, i.e. the amount of positions where the values of `prediction`
and `truth` are
- 1 and 1 (True Positive)
- 1 and 0 (False Positive)
- 0 and 0 (True Negative)
- 0 and 1 (False Negative)
"""
confusion_vector = prediction / truth
# Element-wise division of the 2 tensors returns a new tensor which holds a
# unique value for each case:
# 1 where prediction and truth are 1 (True Positive)
# inf where prediction is 1 and truth is 0 (False Positive)
# nan where prediction and truth are 0 (True Negative)
# 0 where prediction is 0 and truth is 1 (False Negative)
true_positives = torch.sum(confusion_vector == 1).item()
false_positives = torch.sum(confusion_vector == float('inf')).item()
true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
false_negatives = torch.sum(confusion_vector == 0).item()
return true_positives, false_positives, true_negatives, false_negatives
import unittest
import torch
from aux import confusion
class TestConfusion(unittest.TestCase):
def test_with_valid_tensors(self):
prediction = torch.tensor([
[1],
[1.0],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
])
truth = torch.tensor([
[1.0],
[1],
[0],
[0],
[1],
[0],
[0],
[1],
[1],
[1]
])
tp, fp, tn, fn = confusion(prediction, truth)
self.assertEqual(tp, 2)
self.assertEqual(fp, 1)
self.assertEqual(tn, 3)
self.assertEqual(fn, 4)
if __name__ == '__main__':
unittest.main()
@Z-Zheng
Copy link

Z-Zheng commented Oct 12, 2018

How to implement multi-class version?

@muaz-git
Copy link

Nice work. Can you please point any cues relating to doing the same for multi-class version?

@AminJun
Copy link

AminJun commented Mar 25, 2022

Hi there,
I made a little script for handling multi-class version. I hope it helps:

import torch


class ConfusionMatrix:
    _device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def __init__(self, n_classes: int = 10):
        self._matrix = torch.zeros(n_classes * n_classes).to(self._device)
        self._n = n_classes

    def cpu(self):
        self._matrix.cpu()

    def cuda(self):
        self._matrix.cuda()

    def to(self, device: str):
        self._matrix.to(device)

    def __add__(self, other):
        if isinstance(other, ConfusionMatrix):
            self._matrix.add_(other._matrix)
        elif isinstance(other, tuple):
            self.update(*other)
        else:
            raise NotImplemented
        return self

    def update(self, prediction: torch.tensor, label: torch.tensor):
        conf_data = prediction * self._n + label
        conf = conf_data.bincount(minlength=self._n * self._n)
        self._matrix.add_(conf)

    @property
    def value(self):
        return self._matrix.view(self._n, self._n).T


def main():
    label = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
    pred = torch.tensor([0, 1, 0, 0, 0, 1, 2, 2, 2])
    conf = ConfusionMatrix(3)
    conf += pred, label
    print(conf.value)

    conf2 = ConfusionMatrix(3)
    pred2 = torch.Tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]).long()
    conf2.update(pred2, label)
    print(conf2.value)

    conf += conf2
    print(conf.value)


if __name__ == '__main__':
    main()

@americanexplorer13
Copy link

@AminJun

Could you provide information what each row in tensor represent tp fp tn fn?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment