Skip to content

Instantly share code, notes, and snippets.

@0xnurl
Created April 24, 2017 11:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 0xnurl/d2a4b708fbd2de8dd08ce7f70fc9a3e8 to your computer and use it in GitHub Desktop.
Save 0xnurl/d2a4b708fbd2de8dd08ce7f70fc9a3e8 to your computer and use it in GitHub Desktop.
Keras metric for non-null targets accuracy, mainly for Named Entity Recognition models
def non_null_label_accuracy(y_true, y_pred):
"""Calculate accuracy excluding targets that are the null label (at index 0).
Useful when the null target is over-represented in the data, like in Named Entity Recognition tasks.
typical y shape: (batch_size, sentence_length, num_labels)
"""
y_true_argmax = K.argmax(y_true, -1) # ==> (batch_size, sentence_length, 1)
y_pred_argmax = K.argmax(y_pred, -1) # ==> (batch_size, sentence_length, 1)
y_true_argmax_flat = tf.reshape(y_true_argmax, [-1])
y_pred_argmax_flat = tf.reshape(y_pred_argmax, [-1])
non_null_targets_bool = K.not_equal(y_true_argmax_flat, K.zeros_like(y_true_argmax_flat))
non_null_target_idx = K.flatten(K.cast(tf.where(non_null_targets_bool), 'int32'))
y_true_without_null = K.gather(y_true_argmax_flat, non_null_target_idx)
y_pred_without_null = K.gather(y_pred_argmax_flat, non_null_target_idx)
mean = K.mean(K.cast(K.equal(y_pred_without_null,
y_true_without_null),
K.floatx()))
fake_shape_mean = K.ones_like(y_true_argmax, K.floatx()) * mean # If model uses masking, Keras forces metric ouput to have same shape as y
return fake_shape_mean
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment