Skip to content

Instantly share code, notes, and snippets.

@yzh119
Created January 12, 2018 12:25
Show Gist options
  • Save yzh119/fd2146d2aeb329d067568a493b20172f to your computer and use it in GitHub Desktop.
Save yzh119/fd2146d2aeb329d067568a493b20172f to your computer and use it in GitHub Desktop.
ST-Gumbel-Softmax-Pytorch
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).cuda()
return -Variable(torch.log(-torch.log(U + eps) + eps))
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature):
"""
input: [*, n_class]
return: [*, n_class] an one-hot vector
"""
y = gumbel_softmax_sample(logits, temperature)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
return (y_hard - y).detach() + y
if __name__ == '__main__':
import math
print(gumbel_softmax(Variable(torch.cuda.FloatTensor([[math.log(0.1), math.log(0.4), math.log(0.3), math.log(0.2)]] * 20000)), 0.8).sum(dim=0))
@Baukebrenninkmeijer
Copy link

Hi, this seems to be just the Gumbel Softmax Estimator, not the Straight Through Gumbel Softmax Estimator. ST Gumbel Softmax uses the argmax in the forward pass, whose gradients are then approximated by the normal Gumbel Softmax in the backward pass.
So afaik, a ST Gumbel Softmax implementation would require the implementation of both the forward and backward pass functions, since they are different and the forward pass cannot be approximated with autograd.

Please correct me if I'm wrong, but this seems to be the case.

@swyoon
Copy link

swyoon commented Sep 25, 2019

@Baukebrenninkmeijer, in principle, you are right. Here, however, the author cleverly used detach() in line 26 to forward the gradient w.r.t. y_hard to y. I think it is a neat trick.

@Baukebrenninkmeijer
Copy link

Hmm, very neat trick indeed!

@yzh119
Copy link
Author

yzh119 commented Sep 28, 2019

@swyoon is right, this implementation is equivalent to Eric Jang's, where he uses stop_gradient(equivalence of detach in PyTorch for Tensorflow).

@JACKHAHA363
Copy link

JACKHAHA363 commented Oct 26, 2019

logits is not strictly log probs. It would be better to change line 12 to be

 y = F.log_softmax(logits, dim=-1) + sample_gumbel(logits.size())

@Baukebrenninkmeijer
Copy link

@JACKHAHA363 Doesn't this have two problems? One, the softmax doesn't sum to 1 anymore and two, the magnitude of the log_softmax output is much lower than that of the logits, so the noise you're adding is relatively much larger.

@JACKHAHA363
Copy link

JACKHAHA363 commented Oct 26, 2019

  1. No matter how you compute y in line 12. The softmax in line 13 will make everything sum to 1.
  2. I think the magnitude would depend on the actual application. The main reason for using log_softmax(logits) instead of logits is that, if you check the original paper, y = g + log \pi. log \pi here means log probs instead of the network outputs logits.
  3. If you check the definition of softmax, you will quickly realize, log_softmax(logits) = log_softmax(logits + C) for any constant C. So it means that your y should not be sensitive to the scale of the outputted logits.

@ibrahim10h
Copy link

@JACKHAHA363 Line 30 already provides the 'logits' parameter as the log of what appears to be the softmax of a vector: math.log([0.1, 0.4 ,0.3, 0.2]). This may be why F.log_softmax(logits) was not done on Line 12.

@JACKHAHA363
Copy link

JACKHAHA363 commented Nov 27, 2019

@ibrahim10h Right it's okay in this case because it's sending actual log of normalized probabilities. But in general neural network, we refer the output of network as logits which could be the log of normalized probabilities with arbitrary offset. This stand alone example is correct, but it could induce potential error for people carelessly just copy paste.

@EmnamoR
Copy link

EmnamoR commented Jun 22, 2020

Hi, I am trying to implement this gumbel-softmax trick to a vae autoencoder for data synthesization. Here is the implementation. Am i doing something wrong ? thank you

import logging

import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.contrib.distributions import (Bernoulli, OneHotCategorical,
                                              RelaxedOneHotCategorical,
                                              kl_divergence)
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.models import Model

logging.getLogger('tensorflow').disabled = True


class DiscreteVAE:
    def encoder(self, latent_dim, input_dim):
        encoder_input = layers.Input(shape=(input_dim, ), name='encoder_input')
        x = encoder_input
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_1')(x)
        x = layers.Dropout(0.3)(x)
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_2')(x)
        x = tf.keras.layers.Dense(latent_dim)(x)
        encoder_model = Model(inputs=encoder_input, outputs=x)
        encoder_model.summary()
        return encoder_model

    def decoder(self, latent_dim, input_dim):
        decoder_input = layers.Input(latent_dim, name='decoder_input')
        x = decoder_input
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_1')(x)
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_2')(x)
        decoded_input = layers.Dense(input_dim, name='decoded_input')(x)
        decoder_model = Model(decoder_input, decoded_input)
        decoder_model.summary()
        return decoder_model

    def sample_gumbel(self, shape, eps=1e-20):
        """Sample from Gumbel(0, 1)"""
        U = tf.random_uniform(shape, minval=0, maxval=1, dtype=tf.float32)
        return -tf.log(-tf.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature):
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        return tf.nn.softmax(y / temperature)

    def gumbel_softmax(self, args):
        """Sample from the Gumbel-Softmax distribution and optionally discretize.
        Args:
            logits: [batch_size, n_class] unnormalized log-probs
            temperature: non-negative scalar
            hard: if True, take argmax, but differentiate w.r.t. soft sample y
        Returns:
            [batch_size, n_class] sample from the Gumbel-Softmax distribution.
            If hard=True, then the returned sample will be one-hot, otherwise it will
            be a probability distribution that sums to 1 across classes
        """
        logits, temperature = args
        y = self.gumbel_softmax_sample(logits, temperature)
        # k = tf.shape(logits)[-1]
        # y_hard = tf.cast(tf.one_hot(tf.argmax(y, 1), k), y.dtype)
        y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)),
                         y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
        return y

    def CatVAE_loss(self, encoded_input, decoded_input, z, x, tau, latent_dim):
        reconstruction_error = tf.reduce_sum(
            Bernoulli(logits=decoded_input).log_prob(x), 1)
        logits_pz = tf.ones_like(decoded_input) * (1. / latent_dim)
        q_cat_z = OneHotCategorical(logits=encoded_input)
        p_cat_z = OneHotCategorical(logits=logits_pz)
        KL_qp = kl_divergence(q_cat_z, p_cat_z)
        ELBO = tf.reduce_mean(reconstruction_error - KL_qp)
        loss = -ELBO
        return loss

    def build_vae(self, latent_dim, input_dim, opt, data):
        tau = 0.5
        input_x = layers.Input(shape=input_dim, name='vae_input')
        encoder_m = self.encoder(latent_dim, input_dim)
        logits_y = encoder_m(input_x)

        z = layers.Lambda(self.gumbel_softmax)([logits_y, tau])
        decoder_m = self.decoder(latent_dim, input_dim)
        decoded_input = decoder_m(z)
        # loss = self.vae_loss(input_x, input_dim, decoded_input, data)

        loss = self.CatVAE_loss(logits_y, decoded_input, z, input_x, tau,
                                latent_dim)
        vae = Model(input_x, decoded_input)
        vae.add_loss(loss)
        vae.compile(optimizer=opt)
        return vae, decoder_m, encoder_m

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment