Last active
December 19, 2020 00:46
-
-
Save hushell/998af0f03ccbddbb8fe2a503d43a7faa to your computer and use it in GitHub Desktop.
Gumbel softmax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.distributions.utils import clamp_probs | |
from torch.distributions import RelaxedOneHotCategorical | |
class QuantizeCategorical(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, soft_value): | |
argmax = soft_value.max(-1)[1] | |
hard_value = torch.zeros_like(soft_value) | |
hard_value._unquantize = soft_value | |
if argmax.dim() < hard_value.dim(): | |
argmax = argmax.unsqueeze(-1) | |
return hard_value.scatter_(-1, argmax, 1) | |
@staticmethod | |
def backward(ctx, grad): | |
return grad | |
class RelaxedOneHotCategoricalStraightThrough(RelaxedOneHotCategorical): | |
""" | |
An implementation of | |
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` | |
with a straight-through gradient estimator. | |
This distribution has the following properties: | |
- The samples returned by the :meth:`rsample` method are discrete/quantized. | |
- The :meth:`log_prob` method returns the log probability of the | |
relaxed/unquantized sample using the GumbelSoftmax distribution. | |
- In the backward pass the gradient of the sample with respect to the | |
parameters of the distribution uses the relaxed/unquantized sample. | |
References: | |
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables, | |
Chris J. Maddison, Andriy Mnih, Yee Whye Teh | |
[2] Categorical Reparameterization with Gumbel-Softmax, | |
Eric Jang, Shixiang Gu, Ben Poole | |
""" | |
def rsample(self, sample_shape=torch.Size()): | |
soft_sample = super().rsample(sample_shape) | |
soft_sample = clamp_probs(soft_sample) | |
hard_sample = QuantizeCategorical.apply(soft_sample) | |
return hard_sample | |
# Example | |
logits = torch.rand(3) | |
logits.requires_grad_() | |
dist = RelaxedOneHotCategoricalStraightThrough(temperature=1., logits=logits) | |
x = dist.rsample() | |
loss = x.max() | |
loss.backward() | |
print(logits.grad) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment