Skip to content

Instantly share code, notes, and snippets.

@tejaskhot
Created June 18, 2018 15:54
Show Gist options
  • Save tejaskhot/7b01f9cde6b89fb993da5ad7d462f1fc to your computer and use it in GitHub Desktop.
Save tejaskhot/7b01f9cde6b89fb993da5ad7d462f1fc to your computer and use it in GitHub Desktop.
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).cuda()
return -Variable(torch.log(-torch.log(U + eps) + eps))
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature):
"""
input: [*, n_class]
return: [*, n_class] an one-hot vector
"""
y = gumbel_softmax_sample(logits, temperature)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
return (y_hard - y).detach() + y
if __name__ == '__main__':
import math
data = Variable(torch.cuda.FloatTensor([[math.log(0.1), math.log(0.4), math.log(0.3), math.log(0.2)]] * 20000))
one_hot = gumbel_softmax(data, 0.8).sum(dim=0)
print(one_hot)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment