Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save HaritzPuerto/3f4a6aeeb7d681684da26fc51a78d2f0 to your computer and use it in GitHub Desktop.
Save HaritzPuerto/3f4a6aeeb7d681684da26fc51a78d2f0 to your computer and use it in GitHub Desktop.
Calculating the confusion matrix between two PyTorch tensors (a batch of predictions) - Last tested with PyTorch 0.4.1
import torch
def confusion(prediction, truth):
""" Returns the confusion matrix for the values in the `prediction` and `truth`
tensors, i.e. the amount of positions where the values of `prediction`
and `truth` are
- 1 and 1 (True Positive)
- 1 and 0 (False Positive)
- 0 and 0 (True Negative)
- 0 and 1 (False Negative)
"""
confusion_vector = prediction / truth
# Element-wise division of the 2 tensors returns a new tensor which holds a
# unique value for each case:
# 1 where prediction and truth are 1 (True Positive)
# inf where prediction is 1 and truth is 0 (False Positive)
# nan where prediction and truth are 0 (True Negative)
# 0 where prediction is 0 and truth is 1 (False Negative)
true_positives = torch.sum(confusion_vector == 1).item()
false_positives = torch.sum(confusion_vector == float('inf')).item()
true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
false_negatives = torch.sum(confusion_vector == 0).item()
return true_positives, false_positives, true_negatives, false_negatives
import unittest
import torch
from aux import confusion
class TestConfusion(unittest.TestCase):
def test_with_valid_tensors(self):
prediction = torch.tensor([
[1],
[1.0],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
])
truth = torch.tensor([
[1.0],
[1],
[0],
[0],
[1],
[0],
[0],
[1],
[1],
[1]
])
tp, fp, tn, fn = confusion(prediction, truth)
self.assertEqual(tp, 2)
self.assertEqual(fp, 1)
self.assertEqual(tn, 3)
self.assertEqual(fn, 4)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment