Skip to content

Instantly share code, notes, and snippets.

@Sentient07
Last active June 14, 2017 12:01
Show Gist options
  • Save Sentient07/81bdf04de7d8347b429734da0fecbaf6 to your computer and use it in GitHub Desktop.
Save Sentient07/81bdf04de7d8347b429734da0fecbaf6 to your computer and use it in GitHub Desktop.
Gumbel softmax draft in theano
import numpy as np
import theano
import theano.tensor as tensor
rng = np.random.RandomState(1)
srng = RandomStreams(rng.randint(1234))
def get_one_hot(inp, nb_samples, nb_class):
m = tensor.zeros((nb_samples, nb_class))
m = tensor.set_subtensor(m[tensor.arange(nb_samples), tensor.argmax(inp, -1)], 1)
return m
def gumbel_softmax(inp, temperature, epsilon, nb_classes, hard=False):
uniform_sample = srng.uniform(inp.shape, low=0, high=1).astype('float32')
gumbel_dist = -tensor.log(-tensor.log(uniform_sample + epsilon) + epsilon)
soft = tensor.nnet.softmax((inp + gumbel_dist) / temperature)
if hard:
gumbel_trick = get_one_hot(softm, softm.shape[0], nb_classes)
return gumbel_trick
return soft
t1 = tensor.ivector()
temp = theano.shared(np.float32(1e-1))
gs = gumbel_softmax(t1, temp, epsilon=np.float32(1e-20), nb_classes=1, hard=False)
func1 = theano.function([t1], gs)
func1.maker.fgraph.toposort()
theano.printing.debugprint(func1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment