Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active May 8, 2021 14:58
Show Gist options
  • Save justheuristic/60167e77a95221586be315ae527c3cbd to your computer and use it in GitHub Desktop.
Save justheuristic/60167e77a95221586be315ae527c3cbd to your computer and use it in GitHub Desktop.
import tensorflow as tf
def entmax15(inputs, axis=-1):
"""
Entmax 1.5 implementation, heavily inspired by
* paper: https://arxiv.org/pdf/1905.05702.pdf
* pytorch code: https://github.com/deep-spin/entmax
:param inputs: similar to softmax logits, but for entmax1.5
:param axis: entmax1.5 outputs will sum to 1 over this axis
:return: entmax activations of same shape as inputs
"""
@tf.custom_gradient
def _entmax_inner(inputs):
with tf.name_scope('entmax'):
inputs = inputs / 2 # divide by 2 so as to solve actual entmax
inputs -= tf.reduce_max(inputs, axis, keep_dims=True) # subtract max for stability
threshold, _ = entmax_threshold_and_support(inputs, axis)
outputs_sqrt = tf.nn.relu(inputs - threshold)
outputs = tf.square(outputs_sqrt)
def grad_fn(d_outputs):
with tf.name_scope('entmax_grad'):
d_inputs = d_outputs * outputs_sqrt
q = tf.reduce_sum(d_inputs, axis=axis, keep_dims=True)
q = q / tf.reduce_sum(outputs_sqrt, axis=axis, keep_dims=True)
d_inputs -= q * outputs_sqrt
return d_inputs
return outputs, grad_fn
return _entmax_inner(inputs)
@tf.custom_gradient
def sparse_entmax15_loss_with_logits(labels, logits):
"""
Computes sample-wise entmax1.5 loss
:param labels: reference answers vector int64[batch_size] \in [0, num_classes)
:param logits: output matrix float32[batch_size, num_classes] (not actually logits :)
:returns: elementwise loss, float32[batch_size]
"""
assert logits.shape.ndims == 2 and labels.shape.ndims == 1
with tf.name_scope('entmax_loss'):
p_star = entmax15(logits, axis=-1)
omega_entmax15 = (1 - (tf.reduce_sum(p_star * tf.sqrt(p_star), axis=-1))) / 0.75
p_incr = p_star - tf.one_hot(labels, depth=tf.shape(logits)[-1], axis=-1)
loss = omega_entmax15 + tf.einsum("ij,ij->i", p_incr, logits)
def grad_fn(grad_output):
with tf.name_scope('entmax_loss_grad'):
return None, grad_output[..., None] * p_incr
return loss, grad_fn
@tf.custom_gradient
def entmax15_loss_with_logits(labels, logits):
"""
Computes sample-wise entmax1.5 loss
:param logits: "logits" matrix float32[batch_size, num_classes]
:param labels: reference answers indicators, float32[batch_size, num_classes]
:returns: elementwise loss, float32[batch_size]
WARNING: this function does not propagate gradients through :labels:
This behavior is the same as like softmax_crossentropy_with_logits v1
It may become an issue if you do something like co-distillation
"""
assert labels.shape.ndims == logits.shape.ndims == 2
with tf.name_scope('entmax_loss'):
p_star = entmax15(logits, axis=-1)
omega_entmax15 = (1 - (tf.reduce_sum(p_star * tf.sqrt(p_star), axis=-1))) / 0.75
p_incr = p_star - labels
loss = omega_entmax15 + tf.einsum("ij,ij->i", p_incr, logits)
def grad_fn(grad_output):
with tf.name_scope('entmax_loss_grad'):
return None, grad_output[..., None] * p_incr
return loss, grad_fn
def top_k_over_axis(inputs, k, axis=-1, **kwargs):
""" performs tf.nn.top_k over any chosen axis """
with tf.name_scope('top_k_along_axis'):
if axis == -1:
return tf.nn.top_k(inputs, k, **kwargs)
perm_order = list(range(inputs.shape.ndims))
perm_order.append(perm_order.pop(axis))
inv_order = [perm_order.index(i) for i in range(len(perm_order))]
input_perm = tf.transpose(inputs, perm_order)
input_perm_sorted, sort_indices_perm = tf.nn.top_k(
input_perm, k=k, **kwargs)
input_sorted = tf.transpose(input_perm_sorted, inv_order)
sort_indices = tf.transpose(sort_indices_perm, inv_order)
return input_sorted, sort_indices
def _make_ix_like(inputs, axis=-1):
""" creates indices 0, ... , input[axis] unsqueezed to input dimensios """
assert inputs.shape.ndims is not None
rho = tf.cast(tf.range(1, tf.shape(inputs)[axis] + 1), dtype=inputs.dtype)
view = [1] * inputs.shape.ndims
view[axis] = -1
return tf.reshape(rho, view)
def gather_over_axis(values, indices, gather_axis):
"""
replicates the behavior of torch.gather for tf<=1.8;
for newer versions use tf.gather with batch_dims
:param values: tensor [d0, ..., dn]
:param indices: int64 tensor of same shape as values except for gather_axis
:param gather_axis: performs gather along this axis
:returns: gathered values, same shape as values except for gather_axis
If gather_axis == 2
gathered_values[i, j, k, ...] = values[i, j, indices[i, j, k, ...], ...]
see torch.gather for more detils
"""
assert indices.shape.ndims is not None
assert indices.shape.ndims == values.shape.ndims
ndims = indices.shape.ndims
gather_axis = gather_axis % ndims
shape = tf.shape(indices)
selectors = []
for axis_i in range(ndims):
if axis_i == gather_axis:
selectors.append(indices)
else:
index_i = tf.range(tf.cast(shape[axis_i], dtype=indices.dtype), dtype=indices.dtype)
index_i = tf.reshape(index_i, [-1 if i == axis_i else 1 for i in range(ndims)])
index_i = tf.tile(index_i, [shape[i] if i != axis_i else 1 for i in range(ndims)])
selectors.append(index_i)
return tf.gather_nd(values, tf.stack(selectors, axis=-1))
def entmax_threshold_and_support(inputs, axis=-1):
"""
Computes clipping threshold for entmax1.5 over specified axis
NOTE this implementation uses the same heuristic as
the original code: https://tinyurl.com/pytorch-entmax-line-203
:param inputs: (entmax1.5 inputs - max) / 2
:param axis: entmax1.5 outputs will sum to 1 over this axis
"""
with tf.name_scope('entmax_threshold_and_support'):
num_outcomes = tf.shape(inputs)[axis]
inputs_sorted, _ = top_k_over_axis(inputs, k=num_outcomes, axis=axis, sorted=True)
rho = _make_ix_like(inputs, axis=axis)
mean = tf.cumsum(inputs_sorted, axis=axis) / rho
mean_sq = tf.cumsum(tf.square(inputs_sorted), axis=axis) / rho
delta = (1 - rho * (mean_sq - tf.square(mean))) / rho
delta_nz = tf.nn.relu(delta)
tau = mean - tf.sqrt(delta_nz)
support_size = tf.reduce_sum(tf.to_int64(tf.less_equal(tau, inputs_sorted)), axis=axis, keep_dims=True)
tau_star = gather_over_axis(tau, support_size - 1, axis)
return tau_star, support_size

This is free and unencumbered software released into the public domain.

Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means.

In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

For more information, please refer to http://unlicense.org/

@BenjaminWegener
Copy link

BenjaminWegener commented May 6, 2021

@justheuristic
Copy link
Author

justheuristic commented May 7, 2021

great job! 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment