Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jonnyli1125/5384bb9a41caaac983f1cd737359c6c2 to your computer and use it in GitHub Desktop.
Save jonnyli1125/5384bb9a41caaac983f1cd737359c6c2 to your computer and use it in GitHub Desktop.
SparseCategoricalCrossentropy with class weights for Keras/Tensorflow 2
"""
Since Model.fit doesn't support class_weight when using multiple outputs,
this custom loss subclass may be useful.
Relevant issues:
https://github.com/keras-team/keras/issues/11735
https://github.com/tensorflow/tensorflow/issues/40457
https://github.com/tensorflow/tensorflow/issues/41448
"""
import tensorflow as tf
from tensorflow import keras
class WeightedSCCE(keras.losses.Loss):
def __init__(self, class_weight, from_logits=False, name='weighted_scce'):
if class_weight is None or all(v == 1. for v in class_weight):
self.class_weight = None
else:
self.class_weight = tf.convert_to_tensor(class_weight,
dtype=tf.float32)
self.reduction = keras.losses.Reduction.NONE
self.unreduced_scce = keras.losses.SparseCategoricalCrossentropy(
from_logits=from_logits, name=name,
reduction=self.reduction)
def __call__(self, y_true, y_pred, sample_weight=None):
loss = self.unreduced_scce(y_true, y_pred, sample_weight)
if self.class_weight is not None:
weight_mask = tf.gather(self.class_weight, y_true)
loss = tf.math.multiply(loss, weight_mask)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment