Skip to content

Instantly share code, notes, and snippets.

@mberr
Last active March 31, 2022 08:46
Show Gist options
  • Save mberr/78559dfec160d5bf7245d674ca18f42f to your computer and use it in GitHub Desktop.
Save mberr/78559dfec160d5bf7245d674ca18f42f to your computer and use it in GitHub Desktop.
Determine optimal threshold for Macro F1 score
"""Determine optimal threshold for Macro F1 score."""
from typing import Tuple
import numpy
from sklearn.metrics._ranking import _binary_clf_curve
def f1_scores(
precision: numpy.ndarray,
recall: numpy.ndarray,
) -> numpy.ndarray:
denom = precision + recall
denom[denom == 0.0] = 1.0
return 2 * (precision * recall) / denom
def recall(
tps: numpy.ndarray,
tps_fns: numpy.ndarray,
) -> numpy.ndarray:
return tps / tps_fns
def precision(
tps: numpy.ndarray,
tps_fps: numpy.ndarray,
) -> numpy.ndarray:
return tps / tps_fps
def all_f1_scores(
y_true: numpy.ndarray,
y_score: numpy.ndarray,
) -> Tuple[numpy.ndarray, numpy.ndarray]:
# cf. https://stats.stackexchange.com/questions/518616/how-to-find-the-optimal-threshold-for-the-weighted-f1-score-in-a-binary-classifi
# cf. https://arxiv.org/abs/1911.03347
# compute TP, FP, FN, TN for all thresholds
fps, tps, thresholds = _binary_clf_curve(y_true, y_score)
tns = fps[-1] - fps
fns = tps[-1] - tps
# F1-scores positive class
f1_pos = f1_scores(
precision=precision(tps=tps, tps_fps=tps + fps),
recall=recall(tps=tps, tps_fns=tps[-1]), # tps + fns = tps + (tps[-1] - tps) = tps[-1]
)
# F1-scores negative class
f1_neg = f1_scores(
precision=precision(tps=tns, tps_fps=tns + fns),
recall=recall(tps=tns, tps_fns=tns + fps), # tns + fps = fps[-1] - fps + fps = fps[-1]
)
# macro average
f1 = 0.5 * (f1_pos + f1_neg)
return thresholds, f1
def optimal_f1_score(
y_true: numpy.ndarray,
y_score: numpy.ndarray,
) -> Tuple[float, float]:
thresholds, f1s = all_f1_scores(y_true=y_true, y_score=y_score)
idx = numpy.nanargmax(f1s)
return thresholds[idx], f1s[idx]
if __name__ == "__main__":
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_moons
from matplotlib import pyplot as plt
clf = MLPClassifier()
X, y_true = make_moons()
clf.fit(X, y_true)
y_score = clf.predict_proba(X)[:, 1]
ts, f1_both = all_f1_scores(y_true=y_true, y_score=y_score)
t, opt = optimal_f1_score(y_true=y_true, y_score=y_score)
fig, ax = plt.subplots()
ax.plot(ts, f1_both)
ax.axvline(t, ls="dashed", color="black")
ax.axhline(opt, ls="dashed", color="black")
ax.set_xlabel("score")
ax.set_ylabel("$F_1$")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment