Skip to content

Instantly share code, notes, and snippets.

@dddzg
Forked from yzh119/st-gumbel.py
Created October 10, 2019 07:07
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save dddzg/1f12f1441bfb61d92afa06302aa30696 to your computer and use it in GitHub Desktop.
Save dddzg/1f12f1441bfb61d92afa06302aa30696 to your computer and use it in GitHub Desktop.
ST-Gumbel-Softmax-Pytorch
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).cuda()
return -Variable(torch.log(-torch.log(U + eps) + eps))
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature):
"""
input: [*, n_class]
return: [*, n_class] an one-hot vector
"""
y = gumbel_softmax_sample(logits, temperature)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
return (y_hard - y).detach() + y
if __name__ == '__main__':
import math
print(gumbel_softmax(Variable(torch.cuda.FloatTensor([[math.log(0.1), math.log(0.4), math.log(0.3), math.log(0.2)]] * 20000)), 0.8).sum(dim=0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment