Skip to content

Instantly share code, notes, and snippets.

@sbarratt
Created August 11, 2016 16:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sbarratt/859dff5e89729bc7eaf5b8bb20b31c16 to your computer and use it in GitHub Desktop.
Save sbarratt/859dff5e89729bc7eaf5b8bb20b31c16 to your computer and use it in GitHub Desktop.
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import matplotlib.pyplot as plt
# Load data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Neural Network Initialization
x = tf.placeholder(tf.float32, shape=[None, 784])
W_1 = tf.Variable(tf.random_normal([784, 100], stddev=.05))
b_1 = tf.Variable(tf.random_normal([100], stddev=.05))
h_1 = tf.nn.sigmoid(tf.matmul(x, W_1) + b_1)
W_2 = tf.Variable(tf.random_normal([100, 10], stddev=.05))
b_2 = tf.Variable(tf.random_normal([10], stddev=.05))
y = tf.nn.softmax(tf.matmul(h_1, W_2) + b_2)
y_ = tf.placeholder(tf.float32, shape=[None, 10])
# Output Initialization
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.05).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Training
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
train_accs, validation_accs, test_accs = [], [], []
for i in range(100000):
batch = mnist.train.next_batch(100)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
if i % 100 == 0:
train_acc = accuracy.eval(feed_dict={x: mnist.train.images, y_: mnist.train.labels})
validation_acc = accuracy.eval(feed_dict={x: mnist.validation.images, y_: mnist.validation.labels})
test_acc = accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
train_acc.append( train_acc )
validation_acc.append( validation_acc )
test_acc.append( test_acc )
print ("Iteration %s: train accuracy %.3f" % (i, train_acc))
# Save Model
saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
# Plot Training Results
plt.figure(figsize=(18, 12))
plt.plot(1-np.array(train_acc), label="train")
plt.plot(1-np.array(validation_acc), label="validation")
plt.plot(1-np.array(test_acc), label="test")
plt.xlabel("Fraction Error")
plt.ylabel("Number of Batches in Hundreds")
plt.ylim([0, 0.15])
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment