Skip to content

Instantly share code, notes, and snippets.

@yongjincho
Last active July 27, 2018 18:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yongjincho/55d415f7c7602c71b03c70a9ad361b0d to your computer and use it in GitHub Desktop.
Save yongjincho/55d415f7c7602c71b03c70a9ad361b0d to your computer and use it in GitHub Desktop.
A Simple Implementation of Variational Autoencoder (VAE)
# Copyright 2018 Yongjin Cho <yongjin.cho@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Simple Implementation of Variational Autoencoder (VAE)
Requirements
============
* Python >= 3.6
* Tensorflow >= 1.8
* imageio
* matplotlib
How to use
==========
Simply run the following command. Reconstruction images and generated
sample images are saved in the './result'.
$ python vae.py
If you set the latent dim. to 2 by following command, then iterpolation and posterior
distribution images are also created.
$ python vae.py --latent_size=2
References
==========
* https://github.com/hwalsuklee/tensorflow-mnist-VAE
* https://github.com/pytorch/examples/tree/master/vae
"""
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import imageio
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('hidden_size', 400, "")
flags.DEFINE_integer('latent_size', 20, "")
flags.DEFINE_integer('num_epochs', 10, "")
flags.DEFINE_integer('batch_size', 128, "")
flags.DEFINE_float('lr', 0.001, "")
flags.DEFINE_string("result_dir", "result", "")
def encode(x):
with tf.variable_scope("encode"):
x = tf.layers.dense(x, FLAGS.hidden_size, tf.nn.relu)
mu = tf.layers.dense(x, FLAGS.latent_size)
logvar = tf.layers.dense(x, FLAGS.latent_size)
return mu, logvar
def reparameterize(mu, logvar):
std = tf.exp(0.5 * logvar)
e = tf.random_normal(tf.shape(mu))
return mu + e * std
def decode(z):
with tf.variable_scope("decode", reuse=tf.AUTO_REUSE):
x = tf.layers.dense(z, FLAGS.hidden_size, tf.nn.relu)
x = tf.layers.dense(x, 784, tf.nn.sigmoid)
return x
def VAE(x):
mu, logvar = encode(x)
z = reparameterize(mu, logvar)
x = decode(z)
return x, mu, logvar
def compute_loss(x, x_recon, mu, logvar):
bce = tf.reduce_sum(x * tf.log(x_recon) + (1 - x) * tf.log(1 - x_recon), 1)
kld = 0.5 * tf.reduce_sum(1 + logvar - tf.square(mu) - tf.exp(logvar), 1)
elbo = kld + bce
losses = -elbo
return tf.reduce_mean(losses)
def make_batch(dataset):
for i in range(0, len(dataset.images), FLAGS.batch_size):
yield dataset.images[i:i+FLAGS.batch_size], dataset.labels[i:i+FLAGS.batch_size]
def save_image(fname, x, ncols):
n = len(x)
assert n % ncols == 0
nrows = n // ncols
images = np.split(x, n)
rows = []
for i in range(0, n, ncols):
row = np.concatenate(images[i:i+ncols], 2) # [1, 28, 28*ncols]
rows.append(row)
merged = np.concatenate(rows, 1) # [1, 28*nrows, 28*ncols]
out = merged.reshape(28*nrows, 28*ncols) # [28*nrows, 28*ncols]
path = os.path.join(FLAGS.result_dir, fname)
imageio.imwrite(path, (out * 255).astype(np.uint8))
def make_gif(prefix):
images = []
for i in range(FLAGS.num_epochs):
im = imageio.imread(os.path.join(FLAGS.result_dir, f"{prefix}_{i}.png"))
images.append(im)
imageio.mimwrite(os.path.join(FLAGS.result_dir, f"{prefix}.gif"), images)
def main(argv):
# Load data
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
if not os.path.exists(FLAGS.result_dir):
os.makedirs(FLAGS.result_dir)
# Build the train graph
x_in = tf.placeholder(tf.float32, [None, 784], "x")
x_recon, mu, logvar = VAE(x_in)
loss = compute_loss(x_in, x_recon, mu, logvar)
train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss)
# Build the generation graph
z_in = tf.placeholder(tf.float32, [None, FLAGS.latent_size], "z")
x_gen = decode(z_in)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(FLAGS.num_epochs):
# Train
train_loss = 0.0
for batch_x, _ in make_batch(mnist.train):
output = sess.run({"train_op": train_op, "loss": loss}, feed_dict={x_in: batch_x})
train_loss += output["loss"] * len(batch_x)
train_loss /= len(mnist.train.images)
print(f"Epoch: {epoch}, Average loss: {train_loss:.3f}")
# Evaluate
eval_loss = 0.0
for i, (batch_x, batch_y) in enumerate(make_batch(mnist.test)):
output = sess.run({"x": x_recon, "mu": mu, "loss": loss}, feed_dict={x_in: batch_x})
eval_loss += output["loss"] * len(batch_x)
if i == 0:
n = min(len(batch_x), 8)
image = np.concatenate([batch_x[:n].reshape(-1, 28, 28),
output["x"][:n].reshape(-1, 28, 28)], 0)
save_image(f"reconstruction_{epoch}.png", image, n)
# Make posterior distribution map
if FLAGS.latent_size == 2:
if i == 0:
fig, ax = plt.subplots()
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
ax.scatter(output["mu"][:,0], output["mu"][:,1],
c=(batch_y / 10.0), cmap=plt.get_cmap("tab10"))
eval_loss /= len(mnist.test.images)
print(f"Evaluation loss: {eval_loss:.3f}")
if FLAGS.latent_size == 2:
plt.savefig(os.path.join(FLAGS.result_dir, f"posterior_{epoch}.png"))
# Generate samples
z_sample = np.random.uniform(-2, 2, size=(8 * 8, FLAGS.latent_size))
output = sess.run({"x": x_gen}, feed_dict={z_in: z_sample})
save_image(f"sample_{epoch}.png", output["x"].reshape(-1, 28, 28), 8)
# Generate a interpolation image
if FLAGS.latent_size == 2:
n = 20
z_grid = []
for z0 in np.linspace(-3, 3, n):
for z1 in np.linspace(-3, 3, n):
z_grid.append([z0, z1])
output = sess.run({"x": x_gen}, feed_dict={z_in: z_grid})
save_image(f"interpolation_{epoch}.png", output["x"].reshape(-1, 28, 28), n)
# Make GIF
make_gif("reconstruction")
if FLAGS.latent_size == 2:
make_gif("posterior")
make_gif("interpolation")
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment