Skip to content

Instantly share code, notes, and snippets.

@zommiommy
Last active November 10, 2020 12:35
Show Gist options
  • Save zommiommy/0bb4cb37610f59edaa8d903a6d9d6784 to your computer and use it in GitHub Desktop.
Save zommiommy/0bb4cb37610f59edaa8d903a6d9d6784 to your computer and use it in GitHub Desktop.
A bunch of metrics for binary classification tasks implemented in Keras
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.metrics import Metric, AUC
from tensorflow.keras.backend import epsilon
class ConfusionMatrixMetric(Metric):
def __init__(self, name, **kwargs):
super(ConfusionMatrixMetric, self).__init__(name=name, **kwargs)
self.tp = self.add_weight(name='tp', initializer='zeros')
self.fp = self.add_weight(name='fp', initializer='zeros')
self.tn = self.add_weight(name='tn', initializer='zeros')
self.fn = self.add_weight(name='fn', initializer='zeros')
self._result = self.add_weight(name=name, initializer='zeros')
def result(self):
return self._result
def reset_states(self):
self.tp.assign(0)
self.fp.assign(0)
self.tn.assign(0)
self.fn.assign(0)
self._result.assign(0)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, self.dtype)
y_pred = tf.cast(y_pred, self.dtype)
y_pred_pos = K.round(K.clip(y_pred, 0, 1))
y_pred_neg = 1 - y_pred_pos
y_pos = K.round(K.clip(y_true, 0, 1))
y_neg = 1 - y_pos
self.tp.assign_add(K.sum(y_pos * y_pred_pos))
self.fp.assign_add(K.sum(y_neg * y_pred_pos))
self.fn.assign_add(K.sum(y_pos * y_pred_neg))
self.tn.assign_add(K.sum(y_neg * y_pred_neg))
self._result.assign(self._custom_metric())
def _custom_metric(self):
raise NotImplementedError("This method should be implemented by subclasses")
class TruePositivesRatio(ConfusionMatrixMetric):
def _custom_metric(self):
return self.tp / (self.tp + self.fp + self.tn + self.fn + epsilon())
class FalsePositivesRatio(ConfusionMatrixMetric):
def _custom_metric(self):
return self.fp / (self.tp + self.fp + self.tn + self.fn + epsilon())
class TrueNegativesRatio(ConfusionMatrixMetric):
def _custom_metric(self):
return self.tn / (self.tp + self.fp + self.tn + self.fn + epsilon())
class FalseNegativesRatio(ConfusionMatrixMetric):
def _custom_metric(self):
return self.fn / (self.tp + self.fp + self.tn + self.fn + epsilon())
class Recall(ConfusionMatrixMetric):
def _custom_metric(self):
return self.tp / (self.tp + self.fn + epsilon())
class Specificity(ConfusionMatrixMetric):
def _custom_metric(self):
return self.tn / (self.tn + self.fp + epsilon())
class Precision(ConfusionMatrixMetric):
def _custom_metric(self):
return self.tp / (self.tp + self.fp + epsilon())
class NegativePredictiveValue(ConfusionMatrixMetric):
def _custom_metric(self):
return self.tn / (self.fn + self.tn + epsilon())
class MissRate(ConfusionMatrixMetric):
def _custom_metric(self):
return self.fn / (self.fn + self.tp + epsilon())
class FallOut(ConfusionMatrixMetric):
def _custom_metric(self):
return self.fp / (self.fp + self.tn + epsilon())
class FalseDiscoveryRate(ConfusionMatrixMetric):
def _custom_metric(self):
return self.fp / (self.fp + self.tp + epsilon())
class FalseOmissionRate(ConfusionMatrixMetric):
def _custom_metric(self):
return self.fn / (self.fn + self.tn + epsilon())
class PrevalenceThreshold(ConfusionMatrixMetric):
def _custom_metric(self):
tpr = self.tp / (self.tp + self.fn + epsilon())
tnr = self.tn / (self.tn + self.fp + epsilon())
return (tf.math.sqrt(tpr*(1 - tnr)) + tnr - 1) / (tpr + tnr - 1 + epsilon())
class ThreatScore(ConfusionMatrixMetric):
def _custom_metric(self):
return self.tp / (self.tp + self.fn + self.fp + epsilon())
class Accuracy(ConfusionMatrixMetric):
def _custom_metric(self):
return (self.tp + self.tn) / (self.tp + self.tn + self.fp + self.fn + epsilon())
class BalancedAccuracy(ConfusionMatrixMetric):
def _custom_metric(self):
tpr = self.tp / (self.tp + self.fn + epsilon())
tnr = self.tn / (self.tn + self.fp + epsilon())
return (tpr + tnr) / 2
class F1Score(ConfusionMatrixMetric):
def _custom_metric(self):
return self.tp / (self.tp + 0.5 * (self.fp + self.fn) + epsilon())
class MatthewsCorrelationCoefficinet(ConfusionMatrixMetric):
def _custom_metric(self):
numerator = (self.tp * self.tn - self.fp * self.fn)
denominator = tf.math.sqrt((self.tp + self.fp) * (self.tp + self.fn) * (self.tn + self.fp) * (self.tn + self.fn))
return numerator / (denominator + epsilon())
class FowlkesMallowsIndex(ConfusionMatrixMetric):
def _custom_metric(self):
tpr = self.tp / (self.tp + self.fn + epsilon())
tnr = self.tn / (self.tn + self.fp + epsilon())
return (tpr + tnr) / 2
class Informedness(ConfusionMatrixMetric):
def _custom_metric(self):
tpr = self.tp / (self.tp + self.fn + epsilon())
ppv = self.tp / (self.tp + self.fp + epsilon())
return tf.math.sqrt(ppv * tpr)
class Markedness(ConfusionMatrixMetric):
def _custom_metric(self):
ppv = self.tp / (self.tp + self.fp + epsilon())
npv = self.tn / (self.tn + self.fn + epsilon())
return ppv + npv - 1
class PositiveLikelihoodRatio(ConfusionMatrixMetric):
def _custom_metric(self):
tpr = self.tp / (self.tp + self.fn + epsilon())
fpr = self.fp / (self.fp + self.tn + epsilon())
return tpr / (fpr + epsilon())
class NegativeLikelihoodRatio(ConfusionMatrixMetric):
def _custom_metric(self):
tnr = self.tn / (self.tn + self.fp + epsilon())
fnr = self.fn / (self.fn + self.tp + epsilon())
return fnr / (tnr + epsilon())
class DiagnosticOddsRatio(ConfusionMatrixMetric):
def _custom_metric(self):
tpr = self.tp / (self.tp + self.fn + epsilon())
fpr = self.fp / (self.fp + self.tn + epsilon())
tnr = self.tn / (self.tn + self.fp + epsilon())
fnr = self.fn / (self.fn + self.tp + epsilon())
return (tpr + tnr) / (fpr + fnr + epsilon())
def get_all_metrics():
return [
Accuracy(name="accuracy"),
BalancedAccuracy(name="balanced_accuracy"),
AUC(curve="roc", name="AUROC"),
AUC(curve="pr", name="AUPRC"),
F1Score(name="f1_score"),
MatthewsCorrelationCoefficinet(name="mcc"),
TruePositivesRatio(name="tp/t"),
FalsePositivesRatio(name="fp/t"),
TrueNegativesRatio(name="tn/t"),
FalseNegativesRatio(name="fn/t"),
Recall(name="recall"),
Specificity(name="specificity"),
Precision(name="precision"),
MissRate(name="miss_rate"),
FallOut(name="fall_out"),
NegativePredictiveValue(name="negative_predictive_value"),
FalseDiscoveryRate(name="false_discovery_rate"),
FalseOmissionRate(name="false_omission-rate"),
PrevalenceThreshold(name="prevalence_threshold"),
ThreatScore(name="threat_score"),
FowlkesMallowsIndex(name="fowlkes_mallows_index"),
Informedness(name="informedness"),
Markedness(name="markedness"),
PositiveLikelihoodRatio(name="LR+"),
NegativeLikelihoodRatio(name="LR-"),
DiagnosticOddsRatio(name="DOR")
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment