Created
March 18, 2020 23:53
-
-
Save hsahovic/cf6b929ee95910f3d22d3899b3a999ce to your computer and use it in GitHub Desktop.
A custom tensorflow / keras loss implementing OHEM (https://arxiv.org/abs/1604.03540) with cross-entropy.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.losses import categorical_crossentropy | |
@tf.function | |
def ohem_crossentropy_loss(y_true, y_pred): | |
# You can apply OHEM with any loss you want | |
# To do so, you just need to change the base loss below | |
cross_entropy = categorical_crossentropy(y_true, y_pred) | |
# You can tune how many examples are rejected by modifying the `80` value below | |
# If left unchanged, the bottom 80% examples with the smallest loss will not be used | |
# during backpropagation | |
min_loss = tfp.stats.percentile(cross_entropy, 80, interpolation='midpoint') | |
return tf.boolean_mask( | |
cross_entropy, cross_entropy > min_loss | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment