Skip to content

Instantly share code, notes, and snippets.

@jwalsh
Last active February 12, 2019 16:39
Show Gist options
  • Save jwalsh/e40779b707f610787e21cc2fae18ca85 to your computer and use it in GitHub Desktop.
Save jwalsh/e40779b707f610787e21cc2fae18ca85 to your computer and use it in GitHub Desktop.
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
image_size = 28
labels_size = 10
learning_rate = 0.05
steps_number = 10000
batch_size = 200
training_data = tf.placeholder(tf.float32, [None, image_size*image_size])
labels = tf.placeholder(tf.float32, [None, labels_size])
# Variables to be tuned
W = tf.Variable(tf.truncated_normal([image_size*image_size, labels_size], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[labels_size]))
# Build the network (only output layer)
output = tf.matmul(training_data, W) + b
# Define the loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=output))
# Training step
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
# Accuracy calculation
correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Run the training
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
for i in range(steps_number):
# Get the next batch
input_batch, labels_batch = mnist.train.next_batch(batch_size)
feed_dict = {training_data: input_batch, labels: labels_batch}
# Run the training step
train_step.run(feed_dict=feed_dict)
# Print the accuracy progress on the batch every 100 steps
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict=feed_dict)
print("Step %d, training batch accuracy %g %%"%(i, train_accuracy*100))
# Evaluate on the test set
test_accuracy = accuracy.eval(feed_dict={training_data: mnist.test.images, labels: mnist.test.labels})
print("Test accuracy: %g %%"%(test_accuracy*100))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment