|
import tensorflow as tf |
|
|
|
def entmax15(inputs, axis=-1): |
|
""" |
|
Entmax 1.5 implementation, heavily inspired by |
|
* paper: https://arxiv.org/pdf/1905.05702.pdf |
|
* pytorch code: https://github.com/deep-spin/entmax |
|
:param inputs: similar to softmax logits, but for entmax1.5 |
|
:param axis: entmax1.5 outputs will sum to 1 over this axis |
|
:return: entmax activations of same shape as inputs |
|
""" |
|
@tf.custom_gradient |
|
def _entmax_inner(inputs): |
|
with tf.name_scope('entmax'): |
|
inputs = inputs / 2 # divide by 2 so as to solve actual entmax |
|
inputs -= tf.reduce_max(inputs, axis, keep_dims=True) # subtract max for stability |
|
|
|
threshold, _ = entmax_threshold_and_support(inputs, axis) |
|
outputs_sqrt = tf.nn.relu(inputs - threshold) |
|
outputs = tf.square(outputs_sqrt) |
|
|
|
def grad_fn(d_outputs): |
|
with tf.name_scope('entmax_grad'): |
|
d_inputs = d_outputs * outputs_sqrt |
|
q = tf.reduce_sum(d_inputs, axis=axis, keep_dims=True) |
|
q = q / tf.reduce_sum(outputs_sqrt, axis=axis, keep_dims=True) |
|
d_inputs -= q * outputs_sqrt |
|
return d_inputs |
|
|
|
return outputs, grad_fn |
|
|
|
return _entmax_inner(inputs) |
|
|
|
|
|
@tf.custom_gradient |
|
def sparse_entmax15_loss_with_logits(labels, logits): |
|
""" |
|
Computes sample-wise entmax1.5 loss |
|
:param labels: reference answers vector int64[batch_size] \in [0, num_classes) |
|
:param logits: output matrix float32[batch_size, num_classes] (not actually logits :) |
|
:returns: elementwise loss, float32[batch_size] |
|
""" |
|
assert logits.shape.ndims == 2 and labels.shape.ndims == 1 |
|
with tf.name_scope('entmax_loss'): |
|
p_star = entmax15(logits, axis=-1) |
|
omega_entmax15 = (1 - (tf.reduce_sum(p_star * tf.sqrt(p_star), axis=-1))) / 0.75 |
|
p_incr = p_star - tf.one_hot(labels, depth=tf.shape(logits)[-1], axis=-1) |
|
loss = omega_entmax15 + tf.einsum("ij,ij->i", p_incr, logits) |
|
|
|
def grad_fn(grad_output): |
|
with tf.name_scope('entmax_loss_grad'): |
|
return None, grad_output[..., None] * p_incr |
|
|
|
return loss, grad_fn |
|
|
|
|
|
@tf.custom_gradient |
|
def entmax15_loss_with_logits(labels, logits): |
|
""" |
|
Computes sample-wise entmax1.5 loss |
|
:param logits: "logits" matrix float32[batch_size, num_classes] |
|
:param labels: reference answers indicators, float32[batch_size, num_classes] |
|
:returns: elementwise loss, float32[batch_size] |
|
|
|
WARNING: this function does not propagate gradients through :labels: |
|
This behavior is the same as like softmax_crossentropy_with_logits v1 |
|
It may become an issue if you do something like co-distillation |
|
""" |
|
assert labels.shape.ndims == logits.shape.ndims == 2 |
|
with tf.name_scope('entmax_loss'): |
|
p_star = entmax15(logits, axis=-1) |
|
omega_entmax15 = (1 - (tf.reduce_sum(p_star * tf.sqrt(p_star), axis=-1))) / 0.75 |
|
p_incr = p_star - labels |
|
loss = omega_entmax15 + tf.einsum("ij,ij->i", p_incr, logits) |
|
|
|
def grad_fn(grad_output): |
|
with tf.name_scope('entmax_loss_grad'): |
|
return None, grad_output[..., None] * p_incr |
|
|
|
return loss, grad_fn |
|
|
|
|
|
def top_k_over_axis(inputs, k, axis=-1, **kwargs): |
|
""" performs tf.nn.top_k over any chosen axis """ |
|
with tf.name_scope('top_k_along_axis'): |
|
if axis == -1: |
|
return tf.nn.top_k(inputs, k, **kwargs) |
|
|
|
perm_order = list(range(inputs.shape.ndims)) |
|
perm_order.append(perm_order.pop(axis)) |
|
inv_order = [perm_order.index(i) for i in range(len(perm_order))] |
|
|
|
input_perm = tf.transpose(inputs, perm_order) |
|
input_perm_sorted, sort_indices_perm = tf.nn.top_k( |
|
input_perm, k=k, **kwargs) |
|
|
|
input_sorted = tf.transpose(input_perm_sorted, inv_order) |
|
sort_indices = tf.transpose(sort_indices_perm, inv_order) |
|
return input_sorted, sort_indices |
|
|
|
|
|
def _make_ix_like(inputs, axis=-1): |
|
""" creates indices 0, ... , input[axis] unsqueezed to input dimensios """ |
|
assert inputs.shape.ndims is not None |
|
rho = tf.cast(tf.range(1, tf.shape(inputs)[axis] + 1), dtype=inputs.dtype) |
|
view = [1] * inputs.shape.ndims |
|
view[axis] = -1 |
|
return tf.reshape(rho, view) |
|
|
|
|
|
def gather_over_axis(values, indices, gather_axis): |
|
""" |
|
replicates the behavior of torch.gather for tf<=1.8; |
|
for newer versions use tf.gather with batch_dims |
|
:param values: tensor [d0, ..., dn] |
|
:param indices: int64 tensor of same shape as values except for gather_axis |
|
:param gather_axis: performs gather along this axis |
|
:returns: gathered values, same shape as values except for gather_axis |
|
If gather_axis == 2 |
|
gathered_values[i, j, k, ...] = values[i, j, indices[i, j, k, ...], ...] |
|
see torch.gather for more detils |
|
""" |
|
assert indices.shape.ndims is not None |
|
assert indices.shape.ndims == values.shape.ndims |
|
|
|
ndims = indices.shape.ndims |
|
gather_axis = gather_axis % ndims |
|
shape = tf.shape(indices) |
|
|
|
selectors = [] |
|
for axis_i in range(ndims): |
|
if axis_i == gather_axis: |
|
selectors.append(indices) |
|
else: |
|
index_i = tf.range(tf.cast(shape[axis_i], dtype=indices.dtype), dtype=indices.dtype) |
|
index_i = tf.reshape(index_i, [-1 if i == axis_i else 1 for i in range(ndims)]) |
|
index_i = tf.tile(index_i, [shape[i] if i != axis_i else 1 for i in range(ndims)]) |
|
selectors.append(index_i) |
|
|
|
return tf.gather_nd(values, tf.stack(selectors, axis=-1)) |
|
|
|
|
|
def entmax_threshold_and_support(inputs, axis=-1): |
|
""" |
|
Computes clipping threshold for entmax1.5 over specified axis |
|
NOTE this implementation uses the same heuristic as |
|
the original code: https://tinyurl.com/pytorch-entmax-line-203 |
|
:param inputs: (entmax1.5 inputs - max) / 2 |
|
:param axis: entmax1.5 outputs will sum to 1 over this axis |
|
""" |
|
|
|
with tf.name_scope('entmax_threshold_and_support'): |
|
num_outcomes = tf.shape(inputs)[axis] |
|
inputs_sorted, _ = top_k_over_axis(inputs, k=num_outcomes, axis=axis, sorted=True) |
|
|
|
rho = _make_ix_like(inputs, axis=axis) |
|
|
|
mean = tf.cumsum(inputs_sorted, axis=axis) / rho |
|
|
|
mean_sq = tf.cumsum(tf.square(inputs_sorted), axis=axis) / rho |
|
delta = (1 - rho * (mean_sq - tf.square(mean))) / rho |
|
|
|
delta_nz = tf.nn.relu(delta) |
|
tau = mean - tf.sqrt(delta_nz) |
|
|
|
support_size = tf.reduce_sum(tf.to_int64(tf.less_equal(tau, inputs_sorted)), axis=axis, keep_dims=True) |
|
|
|
tau_star = gather_over_axis(tau, support_size - 1, axis) |
|
return tau_star, support_size |
updated the entmax15 function for tf2.0+
https://gist.github.com/BenjaminWegener/8fad40ffd80fbe9087d13ad464a48ca9