Skip to content

Instantly share code, notes, and snippets.

@louismullie
Created September 15, 2020 16:50
Show Gist options
  • Save louismullie/7d26e3c871891f6e7c8a556af6ee479d to your computer and use it in GitHub Desktop.
Save louismullie/7d26e3c871891f6e7c8a556af6ee479d to your computer and use it in GitHub Desktop.
def auc_roc(y_true, y_pred):
# can be any tensorflow metric
value, update_op = tf.contrib.metrics.streaming_auc(y_pred, y_true)
# find all variables created for this metric
metric_vars = [i for i in tf.local_variables() if 'auc_roc' in i.name.split('/')[1]]
# Add metric variables to GLOBAL_VARIABLES collection.
# They will be initialized for new session.
for v in metric_vars:
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, v)
# force to update metric values
with tf.control_dependencies([update_op]):
value = tf.identity(value)
return value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment