Skip to content

Instantly share code, notes, and snippets.

@yzh119
Created January 12, 2018 12:25
Show Gist options
  • Save yzh119/fd2146d2aeb329d067568a493b20172f to your computer and use it in GitHub Desktop.
Save yzh119/fd2146d2aeb329d067568a493b20172f to your computer and use it in GitHub Desktop.
ST-Gumbel-Softmax-Pytorch
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).cuda()
return -Variable(torch.log(-torch.log(U + eps) + eps))
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature):
"""
input: [*, n_class]
return: [*, n_class] an one-hot vector
"""
y = gumbel_softmax_sample(logits, temperature)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
return (y_hard - y).detach() + y
if __name__ == '__main__':
import math
print(gumbel_softmax(Variable(torch.cuda.FloatTensor([[math.log(0.1), math.log(0.4), math.log(0.3), math.log(0.2)]] * 20000)), 0.8).sum(dim=0))
@ibrahim10h
Copy link

@JACKHAHA363 Line 30 already provides the 'logits' parameter as the log of what appears to be the softmax of a vector: math.log([0.1, 0.4 ,0.3, 0.2]). This may be why F.log_softmax(logits) was not done on Line 12.

@JACKHAHA363
Copy link

JACKHAHA363 commented Nov 27, 2019

@ibrahim10h Right it's okay in this case because it's sending actual log of normalized probabilities. But in general neural network, we refer the output of network as logits which could be the log of normalized probabilities with arbitrary offset. This stand alone example is correct, but it could induce potential error for people carelessly just copy paste.

@EmnamoR
Copy link

EmnamoR commented Jun 22, 2020

Hi, I am trying to implement this gumbel-softmax trick to a vae autoencoder for data synthesization. Here is the implementation. Am i doing something wrong ? thank you

import logging

import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.contrib.distributions import (Bernoulli, OneHotCategorical,
                                              RelaxedOneHotCategorical,
                                              kl_divergence)
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.models import Model

logging.getLogger('tensorflow').disabled = True


class DiscreteVAE:
    def encoder(self, latent_dim, input_dim):
        encoder_input = layers.Input(shape=(input_dim, ), name='encoder_input')
        x = encoder_input
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_1')(x)
        x = layers.Dropout(0.3)(x)
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_2')(x)
        x = tf.keras.layers.Dense(latent_dim)(x)
        encoder_model = Model(inputs=encoder_input, outputs=x)
        encoder_model.summary()
        return encoder_model

    def decoder(self, latent_dim, input_dim):
        decoder_input = layers.Input(latent_dim, name='decoder_input')
        x = decoder_input
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_1')(x)
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_2')(x)
        decoded_input = layers.Dense(input_dim, name='decoded_input')(x)
        decoder_model = Model(decoder_input, decoded_input)
        decoder_model.summary()
        return decoder_model

    def sample_gumbel(self, shape, eps=1e-20):
        """Sample from Gumbel(0, 1)"""
        U = tf.random_uniform(shape, minval=0, maxval=1, dtype=tf.float32)
        return -tf.log(-tf.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature):
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        return tf.nn.softmax(y / temperature)

    def gumbel_softmax(self, args):
        """Sample from the Gumbel-Softmax distribution and optionally discretize.
        Args:
            logits: [batch_size, n_class] unnormalized log-probs
            temperature: non-negative scalar
            hard: if True, take argmax, but differentiate w.r.t. soft sample y
        Returns:
            [batch_size, n_class] sample from the Gumbel-Softmax distribution.
            If hard=True, then the returned sample will be one-hot, otherwise it will
            be a probability distribution that sums to 1 across classes
        """
        logits, temperature = args
        y = self.gumbel_softmax_sample(logits, temperature)
        # k = tf.shape(logits)[-1]
        # y_hard = tf.cast(tf.one_hot(tf.argmax(y, 1), k), y.dtype)
        y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)),
                         y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
        return y

    def CatVAE_loss(self, encoded_input, decoded_input, z, x, tau, latent_dim):
        reconstruction_error = tf.reduce_sum(
            Bernoulli(logits=decoded_input).log_prob(x), 1)
        logits_pz = tf.ones_like(decoded_input) * (1. / latent_dim)
        q_cat_z = OneHotCategorical(logits=encoded_input)
        p_cat_z = OneHotCategorical(logits=logits_pz)
        KL_qp = kl_divergence(q_cat_z, p_cat_z)
        ELBO = tf.reduce_mean(reconstruction_error - KL_qp)
        loss = -ELBO
        return loss

    def build_vae(self, latent_dim, input_dim, opt, data):
        tau = 0.5
        input_x = layers.Input(shape=input_dim, name='vae_input')
        encoder_m = self.encoder(latent_dim, input_dim)
        logits_y = encoder_m(input_x)

        z = layers.Lambda(self.gumbel_softmax)([logits_y, tau])
        decoder_m = self.decoder(latent_dim, input_dim)
        decoded_input = decoder_m(z)
        # loss = self.vae_loss(input_x, input_dim, decoded_input, data)

        loss = self.CatVAE_loss(logits_y, decoded_input, z, input_x, tau,
                                latent_dim)
        vae = Model(input_x, decoded_input)
        vae.add_loss(loss)
        vae.compile(optimizer=opt)
        return vae, decoder_m, encoder_m

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment