Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active April 17, 2024 17:01
Show Gist options
  • Save sadimanna/4374ebd82c918c1f218a2b4a580fad3d to your computer and use it in GitHub Desktop.
Save sadimanna/4374ebd82c918c1f218a2b4a580fad3d to your computer and use it in GitHub Desktop.
@tf.keras.saving.register_keras_serializable(name="WeightedCategoricalCrossentropy")
class WeightedCategoricalCrossentropy:
def __init__(
self,
weights,
label_smoothing=0.0,
axis=-1,
name="weighted_categorical_crossentropy",
fn = None,
):
"""Initializes `WeightedCategoricalCrossentropy` instance.
Args:
from_logits: Whether to interpret `y_pred` as a tensor of
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
assume that `y_pred` contains probabilities (i.e., values in [0,
1]).
label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When >
0, we compute the loss between the predicted labels and a smoothed
version of the true labels, where the smoothing squeezes the labels
towards 0.5. Larger values of `label_smoothing` correspond to
heavier smoothing.
axis: The axis along which to compute crossentropy (the features
axis). Defaults to -1.
name: Name for the op. Defaults to 'weighted_categorical_crossentropy'.
"""
super().__init__()
self.weights = weights # tf.reshape(tf.convert_to_tensor(weights),(1,-1))
self.label_smoothing = label_smoothing
self.name = name
self.fn = weighted_categorical_crossentropy if fn is None else fn
def __call__(self, y_true, y_pred, axis=-1):
if isinstance(axis, bool):
raise ValueError(
"`axis` must be of type `int`. "
f"Received: axis={axis} of type {type(axis)}"
)
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
self.label_smoothing = tf.convert_to_tensor(self.label_smoothing, dtype=y_pred.dtype)
if y_pred.shape[-1] == 1:
warnings.warn(
"In loss categorical_crossentropy, expected "
"y_pred.shape to be (batch_size, num_classes) "
f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. "
"Consider using 'binary_crossentropy' if you only have 2 classes.",
SyntaxWarning,
stacklevel=2,
)
def _smooth_labels():
num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype)
return y_true * (1.0 - self.label_smoothing) + (self.label_smoothing / num_classes)
y_true = tf.__internal__.smart_cond.smart_cond(self.label_smoothing, _smooth_labels, lambda: y_true)
return tf.reduce_mean(self.fn(y_true, y_pred, self.weights, axis=axis))
def get_config(self):
config = {"name":self.name, "weights": self.weights, "fn": weighted_categorical_crossentropy}
# base_config = super().get_config()
return dict(list(config.items()))
@classmethod
def from_config(cls, config):
"""Instantiates a `Loss` from its config (output of `get_config()`).
Args:
config: Output of `get_config()`.
"""
if saving_lib.saving_v3_enabled():
fn_name = config.pop("fn", None)
if fn_name: # and cls is LossFunctionWrapper:
config["fn"] = get(fn_name)
return cls(**config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment