Skip to content

Instantly share code, notes, and snippets.

@freifrauvonbleifrei
Last active July 1, 2019 10:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save freifrauvonbleifrei/9b24ea1a715e2493d68a0660f8ab180e to your computer and use it in GitHub Desktop.
Save freifrauvonbleifrei/9b24ea1a715e2493d68a0660f8ab180e to your computer and use it in GitHub Desktop.
Normalized Risk-Averting Error Loss in tensorflow
#from tensorflow internals
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
Args:
losses: `Tensor` whose elements contain individual loss measurements.
num_present: The number of measurable elements in `losses`.
Returns:
A scalar representing the mean of `losses`. If `num_present` is zero,
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return math_ops.div_no_nan(total_loss, num_present, name="value")
def _num_elements(losses):
"""Computes the number of elements in `losses` tensor."""
with ops.name_scope(None, "num_elements", values=[losses]) as scope:
return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
# http://www.math.umbc.edu/~jameslo/papers/isnn12nrae.pdf NRAE
def loss_nrae(labels, predictions):
# from mean_squared_error definition:
predictions = math_ops.cast(predictions, dtype=dtypes.float32)
labels = math_ops.cast(labels, dtype=dtypes.float32)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
losses = math_ops.squared_difference(predictions, labels)
# now weigh by lambda and exp
l_lambda = 1e1
losses = ops.convert_to_tensor(losses)
input_dtype = losses.dtype
weights = math_ops.cast(l_lambda, dtype=dtypes.float32)
weighted_losses = math_ops.multiply(losses, weights)
weighted_losses = math_ops.exp(weighted_losses)
#reduce by summing
# reduction == Reduction.SUM_OVER_BATCH_SIZE:
loss = math_ops.reduce_sum(weighted_losses)
loss = _safe_mean(weighted_losses, _num_elements(losses))
# re-normalize
c = 1/l_lambda * math_ops.log(loss)
# Convert the result back to the input type.
c = math_ops.cast(c, input_dtype)
return c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment