Instantly share code, notes, and snippets.

Embed
What would you like to do?
TensorFlow Variational Auto-Encoder
# Full example for my blog post at:
# https://danijar.com/building-variational-auto-encoders-in-tensorflow/
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tfd = tf.contrib.distributions
def make_encoder(data, code_size):
x = tf.layers.flatten(data)
x = tf.layers.dense(x, 200, tf.nn.relu)
x = tf.layers.dense(x, 200, tf.nn.relu)
loc = tf.layers.dense(x, code_size)
scale = tf.layers.dense(x, code_size, tf.nn.softplus)
return tfd.MultivariateNormalDiag(loc, scale)
def make_prior(code_size):
loc = tf.zeros(code_size)
scale = tf.ones(code_size)
return tfd.MultivariateNormalDiag(loc, scale)
def make_decoder(code, data_shape):
x = code
x = tf.layers.dense(x, 200, tf.nn.relu)
x = tf.layers.dense(x, 200, tf.nn.relu)
logit = tf.layers.dense(x, np.prod(data_shape))
logit = tf.reshape(logit, [-1] + data_shape)
return tfd.Independent(tfd.Bernoulli(logit), 2)
def plot_codes(ax, codes, labels):
ax.scatter(codes[:, 0], codes[:, 1], s=2, c=labels, alpha=0.1)
ax.set_aspect('equal')
ax.set_xlim(codes.min() - .1, codes.max() + .1)
ax.set_ylim(codes.min() - .1, codes.max() + .1)
ax.tick_params(
axis='both', which='both', left='off', bottom='off',
labelleft='off', labelbottom='off')
def plot_samples(ax, samples):
for index, sample in enumerate(samples):
ax[index].imshow(sample, cmap='gray')
ax[index].axis('off')
data = tf.placeholder(tf.float32, [None, 28, 28])
make_encoder = tf.make_template('encoder', make_encoder)
make_decoder = tf.make_template('decoder', make_decoder)
# Define the model.
prior = make_prior(code_size=2)
posterior = make_encoder(data, code_size=2)
code = posterior.sample()
# Define the loss.
likelihood = make_decoder(code, [28, 28]).log_prob(data)
divergence = tfd.kl_divergence(posterior, prior)
elbo = tf.reduce_mean(likelihood - divergence)
optimize = tf.train.AdamOptimizer(0.001).minimize(-elbo)
samples = make_decoder(prior.sample(10), [28, 28]).mean()
mnist = input_data.read_data_sets('MNIST_data/')
fig, ax = plt.subplots(nrows=20, ncols=11, figsize=(10, 20))
with tf.train.MonitoredSession() as sess:
for epoch in range(20):
feed = {data: mnist.test.images.reshape([-1, 28, 28])}
test_elbo, test_codes, test_samples = sess.run([elbo, code, samples], feed)
print('Epoch', epoch, 'elbo', test_elbo)
ax[epoch, 0].set_ylabel('Epoch {}'.format(epoch))
plot_codes(ax[epoch, 0], test_codes, mnist.test.labels)
plot_samples(ax[epoch, 1:], test_samples)
for _ in range(600):
feed = {data: mnist.train.next_batch(100)[0].reshape([-1, 28, 28])}
sess.run(optimize, feed)
plt.savefig('vae-mnist.png', dpi=300, transparent=True, bbox_inches='tight')
@korymath

This comment has been minimized.

Copy link

korymath commented Jan 15, 2018

... colab? (wink wink)

@danijar

This comment has been minimized.

@kpe

This comment has been minimized.

Copy link

kpe commented Jul 25, 2018

It seems like the line

samples = make_decoder(prior.sample(10), [28, 28]).mean()

needs to be replaced with:

samples = make_decoder(prior.sample((10,1)), [28, 28]).mean()

otherwise it won't run (i.e. in colab).

@skokalj

This comment has been minimized.

Copy link

skokalj commented Sep 20, 2018

Hi, I am getting
InvalidArgumentError (see above for traceback): Matrix size-incompatible: In[0]: [2,10], In[1]: [2,200]
the second time it calls make_decoder.
I use tensorflow (and tensorflow-gpu), both version 1.9.0

Thanks,
-Silvija

@skokalj

This comment has been minimized.

Copy link

skokalj commented Sep 22, 2018

Just an update: when I applied the modification that @kpe suggested everything worked fine. Thanks!

@alla15747

This comment has been minimized.

Copy link

alla15747 commented Dec 6, 2018

Thank you for the tutorial. I changed the code size to 4, but the code is not working. It just work with code size 2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment