Skip to content

Instantly share code, notes, and snippets.

/vaemnist.py
Created Sep 8, 2017

Embed
What would you like to do?
import time
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
def weight_variable(shape, name):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial, name=name)
def bias_variable(shape, name):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial, name=name)
def FC_layer(X, W, b):
return tf.matmul(X, W) + b
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
n_pixels = 28*28
latent_dim = 2 # 20
h_dim = 500
num_iterations = 1000 #100000
recording_interval = 1000
X = tf.placeholder(tf.float32, shape=[None, n_pixels], name='X')
W_enc = weight_variable([n_pixels, h_dim], 'W_enc')
b_enc = bias_variable([h_dim], 'b_enc')
h_enc = tf.nn.tanh(FC_layer(X, W_enc, b_enc))
W_mu = weight_variable([h_dim, latent_dim], 'W_mu')
b_mu = bias_variable([latent_dim], 'b_mu')
mu = FC_layer(h_enc, W_mu, b_mu)
W_logstd = weight_variable([h_dim, latent_dim], 'W_logstd')
b_logstd = bias_variable([latent_dim], 'b_logstd')
logstd = FC_layer(h_enc, W_logstd, b_logstd)
noise = tf.random_normal([1, latent_dim])
z = tf.add(mu, noise * tf.exp(.5 * logstd), name='z')
W_dec = weight_variable([latent_dim, h_dim], 'W_dec')
b_dec = bias_variable([h_dim], 'b_dec')
h_dec = tf.nn.tanh(FC_layer(z, W_dec, b_dec))
W_reconstruct = weight_variable([h_dim, n_pixels], 'W_reconstruct')
b_reconstruct = bias_variable([n_pixels], 'b_reconstruct')
reconstruction = tf.nn.sigmoid(FC_layer(h_dec, W_reconstruct, b_reconstruct), name='reconstruction')
z_custom = tf.placeholder(tf.float32, shape=[None, latent_dim], name='z_custom')
h_dec_c = tf.nn.tanh(FC_layer(z_custom, W_dec, b_dec))
reconstruction_custom = tf.nn.sigmoid(FC_layer(h_dec_c, W_reconstruct, b_reconstruct), name='reconstruction_custom')
log_likelihood = tf.reduce_sum(X * tf.log(reconstruction + 1e-9) + (1 - X) * tf.log(1 - reconstruction + 1e-9), reduction_indices=1)
KL_term = -.5 * tf.reduce_sum(1 + 2 * logstd - tf.pow(mu, 2) - tf.exp(2 * logstd), reduction_indices=1)
variational_lower_bound = tf.reduce_mean(log_likelihood - KL_term)
optimizer = tf.train.AdadeltaOptimizer().minimize(-variational_lower_bound)
tf.add_to_collection("optimizer", optimizer)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for i in range(num_iterations):
x_batch = np.round(mnist.train.next_batch(200)[0])
sess.run(optimizer, feed_dict={X: x_batch})
if not i % recording_interval:
print("Iteration: {}, Loss: {}".format(i, variational_lower_bound.eval(feed_dict={X: x_batch})))
saver.save(sess, './models/vaemnist-2d', global_step=num_iterations)
from tqdm import tqdm
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as manimation
import numpy as np
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
additional_training = False
num_iterations = 400000
recording_interval = 1000
model_to_import = './models/vaemnist-2d-1000000.meta'
total_iterations = num_iterations + int(model_to_import[model_to_import.index('-') + (3 if '2d' in model_to_import else 0) + 1:model_to_import.index('.', model_to_import.index('.') + 1)])
movie = True
comparisons = False
grid_show = False
with tf.Session() as sess:
loader = tf.train.import_meta_graph(model_to_import)
loader.restore(sess, tf.train.latest_checkpoint('./models'))
graph = tf.get_default_graph()
reconstruction = graph.get_tensor_by_name('reconstruction:0')
reconstruction_custom = graph.get_tensor_by_name('reconstruction_custom:0')
X = graph.get_tensor_by_name('X:0')
z = graph.get_tensor_by_name('z:0')
z_custom = graph.get_tensor_by_name('z_custom:0')
optimizer = tf.get_collection("optimizer")[0]
if additional_training:
for i in tqdm(range(num_iterations)):
x_batch = np.round(mnist.train.next_batch(200)[0])
sess.run(optimizer, feed_dict={X: x_batch})
saver = tf.train.Saver()
saver.save(sess, './models/vaemnist-2d', global_step=total_iterations)
if comparisons:
num_pairs = 10
image_indices = np.random.randint(0, 200, num_pairs)
for pair in range(num_pairs):
x = np.reshape(mnist.test.images[image_indices[pair]], (1, 28 * 28))
plt.figure()
x_image = np.reshape(x, (28, 28))
plt.subplot(121)
plt.imshow(x_image)
x_reconstruction = reconstruction.eval(feed_dict={X: x})
x_reconstruction_image = np.reshape(x_reconstruction, (28, 28))
print(sess.run(z, feed_dict={X: x}), mnist.test.labels[image_indices[pair]])
plt.subplot(122)
plt.imshow(x_reconstruction_image)
nx = ny = 20
x_values = np.linspace(-3, 3, nx)
y_values = np.linspace(-3, 3, ny)
if grid_show:
canvas = np.empty((28*ny, 28*nx))
for i, yi in enumerate(x_values):
for j, xi in enumerate(y_values):
z_mu = np.array([[xi, yi]]*1)
x_mean = reconstruction_custom.eval(feed_dict={z_custom: z_mu})
canvas[(nx-i-1)*28:(nx-i)*28, j*28:(j+1)*28] = x_mean[0].reshape(28, 28)
plt.figure(figsize=(8, 10))
Xi, Yi = np.meshgrid(x_values, y_values)
plt.imshow(canvas, origin='upper', cmap='gray')
plt.tight_layout()
if movie:
FFMpegWriter = manimation.writers['ffmpeg']
metadata = dict(title='vidtest', artist='c', comment='movsupp')
writer = FFMpegWriter(fps=15, metadata=metadata)
fig = plt.figure()
sorted_movie = np.zeros((nx * ny, 28, 28), dtype=float)
for x, yi in enumerate(x_values):
for j, xi in enumerate(y_values):
z_mu = np.array([[xi, yi]]*1)
x_mean = reconstruction_custom.eval(feed_dict={z_custom: z_mu})
blip = (len(y_values) - j if x % 2 else j) - (x * nx + len(y_values) - j >= 399)
sorted_movie[x * nx + blip] = np.array(x_mean[0].reshape(28, 28))
with writer.saving(fig, "writer_test.mp4", dpi=100):
for i in range(len(sorted_movie)):
if not i % (len(sorted_movie) / 10): print(i / len(sorted_movie) * 100)
if not np.count_nonzero(sorted_movie[i]): continue
plt.clf()
plt.imshow(sorted_movie[i])
writer.grab_frame()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.