Skip to content

Instantly share code, notes, and snippets.

@tejaskhot
Last active February 11, 2023 18:33
Show Gist options
  • Save tejaskhot/fde6ca39209a2a6b6f1ebfad0d99f5ce to your computer and use it in GitHub Desktop.
Save tejaskhot/fde6ca39209a2a6b6f1ebfad0d99f5ce to your computer and use it in GitHub Desktop.
soft argmax in python
# argmax is not differentiable, so the hack to get argmax is
# softmax(x)^T * range(indices)
# numpy
beta = 12
y_est = np.array([[1.1, 3.0, 1.1, 1.3, 0.8]])
# multiplying by some large constant beta to make the resulting
# distribution more peaky near the max
a = np.exp(beta*y_est)
b = np.sum(np.exp(beta*y_est))
softmax = a/b
# computing the max value in a soft way
ymax = np.sum(softmax*y_est)
print(ymax)
pos = range(y_est.size)
# computing the argmax value in a soft way
softargmax = np.sum(softmax*pos)
print(softargmax)
# pytorch
class SoftArgmax1D(torch.nn.Module):
"""
Implementation of a 1d soft arg-max function as an nn.Module, so that we can differentiate through arg-max operations.
"""
def __init__(self, base_index=0, step_size=1):
"""
The "arguments" are base_index, base_index+step_size, base_index+2*step_size, ... and so on for
arguments at indices 0, 1, 2, ....
Assumes that the input to this layer will be a batch of 1D tensors (so a 2D tensor).
:param base_index: Remember a base index for 'indices' for the input
:param step_size: Step size for 'indices' from the input
"""
super(SoftArgmax1D, self).__init__()
self.base_index = base_index
self.step_size = step_size
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, x):
"""
Compute the forward pass of the 1D soft arg-max function as defined below:
SoftArgMax(x) = \sum_i (i * softmax(x)_i)
:param x: The input to the soft arg-max layer
:return: Output of the soft arg-max layer
"""
smax = self.softmax(x)
end_index = self.base_index + x.size()[1] * self.step_size
indices = torch.arange(start=self.base_index, end=end_index, step=self.step_size)
return torch.matmul(smax, indices)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment