Created
May 30, 2022 13:34
-
-
Save Anaphory/f5abb332c6f8acfd43c55a2fa891c2dc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
from collections import Counter | |
import numpy | |
def match_pixels( | |
prediction: numpy.ndarray, ground_truth: numpy.ndarray | |
) -> (typing.Mapping, int): | |
"""Compare two rasters and calculate the deviation under good matching. | |
Greedily compare two rasters, mapping prediction to ground. The biggest | |
overlap is mapped first, and it depends on the order of pixels for ties. | |
Caveat | |
====== | |
This function is not optimal, but generally good enough: | |
Take | |
>>> x = numpy.array([1,1,2,2,3,3,3,3,4,4,4,5,5]) | |
>>> y = numpy.array([1,3,1,2,2,3,4,5,3,4,5,1,2]) | |
Then this function returns | |
>>> match_pixels(x, y) | |
({1: 1, 2: 2, 3: 3, 4: 4, 5: None}, 9) | |
But the optimal matching would be | |
>>> optimal = {1: 3, 2: 1, 3: 5, 4: 4, 5: 2} | |
which would have 8 mismatches. | |
Returns | |
======= | |
mapping: a mapping from the prediction values to the ground_truth values | |
mismatches: the number of correspondences not explained by that mapping | |
""" | |
assert ( | |
prediction.shape == ground_truth.shape | |
), "prediction and ground_truth must have the same shape" | |
counts = Counter(zip(prediction.flat, ground_truth.flat)) | |
missing_prediction_indices = {p for p, g in counts} | |
missing_ground_truth_indices = {g for p, g in counts} | |
mapping = {p: None for p in missing_prediction_indices} | |
mismatches = 0 | |
# Iterate over the pairwise overlaps in order of size | |
for (prediction_index, ground_truth_index), count in counts.most_common(): | |
if ( | |
prediction_index in missing_prediction_indices | |
and ground_truth_index in missing_ground_truth_indices | |
): | |
mapping[prediction_index] = ground_truth_index | |
# These are the biggest candidates | |
missing_prediction_indices.remove(prediction_index) | |
missing_ground_truth_indices.remove(ground_truth_index) | |
else: | |
# This potential mapping between prediction and ground_truth is | |
# shadowed by a bigger one, so all pixels counted for this mapping | |
# are now mismatches. | |
mismatches += count | |
return mapping, mismatches | |
def test_matching_pixels_identical(): | |
x = numpy.array([0, 1, 1, 2, 2, 2, 2, 3, 3]) | |
y = x | |
mapping, mismatches = match_pixels(x, y) | |
assert mismatches == 0 | |
def test_matching_pixels_mapping(): | |
x = numpy.array([0, 1, 1, 2, 2, 2, 2, 3, 3]) | |
true_mapping = {0: 4, 1: 5, 2: 0, 3: 1} | |
y = numpy.array([true_mapping[c] for c in x]) | |
mapping, mismatches = match_pixels(x, y) | |
assert mismatches == 0 | |
assert mapping == true_mapping | |
def test_matching_pixels_nonmatching(): | |
x = numpy.array([0, 1, 1, 2, 2, 2, 2, 3, 3]) | |
y = numpy.array([0, 1, 2, 2, 2, 2, 2, 2, 3]) | |
mapping, mismatches = match_pixels(x, y) | |
assert mismatches == 2 | |
assert mapping == {0: 0, 1: 1, 2: 2, 3: 3} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment