Skip to content

Instantly share code, notes, and snippets.

Created October 8, 2021 16:02
What would you like to do?
Gubmbel Softmax
import torch
def randg(*args, like=None, **kwargs):
"""Sample from Gumbel(location=0, scale=1)"""
generator = kwargs.pop('generator', None)
requires_grad = kwargs.pop('requires_grad', False)
if like is None:
samples = torch.empty(*args, **kwargs)
samples = torch.empty_like(like, *args, **kwargs)
return samples.requires_grad_(requires_grad)
def gumbel(location, scale=1, generator=None):
"""Sample from a Gumbel distribution using the reparameterization trick"""
gumbels = torch.empty_like(location).exponential_(generator=generator)
return gumbels.log_().sub_(location).mul_(-1 / scale)
def gumbel_softmax(logits, hard=True, temperature=1, dim=-1, generator=None):
"""Sample from Gumbel-Softmax; flexible and efficient `F.gumbel_softmax`"""
# sample soft probabilities `softmax(Gumbel(logits, temperature))`
soft = gumbel(logits, temperature, generator).softmax(dim=dim)
if not hard:
return soft
# get hard probabilities (one_hot(soft))
index = soft.argmax(dim, keepdim=True)
hard = torch.zeros_like(logits).scatter_(dim, index, 1)
# make the hard probabilites differentiable
if soft.requires_grad:
return hard
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment