Skip to content

Instantly share code, notes, and snippets.

@albertbuchard
Last active April 18, 2023 18:31
Show Gist options
  • Save albertbuchard/5e1aeca423a2a5d080604f8cd11f65d7 to your computer and use it in GitHub Desktop.
Save albertbuchard/5e1aeca423a2a5d080604f8cd11f65d7 to your computer and use it in GitHub Desktop.
This code defines a PyTorch implementation of the Sparsemax activation function. Sparsemax is an alternative to the softmax activation function that produces sparse probability distributions (euclidian projection to the simplex). The implementation is provided as a PyTorch nn.Module, making it easy to integrate into any architecture.
import torch
import torch.nn as nn
class Sparsemax(nn.Module):
def __init__(self, dim=-1):
super(Sparsemax, self).__init__()
self.dim = dim
def forward(self, x):
# Move the dimension to apply Sparsemax to the front
x = x.transpose(self.dim, -1)
# Calculate the cumulative sum of the sorted input
z, _ = torch.sort(x, dim=-1, descending=True)
cumsums = torch.cumsum(z, dim=-1)
# Project to the simplex; see details in https://arxiv.org/pdf/1602.02068.pdf
K = torch.arange(1, x.shape[-1] + 1, device=x.device)
K = K.repeat(*x.shape[:-1], 1)
support = 1 + K * z - cumsums > 0
k_z = (K * support).max(dim=-1, keepdim=True).values
# Compute the threshold and apply it to the input
# (k_z - 1) is necessary to correct for the 1-indexing in the paper
cumsums_element = torch.gather(cumsums, dim=-1, index=(k_z - 1))
thresholds = (cumsums_element - 1) / k_z
output = torch.clamp(x - thresholds, min=0)
# Transpose back the dimensions
output = output.transpose(self.dim, -1)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment