Skip to content

Instantly share code, notes, and snippets.

@hushell
Last active December 19, 2020 00:46
Show Gist options
  • Save hushell/998af0f03ccbddbb8fe2a503d43a7faa to your computer and use it in GitHub Desktop.
Save hushell/998af0f03ccbddbb8fe2a503d43a7faa to your computer and use it in GitHub Desktop.
Gumbel softmax
import torch
from torch.distributions.utils import clamp_probs
from torch.distributions import RelaxedOneHotCategorical
class QuantizeCategorical(torch.autograd.Function):
@staticmethod
def forward(ctx, soft_value):
argmax = soft_value.max(-1)[1]
hard_value = torch.zeros_like(soft_value)
hard_value._unquantize = soft_value
if argmax.dim() < hard_value.dim():
argmax = argmax.unsqueeze(-1)
return hard_value.scatter_(-1, argmax, 1)
@staticmethod
def backward(ctx, grad):
return grad
class RelaxedOneHotCategoricalStraightThrough(RelaxedOneHotCategorical):
"""
An implementation of
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
with a straight-through gradient estimator.
This distribution has the following properties:
- The samples returned by the :meth:`rsample` method are discrete/quantized.
- The :meth:`log_prob` method returns the log probability of the
relaxed/unquantized sample using the GumbelSoftmax distribution.
- In the backward pass the gradient of the sample with respect to the
parameters of the distribution uses the relaxed/unquantized sample.
References:
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables,
Chris J. Maddison, Andriy Mnih, Yee Whye Teh
[2] Categorical Reparameterization with Gumbel-Softmax,
Eric Jang, Shixiang Gu, Ben Poole
"""
def rsample(self, sample_shape=torch.Size()):
soft_sample = super().rsample(sample_shape)
soft_sample = clamp_probs(soft_sample)
hard_sample = QuantizeCategorical.apply(soft_sample)
return hard_sample
# Example
logits = torch.rand(3)
logits.requires_grad_()
dist = RelaxedOneHotCategoricalStraightThrough(temperature=1., logits=logits)
x = dist.rsample()
loss = x.max()
loss.backward()
print(logits.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment