Skip to content

Instantly share code, notes, and snippets.

@rbrigden rbrigden/autoencoder.py
Last active Jul 7, 2017

Embed
What would you like to do?
autoencoder
from __future__ import division, print_function, absolute_import
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data", one_hot=False)
EPOCHS = 150
BATCH_SIZE = 100
EPOCH_SIZE = 50000
GAMMA = 0.00008
tf.app.flags.DEFINE_boolean("train", False , "train if true, eval if false")
tf.app.flags.DEFINE_boolean("plot_point", False, "plot points")
FLAGS = tf.app.flags.FLAGS
def _truncated_normal_initializer(stddev=0.1):
def _initializer(shape, dtype=tf.float32, partition_info=None):
return tf.truncated_normal(shape, stddev=stddev, dtype=dtype)
return _initializer
def _variable_on_cpu(name, shape, initializer=None, reuse=None):
with tf.device("/cpu:0"):
dtype = "float32"
with tf.variable_scope(name, reuse=reuse):
return tf.get_variable(name, shape=shape, initializer=initializer)
def _affine(x, weights_shape=None, bias_shape=None, decay=None, name=None):
W = _variable_on_cpu("weights",
shape=weights_shape,
initializer=_truncated_normal_initializer())
b = _variable_on_cpu("biases",
shape=bias_shape,
initializer=tf.constant_initializer(0.0))
# add L2 regularization
if decay is not None:
wd = tf.multiply(decay, tf.nn.l2_loss(W), name="l2_weight_loss")
tf.add_to_collection("losses", wd)
return tf.add(tf.matmul(x, W), b, name=name)
def _conv_2d(x, kernel_shape, bias_shape, stride=1, padding="SAME",
decay=None, name=None, reuse=None):
W = _variable_on_cpu(name="weights",
shape=kernel_shape,
initializer=_truncated_normal_initializer(),
reuse=reuse)
b = _variable_on_cpu(name="biases",
shape=bias_shape,
initializer=tf.constant_initializer(0.0),
reuse=reuse)
# add L2 regularization
if decay is not None:
wd = tf.multiply(decay, tf.nn.l2_loss(W), name="l2_weight_loss")
tf.add_to_collection("losses", wd)
# stride of 1 in batch + channel dimensons for standard convolutions
conv = tf.nn.conv2d(x,
filter=W,
strides=[1, stride, stride, 1],
padding=padding,
data_format="NHWC")
return tf.add(conv,b, name=name)
def encode(x):
# Architecture
# Conv64
with tf.variable_scope("encoder"):
with tf.variable_scope("conv1") as scope:
c1z = _conv_2d(tf.reshape(x, [-1, 28, 28, 1]), kernel_shape=[5, 5, 1, 32],
bias_shape=[32], decay=0.01)
c1y = tf.nn.relu(c1z, name=scope.name)
p1 = tf.nn.max_pool(c1y, ksize=[1,2,2,1],
strides=[1,2,2,1], padding="SAME")
with tf.variable_scope("conv2") as scope:
c2z = _conv_2d(p1, kernel_shape=[5, 5, 32, 64],
bias_shape=[64], decay=0.01)
c2y = tf.nn.relu(c2z, name=scope.name)
p2 = tf.nn.max_pool(c2y, ksize=[1,2,2,1],
strides=[1,2,2,1], padding="SAME")
with tf.variable_scope("affine1") as scope:
resh = tf.reshape(p2, [-1, 7*7*64])
a1z = _affine(resh, weights_shape=[7*7*64, 1024],
bias_shape=[1024], decay=0.01)
a1y = tf.nn.relu(a1z, name=scope.name)
with tf.variable_scope("affine2") as scope:
a2z = _affine(a1y, weights_shape=[1024, 512],
bias_shape=[512], decay=0.01)
a2y = tf.nn.relu(a2z, name=scope.name)
with tf.variable_scope("affine3") as scope:
a3z = _affine(a2y, weights_shape=[512, 2],
bias_shape=[2], decay=0.01)
a3y = tf.nn.relu(a3z, name=scope.name)
return a3y
def decode(z):
with tf.variable_scope("decoder"):
with tf.variable_scope("affine1") as scope:
a1z = _affine(z, weights_shape=[2, 512],
bias_shape=[512], decay=0.01)
a1y = tf.nn.relu(a1z, name=scope.name)
with tf.variable_scope("affine2") as scope:
a2z = _affine(a1y, weights_shape=[512, 1024],
bias_shape=[1024], decay=0.01)
a2y = tf.nn.relu(a2z, name=scope.name)
with tf.variable_scope("affine3") as scope:
a3z = _affine(a2y, weights_shape=[1024, 7*7*64],
bias_shape=[7*7*64], decay=0.01,
name=scope.name)
a3y = tf.nn.relu(a3z, name=scope.name)
with tf.variable_scope("deconv1") as scope:
resh = tf.reshape(a3y, [-1, 7, 7, 64])
up1 = tf.image.resize_images(resh, [14, 14],
method=tf.image.ResizeMethod.BILINEAR)
c1z = _conv_2d(up1, kernel_shape=[5, 5, 64, 32],
bias_shape=[32], decay=0.01)
c1y = tf.nn.relu(c1z, name=scope.name)
with tf.variable_scope("deconv2") as scope:
up2 = tf.image.resize_images(c1y, [28, 28],
method=tf.image.ResizeMethod.BILINEAR)
c2z = _conv_2d(up2, kernel_shape=[5, 5, 32, 1],
bias_shape=[1], decay=0.01)
c2y = tf.reshape(tf.nn.relu(c2z, name=scope.name), [-1, 784])
return c2y
def run():
with tf.Graph().as_default() as g:
X = Y = tf.placeholder("float", shape=[None, 784])
with tf.device("/gpu:0"):
encode_op = encode(X)
pred = decode(encode_op)
loss = tf.reduce_mean(tf.pow(Y - pred, 2)) # MSE
optim = tf.train.AdamOptimizer(GAMMA)
train_op = optim.minimize(loss)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
if FLAGS.train:
sess.run(init)
for e in range(EPOCHS):
batches = int(EPOCH_SIZE/BATCH_SIZE)
for i in range(batches):
xs, _ = mnist.train.next_batch(BATCH_SIZE)
_, l = sess.run([train_op, loss], feed_dict={X:xs})
print("Epoch: {}/{}, Loss: {}".format(e, EPOCHS, l))
save_path = saver.save(sess, "model.ckpt")
print("Done training")
sess.close()
else:
saver.restore(sess, "model.ckpt")
idxs = np.random.choice(len(mnist.test.images), 10)
if FLAGS.plot_point:
# generate plot
points = sess.run(encode_op, feed_dict={X: mnist.test.images[:10000]})
labels = mnist.test.labels[:10000]
cs = dict()
for i in range(10):
cs[i] = ([],[])
for point, label in zip(points, labels):
x = point[0]
y = point[1]
cs[label][0].append(x)
cs[label][1].append(y)
data = [cs[0], cs[1], cs[2], cs[3], cs[4], cs[5], cs[6], cs[7], cs[8],
cs[9]]
colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
groups = ["0","1","2","3","4","5","6","7","8","9"]
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, axisbg="1.0")
for data, color, group in zip(data, colors, groups):
x, y = data
ax.scatter(x, y, alpha=0.8, c=color, edgecolors='none', s=30,
label=group)
# Create plot
plt.title('Matplot scatter plot')
plt.legend(loc=2)
plt.show()
else:
images_to_show = np.take(mnist.test.images, idxs, axis=0)
encode_decode = sess.run(pred, feed_dict={X: images_to_show})
sess.close()
f, a = plt.subplots(2, 10, figsize=(10, 2))
for i in range(10):
a[0][i].imshow(np.reshape(images_to_show[i], (28, 28)))
a[0][i].axis('off')
a[1][i].axis('off')
a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))
f.save("plot.png")
#plt.draw()
#plt.waitforbuttonpress()
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.