Skip to content

Instantly share code, notes, and snippets.

Last active February 22, 2023 09:02
Show Gist options
  • Star 32 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save danijar/1cb4d81fed37fd06ef60d08c1181f557 to your computer and use it in GitHub Desktop.
Save danijar/1cb4d81fed37fd06ef60d08c1181f557 to your computer and use it in GitHub Desktop.
TensorFlow Variational Auto-Encoder
# Full example for my blog post at:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tfd = tf.contrib.distributions
def make_encoder(data, code_size):
x = tf.layers.flatten(data)
x = tf.layers.dense(x, 200, tf.nn.relu)
x = tf.layers.dense(x, 200, tf.nn.relu)
loc = tf.layers.dense(x, code_size)
scale = tf.layers.dense(x, code_size, tf.nn.softplus)
return tfd.MultivariateNormalDiag(loc, scale)
def make_prior(code_size):
loc = tf.zeros(code_size)
scale = tf.ones(code_size)
return tfd.MultivariateNormalDiag(loc, scale)
def make_decoder(code, data_shape):
x = code
x = tf.layers.dense(x, 200, tf.nn.relu)
x = tf.layers.dense(x, 200, tf.nn.relu)
logit = tf.layers.dense(x,
logit = tf.reshape(logit, [-1] + data_shape)
return tfd.Independent(tfd.Bernoulli(logit), 2)
def plot_codes(ax, codes, labels):
ax.scatter(codes[:, 0], codes[:, 1], s=2, c=labels, alpha=0.1)
ax.set_xlim(codes.min() - .1, codes.max() + .1)
ax.set_ylim(codes.min() - .1, codes.max() + .1)
axis='both', which='both', left='off', bottom='off',
labelleft='off', labelbottom='off')
def plot_samples(ax, samples):
for index, sample in enumerate(samples):
ax[index].imshow(sample, cmap='gray')
data = tf.placeholder(tf.float32, [None, 28, 28])
make_encoder = tf.make_template('encoder', make_encoder)
make_decoder = tf.make_template('decoder', make_decoder)
# Define the model.
prior = make_prior(code_size=2)
posterior = make_encoder(data, code_size=2)
code = posterior.sample()
# Define the loss.
likelihood = make_decoder(code, [28, 28]).log_prob(data)
divergence = tfd.kl_divergence(posterior, prior)
elbo = tf.reduce_mean(likelihood - divergence)
optimize = tf.train.AdamOptimizer(0.001).minimize(-elbo)
samples = make_decoder(prior.sample(10), [28, 28]).mean()
mnist = input_data.read_data_sets('MNIST_data/')
fig, ax = plt.subplots(nrows=20, ncols=11, figsize=(10, 20))
with tf.train.MonitoredSession() as sess:
for epoch in range(20):
feed = {data: mnist.test.images.reshape([-1, 28, 28])}
test_elbo, test_codes, test_samples =[elbo, code, samples], feed)
print('Epoch', epoch, 'elbo', test_elbo)
ax[epoch, 0].set_ylabel('Epoch {}'.format(epoch))
plot_codes(ax[epoch, 0], test_codes, mnist.test.labels)
plot_samples(ax[epoch, 1:], test_samples)
for _ in range(600):
feed = {data: mnist.train.next_batch(100)[0].reshape([-1, 28, 28])}, feed)
plt.savefig('vae-mnist.png', dpi=300, transparent=True, bbox_inches='tight')
Copy link

... colab? (wink wink)

Copy link

danijar commented Jan 25, 2018

Copy link

kpe commented Jul 25, 2018

It seems like the line

samples = make_decoder(prior.sample(10), [28, 28]).mean()

needs to be replaced with:

samples = make_decoder(prior.sample((10,1)), [28, 28]).mean()

otherwise it won't run (i.e. in colab).

Copy link

skokalj commented Sep 20, 2018

Hi, I am getting
InvalidArgumentError (see above for traceback): Matrix size-incompatible: In[0]: [2,10], In[1]: [2,200]
the second time it calls make_decoder.
I use tensorflow (and tensorflow-gpu), both version 1.9.0


Copy link

skokalj commented Sep 22, 2018

Just an update: when I applied the modification that @kpe suggested everything worked fine. Thanks!

Copy link

Thank you for the tutorial. I changed the code size to 4, but the code is not working. It just work with code size 2.

Copy link

Thanks. Your code is very helpful!
But I have a question. Are you implementing the exact algorithm in "Auto-Encoding Variational Bayes"? Since in that paper, it use MLP to construct the encoder and decoder, which I think in the "make_encoder" function, the activation function of first layer should be tanh, but not relu. And it is the same for the "make_decoder" function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment