Last active
June 14, 2019 18:58
-
-
Save malcolmgreaves/648c9365e249df04cf4438cf2460ddd0 to your computer and use it in GitHub Desktop.
Snippets for using custom metrics -- f1, precision, and recall -- in a Keras model.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from keras import Input, Model | |
from keras import backend as K | |
from keras.engine import InputLayer | |
from keras.layers import Dense, Dropout, Embedding, LSTM, Bidirectional, Lambda | |
from keras.models import Sequential | |
from keras.preprocessing import sequence | |
def bind_conf_mat(pos_pred_indices, neg_index=0, gather_axis=1): | |
gather = lambda bin_y: tf.gather(bin_y, pos_pred_indices, axis=gather_axis) | |
neg_val = lambda bin_y: tf.gather(bin_y, [neg_index], axis=gather_axis) | |
def conf_mat(y_true, y_pred): | |
y_true = binarize(y_true) | |
y_pred = binarize(y_pred) | |
""" | |
Args: | |
y_true: | |
y_pred: | |
Returns: | |
true_positives (p=T,a=T) | |
false_positives (p=T,a=F) | |
true_negatives (p=F,a=F) | |
false_negatives (p=F,a=T) | |
""" | |
_example = """ | |
# https://stackoverflow.com/questions/35833011/how-to-add-if-condition-in-a-tensorflow-graph?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa | |
x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input") | |
condition = tf.placeholder(tf.int32, shape=[], name="condition") | |
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights") | |
b = tf.Variable(tf.zeros([label_option]), name="bias") | |
y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b) | |
""" | |
actual_neg = neg_val(y_true) | |
predict_neg = neg_val(y_true) | |
label_pos_pred = gather(y_pred) | |
label_pos_true = gather(y_true) | |
# tn = 1 if actual and predict are negative, 0 otherwise | |
tn = K.sum(to_int(tf.logical_and(actual_neg, predict_neg))) | |
# tp = sum(and(i,j from pos_pred_indices)) over (y_true, y_pred)) | |
tp = K.sum(to_int(tf.logical_and(label_pos_true, label_pos_pred))) | |
# fp = sum(and(i,j from pos_pred_indices)) over (NOT(y_true), y_pred)) | |
fp = K.sum(to_int(tf.logical_and(tf.logical_not(label_pos_true), label_pos_pred))) | |
# fn = 1 if actual is not negative and predict is negative, 0 otherwise | |
fn = K.sum(to_int(tf.logical_and(tf.logical_and(predict_neg, true), | |
tf.logical_and(actual_neg, false)))) | |
return tp, fp, tn, fn | |
return conf_mat | |
NEG_INDEX = 0 | |
POS_PRED_INDICES = [1, 2] | |
CONFUSION_MATRIX = bind_conf_mat(POS_PRED_INDICES, NEG_INDEX, gather_axis=1) | |
TWO_FLOAT32 = tf.constant(2.0, dtype=tf.float32) | |
def recall(y_true, y_pred): | |
"""Recall metric. | |
Only computes a batch-wise average of recall. | |
Computes the recall, a metric for multi-label classification of | |
how many relevant items are selected. | |
""" | |
tp, _, _, fn = CONFUSION_MATRIX(y_true, y_pred) | |
return tf.divide(tf.to_float(tp), tf.add(tf.to_float(tf.add(tp, fn)), | |
K.epsilon())) | |
def precision(y_true, y_pred): | |
"""Precision metric. | |
Only computes a batch-wise average of precision. | |
Computes the precision, a metric for multi-label classification of | |
how many selected items are relevant. | |
""" | |
tp, fp, _, _ = CONFUSION_MATRIX(y_true, y_pred) | |
return tf.divide(tf.to_float(tp), tf.add(tf.to_float(tf.add(tp, fp)), | |
K.epsilon())) | |
def f1(y_true, y_pred): | |
"""F1 metric. | |
Computes batch-wide F1 calculation: the harmonic average of preicison and recall. | |
""" | |
# f1 = 2.0 * ((p * r) / (p + r)) | |
p = precision(y_true, y_pred) | |
r = recall(y_true, y_pred) | |
return tf.multiply(TWO_FLOAT32, | |
tf.divide(tf.multiply(p, r), | |
tf.add(tf.add(p, r), | |
model.compile( | |
optimizer='rmsprop', | |
loss='categorical_crossentropy', | |
metrics=[f1, precision, recall, 'accuracy'], | |
target_tensors=[tf.placeholder(tf.float32, shape=(None, 3))], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment