Skip to content

Instantly share code, notes, and snippets.

@kaniblu
Last active May 26, 2019 10:46
Show Gist options
  • Save kaniblu/4c193cd1e1e3c657f502ed210633dcf6 to your computer and use it in GitHub Desktop.
Save kaniblu/4c193cd1e1e3c657f502ed210633dcf6 to your computer and use it in GitHub Desktop.
Sampling from the Gumbel-softmax distribution
import torch
def gumbel_softmax(logits, tau=1.0, eps=1e-10):
"""Generate samples from the Gumbel-softmax distribution.
(arXiv: 1611.01144)
Examples:
>>> # sampling from a Gumbel-softmax distribution given a categorical distribution
>>> gumbel_softmax(torch.tensor([0.3, 0.7]).log(), tau=0.1)
tensor([0.0711, 0.9289])
>>> # samples should converge to the categorical distribution
>>> gumbel_softmax(torch.tensor([0.3, 0.7]).log().unsqueeze(0).expand(1000, -1), tau=0.1).mean(0)
tensor([0.2890, 0.7110])
"""
uniform = logits.clone().detach().uniform_()
gumbel_noise = uniform.add_(eps).log_().mul_(-1).add_(eps).log_().mul_(-1)
return torch.softmax((logits + gumbel_noise) / tau, len(logits.size()) - 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment