Skip to content

Instantly share code, notes, and snippets.

@xmodar
Created October 8, 2021 16:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/7c4778926861f48671a3caffd84d1f19 to your computer and use it in GitHub Desktop.
Save xmodar/7c4778926861f48671a3caffd84d1f19 to your computer and use it in GitHub Desktop.
Gubmbel Softmax
import torch
def randg(*args, like=None, **kwargs):
"""Sample from Gumbel(location=0, scale=1)"""
generator = kwargs.pop('generator', None)
requires_grad = kwargs.pop('requires_grad', False)
if like is None:
samples = torch.empty(*args, **kwargs)
else:
samples = torch.empty_like(like, *args, **kwargs)
samples.exponential_(generator=generator).log_().neg_()
return samples.requires_grad_(requires_grad)
def gumbel(location, scale=1, generator=None):
"""Sample from a Gumbel distribution using the reparameterization trick"""
gumbels = torch.empty_like(location).exponential_(generator=generator)
return gumbels.log_().sub_(location).mul_(-1 / scale)
def gumbel_softmax(logits, hard=True, temperature=1, dim=-1, generator=None):
"""Sample from Gumbel-Softmax; flexible and efficient `F.gumbel_softmax`"""
# sample soft probabilities `softmax(Gumbel(logits, temperature))`
soft = gumbel(logits, temperature, generator).softmax(dim=dim)
if not hard:
return soft
# get hard probabilities (one_hot(soft))
index = soft.argmax(dim, keepdim=True)
hard = torch.zeros_like(logits).scatter_(dim, index, 1)
# make the hard probabilites differentiable
if soft.requires_grad:
hard.sub_(soft.detach()).add_(soft)
return hard
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment