Last active
November 24, 2023 09:56
-
-
Save the-bass/cae9f3976866776dea17a5049013258d to your computer and use it in GitHub Desktop.
Calculating the confusion matrix between two PyTorch tensors (a batch of predictions) - Last tested with PyTorch 0.4.1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def confusion(prediction, truth): | |
""" Returns the confusion matrix for the values in the `prediction` and `truth` | |
tensors, i.e. the amount of positions where the values of `prediction` | |
and `truth` are | |
- 1 and 1 (True Positive) | |
- 1 and 0 (False Positive) | |
- 0 and 0 (True Negative) | |
- 0 and 1 (False Negative) | |
""" | |
confusion_vector = prediction / truth | |
# Element-wise division of the 2 tensors returns a new tensor which holds a | |
# unique value for each case: | |
# 1 where prediction and truth are 1 (True Positive) | |
# inf where prediction is 1 and truth is 0 (False Positive) | |
# nan where prediction and truth are 0 (True Negative) | |
# 0 where prediction is 0 and truth is 1 (False Negative) | |
true_positives = torch.sum(confusion_vector == 1).item() | |
false_positives = torch.sum(confusion_vector == float('inf')).item() | |
true_negatives = torch.sum(torch.isnan(confusion_vector)).item() | |
false_negatives = torch.sum(confusion_vector == 0).item() | |
return true_positives, false_positives, true_negatives, false_negatives |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import torch | |
from aux import confusion | |
class TestConfusion(unittest.TestCase): | |
def test_with_valid_tensors(self): | |
prediction = torch.tensor([ | |
[1], | |
[1.0], | |
[1], | |
[0], | |
[0], | |
[0], | |
[0], | |
[0], | |
[0], | |
[0] | |
]) | |
truth = torch.tensor([ | |
[1.0], | |
[1], | |
[0], | |
[0], | |
[1], | |
[0], | |
[0], | |
[1], | |
[1], | |
[1] | |
]) | |
tp, fp, tn, fn = confusion(prediction, truth) | |
self.assertEqual(tp, 2) | |
self.assertEqual(fp, 1) | |
self.assertEqual(tn, 3) | |
self.assertEqual(fn, 4) | |
if __name__ == '__main__': | |
unittest.main() |
Nice work. Can you please point any cues relating to doing the same for multi-class version?
Hi there,
I made a little script for handling multi-class version. I hope it helps:
import torch
class ConfusionMatrix:
_device = 'cuda' if torch.cuda.is_available() else 'cpu'
def __init__(self, n_classes: int = 10):
self._matrix = torch.zeros(n_classes * n_classes).to(self._device)
self._n = n_classes
def cpu(self):
self._matrix.cpu()
def cuda(self):
self._matrix.cuda()
def to(self, device: str):
self._matrix.to(device)
def __add__(self, other):
if isinstance(other, ConfusionMatrix):
self._matrix.add_(other._matrix)
elif isinstance(other, tuple):
self.update(*other)
else:
raise NotImplemented
return self
def update(self, prediction: torch.tensor, label: torch.tensor):
conf_data = prediction * self._n + label
conf = conf_data.bincount(minlength=self._n * self._n)
self._matrix.add_(conf)
@property
def value(self):
return self._matrix.view(self._n, self._n).T
def main():
label = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
pred = torch.tensor([0, 1, 0, 0, 0, 1, 2, 2, 2])
conf = ConfusionMatrix(3)
conf += pred, label
print(conf.value)
conf2 = ConfusionMatrix(3)
pred2 = torch.Tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]).long()
conf2.update(pred2, label)
print(conf2.value)
conf += conf2
print(conf.value)
if __name__ == '__main__':
main()
Could you provide information what each row in tensor represent tp fp tn fn?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How to implement multi-class version?