Last active
July 27, 2018 18:34
-
-
Save yongjincho/55d415f7c7602c71b03c70a9ad361b0d to your computer and use it in GitHub Desktop.
A Simple Implementation of Variational Autoencoder (VAE)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2018 Yongjin Cho <yongjin.cho@gmail.com> | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""A Simple Implementation of Variational Autoencoder (VAE) | |
Requirements | |
============ | |
* Python >= 3.6 | |
* Tensorflow >= 1.8 | |
* imageio | |
* matplotlib | |
How to use | |
========== | |
Simply run the following command. Reconstruction images and generated | |
sample images are saved in the './result'. | |
$ python vae.py | |
If you set the latent dim. to 2 by following command, then iterpolation and posterior | |
distribution images are also created. | |
$ python vae.py --latent_size=2 | |
References | |
========== | |
* https://github.com/hwalsuklee/tensorflow-mnist-VAE | |
* https://github.com/pytorch/examples/tree/master/vae | |
""" | |
import os | |
import numpy as np | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
import imageio | |
flags = tf.app.flags | |
FLAGS = flags.FLAGS | |
flags.DEFINE_integer('hidden_size', 400, "") | |
flags.DEFINE_integer('latent_size', 20, "") | |
flags.DEFINE_integer('num_epochs', 10, "") | |
flags.DEFINE_integer('batch_size', 128, "") | |
flags.DEFINE_float('lr', 0.001, "") | |
flags.DEFINE_string("result_dir", "result", "") | |
def encode(x): | |
with tf.variable_scope("encode"): | |
x = tf.layers.dense(x, FLAGS.hidden_size, tf.nn.relu) | |
mu = tf.layers.dense(x, FLAGS.latent_size) | |
logvar = tf.layers.dense(x, FLAGS.latent_size) | |
return mu, logvar | |
def reparameterize(mu, logvar): | |
std = tf.exp(0.5 * logvar) | |
e = tf.random_normal(tf.shape(mu)) | |
return mu + e * std | |
def decode(z): | |
with tf.variable_scope("decode", reuse=tf.AUTO_REUSE): | |
x = tf.layers.dense(z, FLAGS.hidden_size, tf.nn.relu) | |
x = tf.layers.dense(x, 784, tf.nn.sigmoid) | |
return x | |
def VAE(x): | |
mu, logvar = encode(x) | |
z = reparameterize(mu, logvar) | |
x = decode(z) | |
return x, mu, logvar | |
def compute_loss(x, x_recon, mu, logvar): | |
bce = tf.reduce_sum(x * tf.log(x_recon) + (1 - x) * tf.log(1 - x_recon), 1) | |
kld = 0.5 * tf.reduce_sum(1 + logvar - tf.square(mu) - tf.exp(logvar), 1) | |
elbo = kld + bce | |
losses = -elbo | |
return tf.reduce_mean(losses) | |
def make_batch(dataset): | |
for i in range(0, len(dataset.images), FLAGS.batch_size): | |
yield dataset.images[i:i+FLAGS.batch_size], dataset.labels[i:i+FLAGS.batch_size] | |
def save_image(fname, x, ncols): | |
n = len(x) | |
assert n % ncols == 0 | |
nrows = n // ncols | |
images = np.split(x, n) | |
rows = [] | |
for i in range(0, n, ncols): | |
row = np.concatenate(images[i:i+ncols], 2) # [1, 28, 28*ncols] | |
rows.append(row) | |
merged = np.concatenate(rows, 1) # [1, 28*nrows, 28*ncols] | |
out = merged.reshape(28*nrows, 28*ncols) # [28*nrows, 28*ncols] | |
path = os.path.join(FLAGS.result_dir, fname) | |
imageio.imwrite(path, (out * 255).astype(np.uint8)) | |
def make_gif(prefix): | |
images = [] | |
for i in range(FLAGS.num_epochs): | |
im = imageio.imread(os.path.join(FLAGS.result_dir, f"{prefix}_{i}.png")) | |
images.append(im) | |
imageio.mimwrite(os.path.join(FLAGS.result_dir, f"{prefix}.gif"), images) | |
def main(argv): | |
# Load data | |
mnist = tf.contrib.learn.datasets.load_dataset("mnist") | |
if not os.path.exists(FLAGS.result_dir): | |
os.makedirs(FLAGS.result_dir) | |
# Build the train graph | |
x_in = tf.placeholder(tf.float32, [None, 784], "x") | |
x_recon, mu, logvar = VAE(x_in) | |
loss = compute_loss(x_in, x_recon, mu, logvar) | |
train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss) | |
# Build the generation graph | |
z_in = tf.placeholder(tf.float32, [None, FLAGS.latent_size], "z") | |
x_gen = decode(z_in) | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
for epoch in range(FLAGS.num_epochs): | |
# Train | |
train_loss = 0.0 | |
for batch_x, _ in make_batch(mnist.train): | |
output = sess.run({"train_op": train_op, "loss": loss}, feed_dict={x_in: batch_x}) | |
train_loss += output["loss"] * len(batch_x) | |
train_loss /= len(mnist.train.images) | |
print(f"Epoch: {epoch}, Average loss: {train_loss:.3f}") | |
# Evaluate | |
eval_loss = 0.0 | |
for i, (batch_x, batch_y) in enumerate(make_batch(mnist.test)): | |
output = sess.run({"x": x_recon, "mu": mu, "loss": loss}, feed_dict={x_in: batch_x}) | |
eval_loss += output["loss"] * len(batch_x) | |
if i == 0: | |
n = min(len(batch_x), 8) | |
image = np.concatenate([batch_x[:n].reshape(-1, 28, 28), | |
output["x"][:n].reshape(-1, 28, 28)], 0) | |
save_image(f"reconstruction_{epoch}.png", image, n) | |
# Make posterior distribution map | |
if FLAGS.latent_size == 2: | |
if i == 0: | |
fig, ax = plt.subplots() | |
ax.set_xlim(-5, 5) | |
ax.set_ylim(-5, 5) | |
ax.scatter(output["mu"][:,0], output["mu"][:,1], | |
c=(batch_y / 10.0), cmap=plt.get_cmap("tab10")) | |
eval_loss /= len(mnist.test.images) | |
print(f"Evaluation loss: {eval_loss:.3f}") | |
if FLAGS.latent_size == 2: | |
plt.savefig(os.path.join(FLAGS.result_dir, f"posterior_{epoch}.png")) | |
# Generate samples | |
z_sample = np.random.uniform(-2, 2, size=(8 * 8, FLAGS.latent_size)) | |
output = sess.run({"x": x_gen}, feed_dict={z_in: z_sample}) | |
save_image(f"sample_{epoch}.png", output["x"].reshape(-1, 28, 28), 8) | |
# Generate a interpolation image | |
if FLAGS.latent_size == 2: | |
n = 20 | |
z_grid = [] | |
for z0 in np.linspace(-3, 3, n): | |
for z1 in np.linspace(-3, 3, n): | |
z_grid.append([z0, z1]) | |
output = sess.run({"x": x_gen}, feed_dict={z_in: z_grid}) | |
save_image(f"interpolation_{epoch}.png", output["x"].reshape(-1, 28, 28), n) | |
# Make GIF | |
make_gif("reconstruction") | |
if FLAGS.latent_size == 2: | |
make_gif("posterior") | |
make_gif("interpolation") | |
if __name__ == "__main__": | |
tf.logging.set_verbosity(tf.logging.INFO) | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment