Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Nithanaroy/308991afd3f065af2c753d42dd0c730c to your computer and use it in GitHub Desktop.
Save Nithanaroy/308991afd3f065af2c753d42dd0c730c to your computer and use it in GitHub Desktop.
Calculate the accuracy of merged percentile buckets
# Credits: https://stackoverflow.com/a/52127761/1585523
import tensorflow as tf
from tensorflow.keras import backend as K
# Define the custom accuracy tensorflow metric
def three_class_acc(y_true, y_pred):
tr = tf.floor(tf.to_float(K.argmax(y_true, axis=-1) / 3))
pr = tf.floor(tf.to_float(K.argmax(y_pred, axis=-1) / 3))
return K.cast(K.equal(tr, pr), K.floatx())
# Use the metric in a tensorflow.keras model while training it
# model = ... my tensorflow keras model definition ...
model.fit( X_train, y_train, validation_data=(X_test, y_test),
metrics=['accuracy', three_class_acc]
)
# this prints the three_class_acc metric we defined along with regular per class accuracy metric
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment