Skip to content

Instantly share code, notes, and snippets.

@bricksdont
Last active March 14, 2020 09:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bricksdont/812b4d6a21ab045da771560ec9af8c11 to your computer and use it in GitHub Desktop.
Save bricksdont/812b4d6a21ab045da771560ec9af8c11 to your computer and use it in GitHub Desktop.
# Author: Mathias Müller / mmueller@cl.uzh.ch
import mxnet as mx
import numpy as np
from typing import Optional, List
SOFTMAX_NAME = "softmax"
WEIGHTED_CROSS_ENTROPY_NAME = "weighted_cross_entropy"
NORMALIZATION_CHOICES = ["null", "valid", "batch"]
def WeightedCrossEntropyLoss(logits: mx.sym.Symbol,
labels: mx.sym.Symbol,
weights: mx.sym.Symbol,
grad_scale: float = 1.0,
ignore_label: int = 0,
use_ignore: bool = False,
use_weight: bool = False,
normalization: str = "null",
smooth_alpha: float = 0.0,
num_classes: Optional[int] = None,
softmax_name: Optional[str] = SOFTMAX_NAME,
cross_entropy_name: Optional[str] = WEIGHTED_CROSS_ENTROPY_NAME) -> List[mx.sym.Symbol]:
"""
:param logits: Predictions as input for loss. Assumes softmax was *not* applied to logits yet.
Shape: (batch_size * target_seq_len, target_vocab_size)
:param labels: Labels as input for loss.
Shape: (batch_size * target_seq_len)
:param weights: Float weights, one for each item in the batch.
Shape: (batch_size * target_seq_len)
"""
assert normalization in NORMALIZATION_CHOICES
if smooth_alpha > 0.0 or use_ignore:
assert isinstance(num_classes, int), "If label smoothing or use_ignore is enabled, num_classes must be set."
# probs: (batch_size * target_seq_len, target_vocab_size)
probs = mx.sym.softmax(data=logits, axis=1, name="step_softmax")
# logprobs: (batch_size * target_seq_len, target_vocab_size)
logprobs = mx.sym.log(mx.sym.maximum(1e-10, probs), name="step_log")
if smooth_alpha > 0.0:
on_value = 1.0 - smooth_alpha
off_value = smooth_alpha / (num_classes - 1.0)
smoothed_labels = mx.sym.one_hot(indices=mx.sym.cast(data=labels, dtype='int32'),
depth=num_classes,
on_value=on_value,
off_value=off_value, name="step_one_hot")
# ignore PAD_ID
if use_ignore:
smoothed_labels = mx.sym.where(labels != ignore_label, smoothed_labels, mx.sym.zeros_like(smoothed_labels), name="step_where")
ce = smoothed_labels * - logprobs
ce = mx.sym.sum(data=ce, axis=1)
else:
ce = -mx.sym.pick(logprobs, labels, name="step_pick")
# ignore PAD_ID
if use_ignore:
ce = mx.sym.where(labels != ignore_label, ce, mx.sym.zeros_like(ce), name="step_where")
if use_weight:
ce = ce * weights
loss_value = mx.sym.MakeLoss(data=ce,
grad_scale=grad_scale,
normalization=normalization,
name=cross_entropy_name)
weights = mx.sym.BlockGrad(weights)
probs = mx.sym.BlockGrad(probs, name=softmax_name)
return [loss_value, probs]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment