Skip to content

Instantly share code, notes, and snippets.

@Anaphory
Created May 30, 2022 13:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Anaphory/f5abb332c6f8acfd43c55a2fa891c2dc to your computer and use it in GitHub Desktop.
Save Anaphory/f5abb332c6f8acfd43c55a2fa891c2dc to your computer and use it in GitHub Desktop.
import typing
from collections import Counter
import numpy
def match_pixels(
prediction: numpy.ndarray, ground_truth: numpy.ndarray
) -> (typing.Mapping, int):
"""Compare two rasters and calculate the deviation under good matching.
Greedily compare two rasters, mapping prediction to ground. The biggest
overlap is mapped first, and it depends on the order of pixels for ties.
Caveat
======
This function is not optimal, but generally good enough:
Take
>>> x = numpy.array([1,1,2,2,3,3,3,3,4,4,4,5,5])
>>> y = numpy.array([1,3,1,2,2,3,4,5,3,4,5,1,2])
Then this function returns
>>> match_pixels(x, y)
({1: 1, 2: 2, 3: 3, 4: 4, 5: None}, 9)
But the optimal matching would be
>>> optimal = {1: 3, 2: 1, 3: 5, 4: 4, 5: 2}
which would have 8 mismatches.
Returns
=======
mapping: a mapping from the prediction values to the ground_truth values
mismatches: the number of correspondences not explained by that mapping
"""
assert (
prediction.shape == ground_truth.shape
), "prediction and ground_truth must have the same shape"
counts = Counter(zip(prediction.flat, ground_truth.flat))
missing_prediction_indices = {p for p, g in counts}
missing_ground_truth_indices = {g for p, g in counts}
mapping = {p: None for p in missing_prediction_indices}
mismatches = 0
# Iterate over the pairwise overlaps in order of size
for (prediction_index, ground_truth_index), count in counts.most_common():
if (
prediction_index in missing_prediction_indices
and ground_truth_index in missing_ground_truth_indices
):
mapping[prediction_index] = ground_truth_index
# These are the biggest candidates
missing_prediction_indices.remove(prediction_index)
missing_ground_truth_indices.remove(ground_truth_index)
else:
# This potential mapping between prediction and ground_truth is
# shadowed by a bigger one, so all pixels counted for this mapping
# are now mismatches.
mismatches += count
return mapping, mismatches
def test_matching_pixels_identical():
x = numpy.array([0, 1, 1, 2, 2, 2, 2, 3, 3])
y = x
mapping, mismatches = match_pixels(x, y)
assert mismatches == 0
def test_matching_pixels_mapping():
x = numpy.array([0, 1, 1, 2, 2, 2, 2, 3, 3])
true_mapping = {0: 4, 1: 5, 2: 0, 3: 1}
y = numpy.array([true_mapping[c] for c in x])
mapping, mismatches = match_pixels(x, y)
assert mismatches == 0
assert mapping == true_mapping
def test_matching_pixels_nonmatching():
x = numpy.array([0, 1, 1, 2, 2, 2, 2, 3, 3])
y = numpy.array([0, 1, 2, 2, 2, 2, 2, 2, 3])
mapping, mismatches = match_pixels(x, y)
assert mismatches == 2
assert mapping == {0: 0, 1: 1, 2: 2, 3: 3}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment