Skip to content

Instantly share code, notes, and snippets.

@biggzlar
Last active March 8, 2022 10:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save biggzlar/39199226dfb3d0923aaf72acb76e6978 to your computer and use it in GitHub Desktop.
Save biggzlar/39199226dfb3d0923aaf72acb76e6978 to your computer and use it in GitHub Desktop.
Argmax vs. softargmax (play with beta)
import numpy as np
import matplotlib.pyplot as plt
def softmax(x, beta=1.0):
return np.exp(beta * x) / np.sum(np.exp(beta * x))
def softargmax(x, beta=1.0):
return np.sum((np.exp(beta * x) / np.sum(np.exp(beta * x))) * np.arange(len(x)))
dim = 8
samples = 32
beta = 1
x = np.random.random(dim)
softmax_a = []
softargmax_a = []
for i in range(10000):
b = np.random.random(dim)
if np.abs(np.argmax(softmax(x, beta)) - np.argmax(softmax(b, beta))) < 0.005:
softmax_a += [b]
if np.abs(softargmax(x, beta) - softargmax(b, beta)) < 0.005:
softargmax_a += [b]
softmax_a = np.array(softmax_a).squeeze()
softargmax_a = np.array(softargmax_a).squeeze()
fig, axs = plt.subplots(2)
fig.suptitle('Illustration of inputs with equal outputs @' + r'$\beta=' + f'{beta}$')
axs[0].set_title('Inputs with equal argmax')
axs[0].imshow(softmax_a[:samples].T)
axs[1].set_title('Inputs with equal ' + r'$\bf{soft}$' + 'argmax')
axs[1].imshow(softargmax_a[:samples].T)
plt.tight_layout()
plt.savefig("soft_arg_max.png", dpi=300)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment