Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active August 28, 2023 18:19
Show Gist options
  • Save sadimanna/f5fe9096fbd1c69208ab51416c3e6e3b to your computer and use it in GitHub Desktop.
Save sadimanna/f5fe9096fbd1c69208ab51416c3e6e3b to your computer and use it in GitHub Desktop.
@tf.keras.saving.register_keras_serializable(name="weighted_categorical_crossentropy")
def weighted_categorical_crossentropy(target, output, weights, axis=-1):
target = tf.convert_to_tensor(target)
output = tf.convert_to_tensor(output)
target.shape.assert_is_compatible_with(output.shape)
weights = tf.reshape(tf.convert_to_tensor(weights, dtype=target.dtype), (1,-1))
# Adjust the predictions so that the probability of
# each class for every sample adds up to 1
# This is needed to ensure that the cross entropy is
# computed correctly.
output = output / tf.reduce_sum(output, axis, True)
# Compute cross entropy from probabilities.
epsilon_ = tf.constant(tf.keras.backend.epsilon(), output.dtype.base_dtype)
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
return -tf.reduce_sum(weights * target * tf.math.log(output), axis=axis)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment