Skip to content

Instantly share code, notes, and snippets.

@the-bass
Last active November 24, 2023 09:56
Show Gist options
  • Star 24 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save the-bass/cae9f3976866776dea17a5049013258d to your computer and use it in GitHub Desktop.
Save the-bass/cae9f3976866776dea17a5049013258d to your computer and use it in GitHub Desktop.
Calculating the confusion matrix between two PyTorch tensors (a batch of predictions) - Last tested with PyTorch 0.4.1
import torch
def confusion(prediction, truth):
""" Returns the confusion matrix for the values in the `prediction` and `truth`
tensors, i.e. the amount of positions where the values of `prediction`
and `truth` are
- 1 and 1 (True Positive)
- 1 and 0 (False Positive)
- 0 and 0 (True Negative)
- 0 and 1 (False Negative)
"""
confusion_vector = prediction / truth
# Element-wise division of the 2 tensors returns a new tensor which holds a
# unique value for each case:
# 1 where prediction and truth are 1 (True Positive)
# inf where prediction is 1 and truth is 0 (False Positive)
# nan where prediction and truth are 0 (True Negative)
# 0 where prediction is 0 and truth is 1 (False Negative)
true_positives = torch.sum(confusion_vector == 1).item()
false_positives = torch.sum(confusion_vector == float('inf')).item()
true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
false_negatives = torch.sum(confusion_vector == 0).item()
return true_positives, false_positives, true_negatives, false_negatives
import unittest
import torch
from aux import confusion
class TestConfusion(unittest.TestCase):
def test_with_valid_tensors(self):
prediction = torch.tensor([
[1],
[1.0],
[1],
[0],
[0],
[0],
[0],
[0],
[0],
[0]
])
truth = torch.tensor([
[1.0],
[1],
[0],
[0],
[1],
[0],
[0],
[1],
[1],
[1]
])
tp, fp, tn, fn = confusion(prediction, truth)
self.assertEqual(tp, 2)
self.assertEqual(fp, 1)
self.assertEqual(tn, 3)
self.assertEqual(fn, 4)
if __name__ == '__main__':
unittest.main()
@americanexplorer13
Copy link

@AminJun

Could you provide information what each row in tensor represent tp fp tn fn?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment