Last active
November 10, 2020 12:35
-
-
Save zommiommy/0bb4cb37610f59edaa8d903a6d9d6784 to your computer and use it in GitHub Desktop.
A bunch of metrics for binary classification tasks implemented in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow.keras.backend as K | |
from tensorflow.keras.metrics import Metric, AUC | |
from tensorflow.keras.backend import epsilon | |
class ConfusionMatrixMetric(Metric): | |
def __init__(self, name, **kwargs): | |
super(ConfusionMatrixMetric, self).__init__(name=name, **kwargs) | |
self.tp = self.add_weight(name='tp', initializer='zeros') | |
self.fp = self.add_weight(name='fp', initializer='zeros') | |
self.tn = self.add_weight(name='tn', initializer='zeros') | |
self.fn = self.add_weight(name='fn', initializer='zeros') | |
self._result = self.add_weight(name=name, initializer='zeros') | |
def result(self): | |
return self._result | |
def reset_states(self): | |
self.tp.assign(0) | |
self.fp.assign(0) | |
self.tn.assign(0) | |
self.fn.assign(0) | |
self._result.assign(0) | |
def update_state(self, y_true, y_pred, sample_weight=None): | |
y_true = tf.cast(y_true, self.dtype) | |
y_pred = tf.cast(y_pred, self.dtype) | |
y_pred_pos = K.round(K.clip(y_pred, 0, 1)) | |
y_pred_neg = 1 - y_pred_pos | |
y_pos = K.round(K.clip(y_true, 0, 1)) | |
y_neg = 1 - y_pos | |
self.tp.assign_add(K.sum(y_pos * y_pred_pos)) | |
self.fp.assign_add(K.sum(y_neg * y_pred_pos)) | |
self.fn.assign_add(K.sum(y_pos * y_pred_neg)) | |
self.tn.assign_add(K.sum(y_neg * y_pred_neg)) | |
self._result.assign(self._custom_metric()) | |
def _custom_metric(self): | |
raise NotImplementedError("This method should be implemented by subclasses") | |
class TruePositivesRatio(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.tp / (self.tp + self.fp + self.tn + self.fn + epsilon()) | |
class FalsePositivesRatio(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.fp / (self.tp + self.fp + self.tn + self.fn + epsilon()) | |
class TrueNegativesRatio(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.tn / (self.tp + self.fp + self.tn + self.fn + epsilon()) | |
class FalseNegativesRatio(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.fn / (self.tp + self.fp + self.tn + self.fn + epsilon()) | |
class Recall(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.tp / (self.tp + self.fn + epsilon()) | |
class Specificity(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.tn / (self.tn + self.fp + epsilon()) | |
class Precision(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.tp / (self.tp + self.fp + epsilon()) | |
class NegativePredictiveValue(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.tn / (self.fn + self.tn + epsilon()) | |
class MissRate(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.fn / (self.fn + self.tp + epsilon()) | |
class FallOut(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.fp / (self.fp + self.tn + epsilon()) | |
class FalseDiscoveryRate(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.fp / (self.fp + self.tp + epsilon()) | |
class FalseOmissionRate(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.fn / (self.fn + self.tn + epsilon()) | |
class PrevalenceThreshold(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
tpr = self.tp / (self.tp + self.fn + epsilon()) | |
tnr = self.tn / (self.tn + self.fp + epsilon()) | |
return (tf.math.sqrt(tpr*(1 - tnr)) + tnr - 1) / (tpr + tnr - 1 + epsilon()) | |
class ThreatScore(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.tp / (self.tp + self.fn + self.fp + epsilon()) | |
class Accuracy(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return (self.tp + self.tn) / (self.tp + self.tn + self.fp + self.fn + epsilon()) | |
class BalancedAccuracy(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
tpr = self.tp / (self.tp + self.fn + epsilon()) | |
tnr = self.tn / (self.tn + self.fp + epsilon()) | |
return (tpr + tnr) / 2 | |
class F1Score(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
return self.tp / (self.tp + 0.5 * (self.fp + self.fn) + epsilon()) | |
class MatthewsCorrelationCoefficinet(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
numerator = (self.tp * self.tn - self.fp * self.fn) | |
denominator = tf.math.sqrt((self.tp + self.fp) * (self.tp + self.fn) * (self.tn + self.fp) * (self.tn + self.fn)) | |
return numerator / (denominator + epsilon()) | |
class FowlkesMallowsIndex(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
tpr = self.tp / (self.tp + self.fn + epsilon()) | |
tnr = self.tn / (self.tn + self.fp + epsilon()) | |
return (tpr + tnr) / 2 | |
class Informedness(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
tpr = self.tp / (self.tp + self.fn + epsilon()) | |
ppv = self.tp / (self.tp + self.fp + epsilon()) | |
return tf.math.sqrt(ppv * tpr) | |
class Markedness(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
ppv = self.tp / (self.tp + self.fp + epsilon()) | |
npv = self.tn / (self.tn + self.fn + epsilon()) | |
return ppv + npv - 1 | |
class PositiveLikelihoodRatio(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
tpr = self.tp / (self.tp + self.fn + epsilon()) | |
fpr = self.fp / (self.fp + self.tn + epsilon()) | |
return tpr / (fpr + epsilon()) | |
class NegativeLikelihoodRatio(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
tnr = self.tn / (self.tn + self.fp + epsilon()) | |
fnr = self.fn / (self.fn + self.tp + epsilon()) | |
return fnr / (tnr + epsilon()) | |
class DiagnosticOddsRatio(ConfusionMatrixMetric): | |
def _custom_metric(self): | |
tpr = self.tp / (self.tp + self.fn + epsilon()) | |
fpr = self.fp / (self.fp + self.tn + epsilon()) | |
tnr = self.tn / (self.tn + self.fp + epsilon()) | |
fnr = self.fn / (self.fn + self.tp + epsilon()) | |
return (tpr + tnr) / (fpr + fnr + epsilon()) | |
def get_all_metrics(): | |
return [ | |
Accuracy(name="accuracy"), | |
BalancedAccuracy(name="balanced_accuracy"), | |
AUC(curve="roc", name="AUROC"), | |
AUC(curve="pr", name="AUPRC"), | |
F1Score(name="f1_score"), | |
MatthewsCorrelationCoefficinet(name="mcc"), | |
TruePositivesRatio(name="tp/t"), | |
FalsePositivesRatio(name="fp/t"), | |
TrueNegativesRatio(name="tn/t"), | |
FalseNegativesRatio(name="fn/t"), | |
Recall(name="recall"), | |
Specificity(name="specificity"), | |
Precision(name="precision"), | |
MissRate(name="miss_rate"), | |
FallOut(name="fall_out"), | |
NegativePredictiveValue(name="negative_predictive_value"), | |
FalseDiscoveryRate(name="false_discovery_rate"), | |
FalseOmissionRate(name="false_omission-rate"), | |
PrevalenceThreshold(name="prevalence_threshold"), | |
ThreatScore(name="threat_score"), | |
FowlkesMallowsIndex(name="fowlkes_mallows_index"), | |
Informedness(name="informedness"), | |
Markedness(name="markedness"), | |
PositiveLikelihoodRatio(name="LR+"), | |
NegativeLikelihoodRatio(name="LR-"), | |
DiagnosticOddsRatio(name="DOR") | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment