Skip to content

Instantly share code, notes, and snippets.

@yzh119
Created January 12, 2018 12:25
Show Gist options
  • Save yzh119/fd2146d2aeb329d067568a493b20172f to your computer and use it in GitHub Desktop.
Save yzh119/fd2146d2aeb329d067568a493b20172f to your computer and use it in GitHub Desktop.
ST-Gumbel-Softmax-Pytorch
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).cuda()
return -Variable(torch.log(-torch.log(U + eps) + eps))
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature):
"""
input: [*, n_class]
return: [*, n_class] an one-hot vector
"""
y = gumbel_softmax_sample(logits, temperature)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
return (y_hard - y).detach() + y
if __name__ == '__main__':
import math
print(gumbel_softmax(Variable(torch.cuda.FloatTensor([[math.log(0.1), math.log(0.4), math.log(0.3), math.log(0.2)]] * 20000)), 0.8).sum(dim=0))
@EmnamoR
Copy link

EmnamoR commented Jun 22, 2020

Hi, I am trying to implement this gumbel-softmax trick to a vae autoencoder for data synthesization. Here is the implementation. Am i doing something wrong ? thank you

import logging

import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.contrib.distributions import (Bernoulli, OneHotCategorical,
                                              RelaxedOneHotCategorical,
                                              kl_divergence)
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.models import Model

logging.getLogger('tensorflow').disabled = True


class DiscreteVAE:
    def encoder(self, latent_dim, input_dim):
        encoder_input = layers.Input(shape=(input_dim, ), name='encoder_input')
        x = encoder_input
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_1')(x)
        x = layers.Dropout(0.3)(x)
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_2')(x)
        x = tf.keras.layers.Dense(latent_dim)(x)
        encoder_model = Model(inputs=encoder_input, outputs=x)
        encoder_model.summary()
        return encoder_model

    def decoder(self, latent_dim, input_dim):
        decoder_input = layers.Input(latent_dim, name='decoder_input')
        x = decoder_input
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_1')(x)
        x = layers.Dense(256,
                         activation='relu',
                         kernel_initializer='random_uniform',
                         name='Dense_2')(x)
        decoded_input = layers.Dense(input_dim, name='decoded_input')(x)
        decoder_model = Model(decoder_input, decoded_input)
        decoder_model.summary()
        return decoder_model

    def sample_gumbel(self, shape, eps=1e-20):
        """Sample from Gumbel(0, 1)"""
        U = tf.random_uniform(shape, minval=0, maxval=1, dtype=tf.float32)
        return -tf.log(-tf.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature):
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        return tf.nn.softmax(y / temperature)

    def gumbel_softmax(self, args):
        """Sample from the Gumbel-Softmax distribution and optionally discretize.
        Args:
            logits: [batch_size, n_class] unnormalized log-probs
            temperature: non-negative scalar
            hard: if True, take argmax, but differentiate w.r.t. soft sample y
        Returns:
            [batch_size, n_class] sample from the Gumbel-Softmax distribution.
            If hard=True, then the returned sample will be one-hot, otherwise it will
            be a probability distribution that sums to 1 across classes
        """
        logits, temperature = args
        y = self.gumbel_softmax_sample(logits, temperature)
        # k = tf.shape(logits)[-1]
        # y_hard = tf.cast(tf.one_hot(tf.argmax(y, 1), k), y.dtype)
        y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)),
                         y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
        return y

    def CatVAE_loss(self, encoded_input, decoded_input, z, x, tau, latent_dim):
        reconstruction_error = tf.reduce_sum(
            Bernoulli(logits=decoded_input).log_prob(x), 1)
        logits_pz = tf.ones_like(decoded_input) * (1. / latent_dim)
        q_cat_z = OneHotCategorical(logits=encoded_input)
        p_cat_z = OneHotCategorical(logits=logits_pz)
        KL_qp = kl_divergence(q_cat_z, p_cat_z)
        ELBO = tf.reduce_mean(reconstruction_error - KL_qp)
        loss = -ELBO
        return loss

    def build_vae(self, latent_dim, input_dim, opt, data):
        tau = 0.5
        input_x = layers.Input(shape=input_dim, name='vae_input')
        encoder_m = self.encoder(latent_dim, input_dim)
        logits_y = encoder_m(input_x)

        z = layers.Lambda(self.gumbel_softmax)([logits_y, tau])
        decoder_m = self.decoder(latent_dim, input_dim)
        decoded_input = decoder_m(z)
        # loss = self.vae_loss(input_x, input_dim, decoded_input, data)

        loss = self.CatVAE_loss(logits_y, decoded_input, z, input_x, tau,
                                latent_dim)
        vae = Model(input_x, decoded_input)
        vae.add_loss(loss)
        vae.compile(optimizer=opt)
        return vae, decoder_m, encoder_m

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment