Created
October 10, 2023 18:59
-
-
Save alexdremov/a124c9d82dbbdd8af48e5b0f82c466f6 to your computer and use it in GitHub Desktop.
SymmetricBestDICE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy | |
def ravel_image(img): | |
""" | |
Разворачивает изображения в одномерный массив с учетом батча | |
""" | |
assert 1 < len(img.shape) < 4 | |
if len(img.shape) == 2: | |
img = img[None] | |
img = img.reshape(len(img), -1) | |
return img | |
def DICE(pred, true): | |
""" | |
DICE metric | |
:param pred: shape (batch, h, w) or (h, w) | |
:param true: same shape as pred | |
:return: | |
""" | |
assert pred.shape == true.shape, "Shape mismatch" | |
pred = (ravel_image(pred) > 0).astype(np.float32) | |
true = (ravel_image(true) > 0).astype(np.float32) | |
sums = pred.sum(-1) + true.sum(-1) | |
sums[sums == 0] = 1 | |
results = (2 * (pred * true).sum(-1) / (sums)) | |
results[sums == 0] = 1 | |
return results.mean() | |
def get_instance_best_matching(pred, true): | |
""" | |
Возвращает наилучшее соответствие масок и матрицу расстояний всех со всеми | |
""" | |
pred, true = np.round(pred), np.round(true) | |
pred_max = int(pred.max()) | |
true_max = int(true.max()) | |
costs = np.zeros((pred_max, true_max)) | |
for i in range(1, pred_max + 1): | |
for j in range(1, true_max + 1): | |
costs[i - 1, j - 1] = DiffFgDICE(pred == i, true == j) | |
preds_corr, true_corr = scipy.optimize.linear_sum_assignment(-costs) | |
return preds_corr, true_corr, costs | |
def SymmetricBestDice(preds, trues): | |
""" | |
Средний dice сегментации при наилучшем соответствии масок. | |
Если количество инстансев различается, то добивается нулями. | |
Все маски объединяются в одну с увеличением индекса маски для разных масок. | |
Например вот такая маска с тремя объектами. Ноль это все еще фон | |
0 0 0 0 0 0 0 0 | |
0 1 1 0 0 0 2 0 | |
0 1 1 0 2 2 2 0 | |
0 0 0 0 0 0 2 0 | |
0 0 3 3 3 0 0 0 | |
0 0 3 3 3 0 0 0 | |
""" | |
if not isinstance(preds, list) and len(preds.shape) == 2: | |
preds = preds[None] | |
trues = trues[None] | |
results = [] | |
for pred, true in zip(preds, trues): | |
preds_corr, true_corr, costs = get_instance_best_matching(pred, true) | |
costs = costs[preds_corr, true_corr].tolist() | |
costs += [0] * int(abs(np.max(pred) - np.max(true))) | |
results.append(np.mean(costs)) | |
return np.mean(results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment