Last active July 15, 2017 03:37
import mxnet as mx
import numpy as np
## Example of Gumbel-softmax ##
## user settings
batch_size = 2
cardinality = 3
num_samples = 5
temperature = 1.0
def gumbel_softmax_sample(output_dimension, hard=True):
""" Draw a sample from the Gumbel-Softmax distribution (mxnet implementation)"""
eps = 1e-20
uniform = mx.sym.Variable('uniform')
logits = mx.sym.Variable('logits')
temperature = mx.sym.Variable('temperature')
gumbel = -mx.sym.log(-mx.sym.log(uniform + eps) + eps)
y = mx.sym.softmax(mx.sym.broadcast_div((logits + gumbel),temperature))
if hard:
#y_hard0 = mx.sym.cast(mx.sym.argmax(y, axis=1, keepdims=True), dtype='int32')
#y_hard0 = mx.sym.cast(mx.sym.argmax(y, axis=1), dtype='int32')
y_hard = mx.sym.one_hot(indices=mx.sym.argmax(y, axis=1), depth=output_dimension)
y = mx.sym.BlockGrad(y_hard - y) + y
return y
def usual_sample(logits, num_samples):
""" Draw a sample the usual way (numpy implementation)"""
e = np.exp(logits)
normalizer = np.sum(e, axis=1).reshape((e.shape[0],1))
probabilities = e/normalizer
index = range(0,e.shape[1])
c = {} # sampled indices
for row in xrange(e.shape[0]):
c[row] = np.random.choice(index, num_samples, p=probabilities[row])
print "probabilities (each row is a distribution):\n", probabilities
for n in xrange(num_samples):
onehot = np.zeros(e.shape)
for row in xrange(e.shape[0]):
onehot[row][c[row][n]] = 1
print "usual sample %d:" %n
print onehot
return c
## create computational graph for gumbel
x = mx.nd.array(np.random.randn(batch_size, cardinality))
y1 = gumbel_softmax_sample(cardinality, True)
y2 = gumbel_softmax_sample(cardinality, False)
## print out samples
print "logits x for discrete distribution exp(x_k)/sum_{j=1}^K{exp(x_j)}:\n", x.asnumpy()
print "batch_size=%d, cardinality k=%d " %(batch_size, cardinality)
print "\nsamples from gumbel-softmax, temperature=%f" % temperature
for i in xrange(num_samples):
uniform_sample = mx.nd.random_uniform(low=0,high=1,shape=(batch_size, cardinality))
print "gumbel-softmax sample %d (straight-through & original soft version):"%i
print ex1[0].asnumpy()
print ex2[0].asnumpy()
print "\nsamples from usual procedure:"
usual_sample(x.asnumpy(), num_samples)
