Skip to content

Instantly share code, notes, and snippets.

@xmodar
Created October 8, 2021 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/f93de5f53bd25e4e85bdc6d0906378d8 to your computer and use it in GitHub Desktop.
Save xmodar/f93de5f53bd25e4e85bdc6d0906378d8 to your computer and use it in GitHub Desktop.
Differentiable mask for logits before a softmax operation
import torch
__all__ = ['softmax_mask']
class SoftmaxMask(torch.autograd.Function):
"""Differentiable mask for logits before a softmax operation"""
@staticmethod
def forward(ctx, *args, **kwargs):
inputs, = args
if ctx.needs_input_grad[0]:
ctx.save_for_backward(inputs)
ctx.set_materialize_grads(False)
out = torch.empty_like(inputs)
torch.greater(inputs, 0, out=out)
return out.reciprocal_().neg_().add_(1)
@staticmethod
def backward(ctx, *grad_outputs):
grad = None
if grad_outputs[0] is not None and ctx.needs_input_grad[0]:
grad = (-ctx.saved_tensors[0]).exp_()
if grad_outputs[0].requires_grad: # higher order derivatives
grad = grad * grad_outputs[0]
else:
grad.mul_(grad_outputs[0])
return grad
def softmax_mask(inputs, scale=1, hard=True):
"""Get a differentiable mask for logits before a softmax operation
To mask out selected softmax elements, their logit values has to be -inf.
This can be achieved through an additive mask `softmax(logits + mask)`;
the mask values are -inf for the masked elements and zero everywhere else.
Given a binary mask `x`, we can compute our additive mask as `1 - 1 / x`.
Now, to make this mask differentiable, we replace the hard mask `x` with
a soft version `sigmoid(inputs)` where `x = (inputs > 0)`.
Args:
inputs: tensor of any shape
scale: positive scalar will be multiplied by `inputs` (decisiveness)
hard: Whether to use the hard mask while keeping it differentiable
Returns:
Computed softmax mask
"""
if hard:
# `hard = 1 - 1 / (inputs > 0)` (uses the derivative of `soft`)
return SoftmaxMask.apply(scale * inputs)
# `soft = 1 - 1 / sigmoid(inputs) = -exp(-inputs)`
return -(-scale * inputs).exp_()
@xmodar
Copy link
Author

xmodar commented Oct 9, 2021

An alternative would be to use the derivative of g(x / c), where g(x) = x - exp(-x) and c is some constant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment