Last active
February 11, 2023 18:33
-
-
Save tejaskhot/fde6ca39209a2a6b6f1ebfad0d99f5ce to your computer and use it in GitHub Desktop.
soft argmax in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# argmax is not differentiable, so the hack to get argmax is | |
# softmax(x)^T * range(indices) | |
# numpy | |
beta = 12 | |
y_est = np.array([[1.1, 3.0, 1.1, 1.3, 0.8]]) | |
# multiplying by some large constant beta to make the resulting | |
# distribution more peaky near the max | |
a = np.exp(beta*y_est) | |
b = np.sum(np.exp(beta*y_est)) | |
softmax = a/b | |
# computing the max value in a soft way | |
ymax = np.sum(softmax*y_est) | |
print(ymax) | |
pos = range(y_est.size) | |
# computing the argmax value in a soft way | |
softargmax = np.sum(softmax*pos) | |
print(softargmax) | |
# pytorch | |
class SoftArgmax1D(torch.nn.Module): | |
""" | |
Implementation of a 1d soft arg-max function as an nn.Module, so that we can differentiate through arg-max operations. | |
""" | |
def __init__(self, base_index=0, step_size=1): | |
""" | |
The "arguments" are base_index, base_index+step_size, base_index+2*step_size, ... and so on for | |
arguments at indices 0, 1, 2, .... | |
Assumes that the input to this layer will be a batch of 1D tensors (so a 2D tensor). | |
:param base_index: Remember a base index for 'indices' for the input | |
:param step_size: Step size for 'indices' from the input | |
""" | |
super(SoftArgmax1D, self).__init__() | |
self.base_index = base_index | |
self.step_size = step_size | |
self.softmax = torch.nn.Softmax(dim=1) | |
def forward(self, x): | |
""" | |
Compute the forward pass of the 1D soft arg-max function as defined below: | |
SoftArgMax(x) = \sum_i (i * softmax(x)_i) | |
:param x: The input to the soft arg-max layer | |
:return: Output of the soft arg-max layer | |
""" | |
smax = self.softmax(x) | |
end_index = self.base_index + x.size()[1] * self.step_size | |
indices = torch.arange(start=self.base_index, end=end_index, step=self.step_size) | |
return torch.matmul(smax, indices) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment