Skip to content

Instantly share code, notes, and snippets.

@ericjang
Created November 9, 2016 05:30
Show Gist options
  • Save ericjang/1001afd374c2c3b7752545ce6d9ed349 to your computer and use it in GitHub Desktop.
Save ericjang/1001afd374c2c3b7752545ce6d9ed349 to your computer and use it in GitHub Desktop.
def sample_gumbel(shape, eps=1e-20):
"""Sample from Gumbel(0, 1)"""
U = tf.random_uniform(shape,minval=0,maxval=1)
return -tf.log(-tf.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
y = logits + sample_gumbel(tf.shape(logits))
return tf.nn.softmax( y / temperature)
def gumbel_softmax(logits, temperature, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = gumbel_softmax_sample(logits, temperature)
if hard:
k = tf.shape(logits)[-1]
#y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
y = tf.stop_gradient(y_hard - y) + y
return y
@yuxiang-wu
Copy link

Hi, Eric! Thanks for sharing the code. I have a question regarding line 24-26. You implemented two ways to compute the one-hot. It seems that line 26 could lead to multiple ones when tie occurs, though that is very unlikely. Which way is better? Which is tested and used in your paper?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment