Skip to content

Instantly share code, notes, and snippets.

@alper111
Last active November 23, 2023 18:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alper111/ed57eae08ca6e0822ebe313eda0b5e2b to your computer and use it in GitHub Desktop.
Save alper111/ed57eae08ca6e0822ebe313eda0b5e2b to your computer and use it in GitHub Desktop.
This snippet compares the sigmoid function's response and derivative with the Gumbel-sigmoid's.
import torch
import matplotlib.pyplot as plt
def sample_gumbel_diff(*shape):
eps = 1e-20
u1 = torch.rand(shape)
u2 = torch.rand(shape)
diff = torch.log(torch.log(u2+eps)/torch.log(u1+eps)+eps)
return diff
def gumbel_sigmoid(logits, T=1.0, hard=False):
g = sample_gumbel_diff(*logits.shape)
g = g.to(logits.device)
y = (g + logits) / T
s = torch.sigmoid(y)
if hard:
s_hard = s.round()
s = (s_hard - s).detach() + s
return s
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
# Sigmoid
x = torch.linspace(-8, 8, 100) # input x \in [-8, 8]
x.requires_grad = True # turn on gradient logging
y_sig = torch.sigmoid(x) # sigmoid function
y_sig.backward(torch.ones_like(x)) # backpropagate gradients
ax[0].plot(x.detach().numpy(),
y_sig.detach().numpy(),
label="sigmoid(x)", c="r")
ax[1].plot(x.detach().numpy(),
x.grad.detach().numpy(),
label="sigmoid'(x)", c="r")
# Sigmoid with increased slope
x = torch.linspace(-8, 8, 100) # input x \in [-8, 8]
x.requires_grad = True # turn on gradient logging
y_sig = torch.sigmoid(x*10) # sigmoid function with increased slope
y_sig.backward(torch.ones_like(x)) # backpropagate gradients
ax[2].plot(x.detach().numpy(),
y_sig.detach().numpy(),
label="sigmoid(x*10)", c="g")
ax[2].plot(x.detach().numpy(),
x.grad.detach().numpy(),
label="sigmoid'(x*10)", c="g", linestyle="--")
# Gumbel-sigmoid
x = torch.linspace(-8, 8, 100).repeat(200, 1) # repeat each segment 200 times to
# visualize the stochasticity
x.requires_grad = True
yg = gumbel_sigmoid(x, T=1, hard=False) # use the default temperature
yg.backward(torch.ones_like(yg)) # backpropagate gradients
ax[0].scatter(x.reshape(-1).detach(),
yg.reshape(-1).detach(),
label="gumbel(x)", c="b", alpha=0.05)
ax[1].scatter(x.reshape(-1).detach(),
x.grad.data.reshape(-1),
label="gumbel'(x)", c="b", alpha=0.05)
ax[0].legend()
ax[1].legend()
ax[2].legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment