Skip to content

Instantly share code, notes, and snippets.

@wedesoft
Last active November 6, 2017 10:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wedesoft/c2827ce7abedae2ebe04ad68909b8c84 to your computer and use it in GitHub Desktop.
Save wedesoft/c2827ce7abedae2ebe04ad68909b8c84 to your computer and use it in GitHub Desktop.
MNIST autoencoder implemented using Tensorflow
#!/usr/bin/env python3
# http://blog.aloni.org/posts/backprop-with-tensorflow/
import random
import numpy as np
import tensorflow as tf
import pickle
import gzip
import cv2
def show(title, img, wait=-1):
cv2.imshow(title, cv2.resize(img.reshape(28, 28), (280, 280)))
return cv2.waitKey(wait or 1) != 27
if __name__ == '__main__':
training, validation, testing = pickle.load(gzip.open('mnist.pkl.gz', 'rb'), encoding='iso-8859-1')
with tf.Session() as sess:
saver = tf.train.import_meta_graph('auto.meta')
saver.restore(sess, 'auto')
image = np.random.rand(1, 784)
prediction = tf.get_collection('prediction')[0]
while True:
index = random.randrange(len(testing[0]))
noise = np.random.rand(1, 784)
image = np.clip(testing[0][index:index + 1] + (2 * noise - 1) ** 5, 0, 1)
show('input', image, False)
result = sess.run(prediction, feed_dict={'x:0': image})
if not show('prediction', result, 300):
break
#!/usr/bin/env python3
# https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/autoencoder.py
import math
import pickle
import gzip
from functools import reduce
from operator import add
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import cv2
def random_choice(count, size):
return np.random.choice(count, size, replace=False)
def random_selection(size, *arrays):
indices = random_choice(len(arrays[0]), size)
result = tuple(np.take(array, indices, axis=0) for array in arrays)
return result[0] if len(result) == 1 else result
def show(title, img, wait=True):
cv2.imshow(title, cv2.resize(img.reshape(28, 28), (280, 280)))
return cv2.waitKey(wait or 1) != 27
if __name__ == '__main__':
# http://deeplearning.net/data/mnist/mnist.pkl.gz
training, validation, testing = pickle.load(gzip.open('mnist.pkl.gz', 'rb'), encoding='iso-8859-1')
n_iterations = 50000
batch_size = 256
n_hidden1 = 200
n_hidden2 = 30
alpha = 0.1
x = tf.placeholder(tf.float32, [None, 28 * 28], name='x')
m1 = tf.Variable(tf.random_normal([28 * 28, n_hidden1], stddev=0.1))
b1 = tf.Variable(tf.random_normal([n_hidden1]))
m2 = tf.Variable(tf.random_normal([n_hidden1, n_hidden2], stddev=0.1))
b2 = tf.Variable(tf.random_normal([n_hidden2]))
m3 = tf.Variable(tf.random_normal([n_hidden2, n_hidden1], stddev=0.1))
b3 = tf.Variable(tf.random_normal([n_hidden1]))
m4 = tf.Variable(tf.random_normal([n_hidden1, 28 * 28], stddev=0.1))
b4 = tf.Variable(tf.random_normal([28 * 28]))
theta = [m1, b1, m2, b2, m3, b3, m4, b4]
a0 = x
z1 = tf.add(tf.matmul( x, m1), b1)
a1 = tf.sigmoid(z1)
z2 = tf.add(tf.matmul(a1, m2), b2)
a2 = tf.sigmoid(z2)
z3 = tf.add(tf.matmul(a2, m3), b3)
a3 = tf.sigmoid(z3)
z4 = tf.add(tf.matmul(a3, m4), b4)
a4 = tf.sigmoid(z4)
h = a4
m = tf.to_float(tf.shape(x)[0])
cost = tf.reduce_sum(tf.pow(x - h, 2)) / m
dtheta = tf.gradients(cost, theta)
step = [tf.assign(value, tf.subtract(value, tf.multiply(alpha, dvalue))) for value, dvalue in zip(theta, dtheta)]
saver = tf.train.Saver()
with tf.Session() as session:
train = {x: training[0]}
j_train = 0.5 * 768
session.run(tf.global_variables_initializer())
progress = tqdm(range(n_iterations))
for i in progress:
selection = random_selection(batch_size, train[x])
mini_batch = {x: selection}
j_train = 0.99 * j_train + 0.01 * session.run(cost, feed_dict=mini_batch)
progress.set_description('cost: %8.6f' % j_train)
if i % 50 == 0:
show('original', selection[0:1], False)
show('reconstruction', session.run(h, feed_dict={x: selection[0:1]}), 10)
session.run(step, feed_dict=mini_batch)
tf.add_to_collection('prediction', h)
saver.save(session, './auto')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment