Skip to content

Instantly share code, notes, and snippets.

@dtrizna
Last active April 12, 2023 08:01
Show Gist options
  • Save dtrizna/cfce45a85171602abf3b49335d67148f to your computer and use it in GitHub Desktop.
Save dtrizna/cfce45a85171602abf3b49335d67148f to your computer and use it in GitHub Desktop.
from sklearn.metrics import roc_curve, det_curve
def get_threshold_from_rate(thresholds, rate_array, rate):
index = np.where(rate_array >= rate)[0][0]
return thresholds[index]
def get_value_from_threshold(values, thresholds, threshold):
try:
thr_index = np.where(thresholds <= threshold)[0][0]
except IndexError:
thr_index = 0
return values[thr_index]
metrics = {}
metrics["fpr"], metrics["tpr"], metrics["threshold_roc"] = roc_curve(y_true, preds)
_, metrics["fnr"], metrics["threshold_det"] = det_curve(y_true, preds)
print("---" * 35)
for fpr_rate in [0.00005, 0.0001, 0.0005, 0.001]:
threshold = get_threshold_from_rate(metrics["threshold_roc"], metrics["fpr"], fpr_rate)
tpr_rate = get_value_from_threshold(metrics["tpr"], metrics["threshold_roc"], threshold)
print(f"{encoding:>20} : False Positive rate: {fpr_rate*100:>5.3f}% | Detection rate: {tpr_rate*100:>5.2f}% | Threshold: {threshold:>5.4f} ")
print("---" * 35)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment