Created
September 8, 2017 05:38
-
-
Save anonymous/90364c863638ac34bd4b2656898a119f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
from tensorflow.examples.tutorials.mnist import input_data | |
def weight_variable(shape, name): | |
initial = tf.truncated_normal(shape, stddev=0.1) | |
return tf.Variable(initial, name=name) | |
def bias_variable(shape, name): | |
initial = tf.truncated_normal(shape, stddev=0.1) | |
return tf.Variable(initial, name=name) | |
def FC_layer(X, W, b): | |
return tf.matmul(X, W) + b | |
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) | |
n_pixels = 28*28 | |
latent_dim = 2 # 20 | |
h_dim = 500 | |
num_iterations = 1000 #100000 | |
recording_interval = 1000 | |
X = tf.placeholder(tf.float32, shape=[None, n_pixels], name='X') | |
W_enc = weight_variable([n_pixels, h_dim], 'W_enc') | |
b_enc = bias_variable([h_dim], 'b_enc') | |
h_enc = tf.nn.tanh(FC_layer(X, W_enc, b_enc)) | |
W_mu = weight_variable([h_dim, latent_dim], 'W_mu') | |
b_mu = bias_variable([latent_dim], 'b_mu') | |
mu = FC_layer(h_enc, W_mu, b_mu) | |
W_logstd = weight_variable([h_dim, latent_dim], 'W_logstd') | |
b_logstd = bias_variable([latent_dim], 'b_logstd') | |
logstd = FC_layer(h_enc, W_logstd, b_logstd) | |
noise = tf.random_normal([1, latent_dim]) | |
z = tf.add(mu, noise * tf.exp(.5 * logstd), name='z') | |
W_dec = weight_variable([latent_dim, h_dim], 'W_dec') | |
b_dec = bias_variable([h_dim], 'b_dec') | |
h_dec = tf.nn.tanh(FC_layer(z, W_dec, b_dec)) | |
W_reconstruct = weight_variable([h_dim, n_pixels], 'W_reconstruct') | |
b_reconstruct = bias_variable([n_pixels], 'b_reconstruct') | |
reconstruction = tf.nn.sigmoid(FC_layer(h_dec, W_reconstruct, b_reconstruct), name='reconstruction') | |
z_custom = tf.placeholder(tf.float32, shape=[None, latent_dim], name='z_custom') | |
h_dec_c = tf.nn.tanh(FC_layer(z_custom, W_dec, b_dec)) | |
reconstruction_custom = tf.nn.sigmoid(FC_layer(h_dec_c, W_reconstruct, b_reconstruct), name='reconstruction_custom') | |
log_likelihood = tf.reduce_sum(X * tf.log(reconstruction + 1e-9) + (1 - X) * tf.log(1 - reconstruction + 1e-9), reduction_indices=1) | |
KL_term = -.5 * tf.reduce_sum(1 + 2 * logstd - tf.pow(mu, 2) - tf.exp(2 * logstd), reduction_indices=1) | |
variational_lower_bound = tf.reduce_mean(log_likelihood - KL_term) | |
optimizer = tf.train.AdadeltaOptimizer().minimize(-variational_lower_bound) | |
tf.add_to_collection("optimizer", optimizer) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
saver = tf.train.Saver() | |
for i in range(num_iterations): | |
x_batch = np.round(mnist.train.next_batch(200)[0]) | |
sess.run(optimizer, feed_dict={X: x_batch}) | |
if not i % recording_interval: | |
print("Iteration: {}, Loss: {}".format(i, variational_lower_bound.eval(feed_dict={X: x_batch}))) | |
saver.save(sess, './models/vaemnist-2d', global_step=num_iterations) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
from tensorflow.examples.tutorials.mnist import input_data | |
import tensorflow as tf | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import matplotlib.animation as manimation | |
import numpy as np | |
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) | |
additional_training = False | |
num_iterations = 400000 | |
recording_interval = 1000 | |
model_to_import = './models/vaemnist-2d-1000000.meta' | |
total_iterations = num_iterations + int(model_to_import[model_to_import.index('-') + (3 if '2d' in model_to_import else 0) + 1:model_to_import.index('.', model_to_import.index('.') + 1)]) | |
movie = True | |
comparisons = False | |
grid_show = False | |
with tf.Session() as sess: | |
loader = tf.train.import_meta_graph(model_to_import) | |
loader.restore(sess, tf.train.latest_checkpoint('./models')) | |
graph = tf.get_default_graph() | |
reconstruction = graph.get_tensor_by_name('reconstruction:0') | |
reconstruction_custom = graph.get_tensor_by_name('reconstruction_custom:0') | |
X = graph.get_tensor_by_name('X:0') | |
z = graph.get_tensor_by_name('z:0') | |
z_custom = graph.get_tensor_by_name('z_custom:0') | |
optimizer = tf.get_collection("optimizer")[0] | |
if additional_training: | |
for i in tqdm(range(num_iterations)): | |
x_batch = np.round(mnist.train.next_batch(200)[0]) | |
sess.run(optimizer, feed_dict={X: x_batch}) | |
saver = tf.train.Saver() | |
saver.save(sess, './models/vaemnist-2d', global_step=total_iterations) | |
if comparisons: | |
num_pairs = 10 | |
image_indices = np.random.randint(0, 200, num_pairs) | |
for pair in range(num_pairs): | |
x = np.reshape(mnist.test.images[image_indices[pair]], (1, 28 * 28)) | |
plt.figure() | |
x_image = np.reshape(x, (28, 28)) | |
plt.subplot(121) | |
plt.imshow(x_image) | |
x_reconstruction = reconstruction.eval(feed_dict={X: x}) | |
x_reconstruction_image = np.reshape(x_reconstruction, (28, 28)) | |
print(sess.run(z, feed_dict={X: x}), mnist.test.labels[image_indices[pair]]) | |
plt.subplot(122) | |
plt.imshow(x_reconstruction_image) | |
nx = ny = 20 | |
x_values = np.linspace(-3, 3, nx) | |
y_values = np.linspace(-3, 3, ny) | |
if grid_show: | |
canvas = np.empty((28*ny, 28*nx)) | |
for i, yi in enumerate(x_values): | |
for j, xi in enumerate(y_values): | |
z_mu = np.array([[xi, yi]]*1) | |
x_mean = reconstruction_custom.eval(feed_dict={z_custom: z_mu}) | |
canvas[(nx-i-1)*28:(nx-i)*28, j*28:(j+1)*28] = x_mean[0].reshape(28, 28) | |
plt.figure(figsize=(8, 10)) | |
Xi, Yi = np.meshgrid(x_values, y_values) | |
plt.imshow(canvas, origin='upper', cmap='gray') | |
plt.tight_layout() | |
if movie: | |
FFMpegWriter = manimation.writers['ffmpeg'] | |
metadata = dict(title='vidtest', artist='c', comment='movsupp') | |
writer = FFMpegWriter(fps=15, metadata=metadata) | |
fig = plt.figure() | |
sorted_movie = np.zeros((nx * ny, 28, 28), dtype=float) | |
for x, yi in enumerate(x_values): | |
for j, xi in enumerate(y_values): | |
z_mu = np.array([[xi, yi]]*1) | |
x_mean = reconstruction_custom.eval(feed_dict={z_custom: z_mu}) | |
blip = (len(y_values) - j if x % 2 else j) - (x * nx + len(y_values) - j >= 399) | |
sorted_movie[x * nx + blip] = np.array(x_mean[0].reshape(28, 28)) | |
with writer.saving(fig, "writer_test.mp4", dpi=100): | |
for i in range(len(sorted_movie)): | |
if not i % (len(sorted_movie) / 10): print(i / len(sorted_movie) * 100) | |
if not np.count_nonzero(sorted_movie[i]): continue | |
plt.clf() | |
plt.imshow(sorted_movie[i]) | |
writer.grab_frame() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment