Created December 3, 2017 23:34
import time
import numpy as np
import tensorflow as tf
from keras.datasets import mnist, cifar10, cifar100
import matplotlib.pyplot as plt
from utils import get_minibatches_idx
# Based on
n_samples = 60000
eps = 1e-10
class VariationalAutoencoder(object):
# See "Auto-Encoding Variational Bayes" by Kingma and Welling for more details
def __init__(self, batch_size=100):
self.x = tf.placeholder(tf.float32, [None, 784])
# Encode each image as mean and variance vectors
h = tf.contrib.layers.fully_connected(self.x, 500)
h = tf.contrib.layers.fully_connected(h, 500)
self.z_mean = tf.contrib.layers.fully_connected(h, 20, activation_fn=tf.identity)
self.z_log_sigma_sq = tf.contrib.layers.fully_connected(h, 20, activation_fn=tf.identity)
# Draw one sample z from Gaussian distribution
noise = tf.random_normal((batch_size, 20), 0, 1, dtype=tf.float32)
# Add noise
self.z = self.z_mean + noise*tf.sqrt(tf.exp(self.z_log_sigma_sq))
# Decoder
h = tf.contrib.layers.fully_connected(self.z, 500)
h = tf.contrib.layers.fully_connected(h, 500)
self.x_reconstr_mean = tf.contrib.layers.fully_connected(h, 784, activation_fn=tf.nn.sigmoid)
reconstr_loss = -tf.reduce_sum(self.x * tf.log(eps + self.x_reconstr_mean)
+ (1-self.x) * tf.log(eps + 1 - self.x_reconstr_mean), 1)
# KL-divergence
latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq
- tf.square(self.z_mean)
- tf.exp(self.z_log_sigma_sq), 1)
self.cost = tf.reduce_mean(reconstr_loss + latent_loss) # average over batch
self.train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(self.cost)
self.sess = tf.Session()
def generate(self, z_mu=None):
if z_mu is None:
z_mu = np.random.normal(size=20)
return, feed_dict={self.z: z_mu})
def reconstruct(self, X):
return, feed_dict={self.x: X})
def train(self, X, batch_size=100, training_epochs=10):
print("\nStarting training")
for epoch in range(training_epochs):
avg_cost = 0.0
train_indices = get_minibatches_idx(len(X), batch_size, shuffle=True)
for it in train_indices:
batch_x = [X[i] for i in it]
_, cost =, self.cost), feed_dict={self.x: batch_x})
avg_cost += cost / n_samples * batch_size
print("Epoch:", '%d' % (epoch+1), "cost=", "{:.3f}".format(avg_cost))
def main():
dataset = 'mnist' # mnist, cifar10, cifar100
# Load the data
# It will be downloaded first if necessary
if dataset == 'mnist':
(X_train, _), (X_test, _) = mnist.load_data()
img_size = 28
num_channels = 1
elif dataset == 'cifar10':
(X_train, _), (X_test, _) = cifar10.load_data()
img_size = 32
num_channels = 3
elif dataset == 'cifar100':
(X_train, _), (X_test, _) = cifar100.load_data()
img_size = 32
num_channels = 3
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train = np.reshape(X_train,[-1,img_size,img_size,num_channels])
X_test = np.reshape(X_test,[-1,img_size,img_size,num_channels])
X_train /= 255
X_test /= 255
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
X_train = np.reshape(X_train,[-1,28*28])
X_test = np.reshape(X_test,[-1,28*28])
vae = VariationalAutoencoder(batch_size=100)
print("Model compiled")
vae.train(X_train, training_epochs=5)
x_sample = X_test[:100]
x_reconstruct = vae.reconstruct(x_sample)
#x_gen = vae.generate()
for i in range(5):
plt.subplot(5, 2, 2*i + 1)
plt.imshow(x_sample[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray")
plt.title("Test input")
plt.subplot(5, 2, 2*i + 2)
plt.imshow(x_reconstruct[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray")
if __name__ == "__main__":
