Skip to content

Instantly share code, notes, and snippets.

@hsahovic
Created March 18, 2020 23:53
Show Gist options
  • Save hsahovic/cf6b929ee95910f3d22d3899b3a999ce to your computer and use it in GitHub Desktop.
Save hsahovic/cf6b929ee95910f3d22d3899b3a999ce to your computer and use it in GitHub Desktop.
A custom tensorflow / keras loss implementing OHEM (https://arxiv.org/abs/1604.03540) with cross-entropy.
import tensorflow as tf
from tensorflow.keras.losses import categorical_crossentropy
@tf.function
def ohem_crossentropy_loss(y_true, y_pred):
# You can apply OHEM with any loss you want
# To do so, you just need to change the base loss below
cross_entropy = categorical_crossentropy(y_true, y_pred)
# You can tune how many examples are rejected by modifying the `80` value below
# If left unchanged, the bottom 80% examples with the smallest loss will not be used
# during backpropagation
min_loss = tfp.stats.percentile(cross_entropy, 80, interpolation='midpoint')
return tf.boolean_mask(
cross_entropy, cross_entropy > min_loss
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment