Skip to content

Instantly share code, notes, and snippets.

@atamborrino
Last active August 27, 2018 12:13
Show Gist options
  • Save atamborrino/d9fe0ca806a64f87cdf46b0b1a9ea20c to your computer and use it in GitHub Desktop.
Save atamborrino/d9fe0ca806a64f87cdf46b0b1a9ea20c to your computer and use it in GitHub Desktop.
# src: https://www.kaggle.com/aglotero/another-iou-metric
def iou_metric(y_true_in, y_pred_in, print_table=False):
labels = y_true_in
y_pred = y_pred_in
true_objects = 2
pred_objects = 2
intersection = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0]
# Compute areas (needed for finding the union between all objects)
area_true = np.histogram(labels, bins = true_objects)[0]
area_pred = np.histogram(y_pred, bins = pred_objects)[0]
area_true = np.expand_dims(area_true, -1)
area_pred = np.expand_dims(area_pred, 0)
# Compute union
union = area_true + area_pred - intersection
# Exclude background from the analysis
intersection = intersection[1:,1:]
union = union[1:,1:]
union[union == 0] = 1e-9
# Compute the intersection over union
iou = intersection / union
# Precision helper function
def precision_at(threshold, iou):
matches = iou > threshold
true_positives = np.sum(matches, axis=1) == 1 # Correct objects
false_positives = np.sum(matches, axis=0) == 0 # Missed objects
false_negatives = np.sum(matches, axis=1) == 0 # Extra objects
tp, fp, fn = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)
return tp, fp, fn
# Loop over IoU thresholds
prec = []
if print_table:
print("Thresh\tTP\tFP\tFN\tPrec.")
for t in np.arange(0.5, 1.0, 0.05):
tp, fp, fn = precision_at(t, iou)
if (tp + fp + fn) > 0:
p = tp / (tp + fp + fn)
else:
p = 0
if print_table:
print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tp, fp, fn, p))
prec.append(p)
if print_table:
print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))
return np.mean(prec)
def iou_metric_batch(y_true_in, y_pred_in):
batch_size = y_true_in.shape[0]
metric = []
for batch in range(batch_size):
value = iou_metric(y_true_in[batch], y_pred_in[batch])
metric.append(value)
return np.mean(metric)
def best_iou_and_threshold(y_true, y_pred, plot=False):
thresholds = np.linspace(0, 1, 50)
ious = np.array([iou_metric_batch(y_true, np.int32(y_pred > threshold)) for threshold in thresholds])
threshold_best_index = np.argmax(ious[9:-10]) + 9
iou_best = ious[threshold_best_index]
threshold_best = thresholds[threshold_best_index]
if plot:
plt.plot(thresholds, ious)
plt.plot(threshold_best, iou_best, "xr", label="Best threshold")
plt.xlabel("Threshold")
plt.ylabel("IoU")
plt.title("Threshold vs IoU ({}, {})".format(threshold_best, iou_best))
plt.legend()
print(f'threshold_best: {threshold_best}')
print(f'iou_best: {iou_best}')
return iou_best, threshold_best
class ValGlobalMetrics(keras.callbacks.Callback):
def on_epoch_end(self, batch, logs={}):
predict = np.asarray(self.model.predict(x_valid))
targ = y_valid
best_iou, _ = best_iou_and_threshold(y_true=targ, y_pred=predict)
logs['val_best_iou'] = best_iou
print(f' - val_best_iou: {best_iou}')
early_stopping = EarlyStopping(patience=10, verbose=1, monitor='val_best_iou', mode='max')
model_checkpoint = ModelCheckpoint("./keras.model", save_best_only=True, verbose=1, monitor='val_best_iou', mode='max')
reduce_lr = ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1, monitor='val_best_iou', mode='max')
keras_callbacks=[ValGlobalMetrics(), early_stopping, model_checkpoint, reduce_lr]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment