Forked from the-bass/confusion_matrix_between_two_pytorch_tensors.py
Created
June 6, 2020 06:43
-
-
Save HaritzPuerto/3f4a6aeeb7d681684da26fc51a78d2f0 to your computer and use it in GitHub Desktop.
Calculating the confusion matrix between two PyTorch tensors (a batch of predictions) - Last tested with PyTorch 0.4.1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def confusion(prediction, truth): | |
""" Returns the confusion matrix for the values in the `prediction` and `truth` | |
tensors, i.e. the amount of positions where the values of `prediction` | |
and `truth` are | |
- 1 and 1 (True Positive) | |
- 1 and 0 (False Positive) | |
- 0 and 0 (True Negative) | |
- 0 and 1 (False Negative) | |
""" | |
confusion_vector = prediction / truth | |
# Element-wise division of the 2 tensors returns a new tensor which holds a | |
# unique value for each case: | |
# 1 where prediction and truth are 1 (True Positive) | |
# inf where prediction is 1 and truth is 0 (False Positive) | |
# nan where prediction and truth are 0 (True Negative) | |
# 0 where prediction is 0 and truth is 1 (False Negative) | |
true_positives = torch.sum(confusion_vector == 1).item() | |
false_positives = torch.sum(confusion_vector == float('inf')).item() | |
true_negatives = torch.sum(torch.isnan(confusion_vector)).item() | |
false_negatives = torch.sum(confusion_vector == 0).item() | |
return true_positives, false_positives, true_negatives, false_negatives |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import torch | |
from aux import confusion | |
class TestConfusion(unittest.TestCase): | |
def test_with_valid_tensors(self): | |
prediction = torch.tensor([ | |
[1], | |
[1.0], | |
[1], | |
[0], | |
[0], | |
[0], | |
[0], | |
[0], | |
[0], | |
[0] | |
]) | |
truth = torch.tensor([ | |
[1.0], | |
[1], | |
[0], | |
[0], | |
[1], | |
[0], | |
[0], | |
[1], | |
[1], | |
[1] | |
]) | |
tp, fp, tn, fn = confusion(prediction, truth) | |
self.assertEqual(tp, 2) | |
self.assertEqual(fp, 1) | |
self.assertEqual(tn, 3) | |
self.assertEqual(fn, 4) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment