Skip to content

Instantly share code, notes, and snippets.

@doleron
Created April 9, 2023 12:05
Show Gist options
  • Save doleron/03ff2dadf90ad841877aca76f69e7f43 to your computer and use it in GitHub Desktop.
Save doleron/03ff2dadf90ad841877aca76f69e7f43 to your computer and use it in GitHub Desktop.
class Custom_Precision(tf.keras.metrics.Precision):
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_fix = tf.math.less(y_pred, 0.5)
y_pred_fix = tf.cast(y_pred_fix, y_pred.dtype)
return super().update_state(y_true, y_pred_fix, sample_weight)
class Custom_Recall(tf.keras.metrics.Recall):
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_fix = tf.math.less(y_pred, 0.5)
y_pred_fix = tf.cast(y_pred_fix, y_pred.dtype)
return super().update_state(y_true, y_pred_fix, sample_weight)
class Custom_Accuracy(tf.keras.metrics.Accuracy):
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_fix = tf.math.less(y_pred, 0.5)
y_pred_fix = tf.cast(y_pred_fix, y_pred.dtype)
return super().update_state(y_true, y_pred_fix, sample_weight)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment