Skip to content

Instantly share code, notes, and snippets.

@wpm
Last active June 22, 2016 20:59
Show Gist options
  • Save wpm/b61e281ea380280c60cd6c872044e9ca to your computer and use it in GitHub Desktop.
Save wpm/b61e281ea380280c60cd6c872044e9ca to your computer and use it in GitHub Desktop.
Minimal TensorFlow Example
"""
A minimal implementation of the MNIST handwritten digits classification task in TensorFlow.
This runs MNIST images images through a single hidden layer and softmax loss function.
It demonstrates in a single Python source file the basics of creating a model, training and evaluating data sets, and
writing summaries that can be visualized by TensorBoard.
"""
from __future__ import division
import math
import tensorflow as tf
from six.moves import xrange as range
from tensorflow.examples.tutorials.mnist import input_data
PIXELS = 28 * 28
HIDDEN = 128
BATCH_SIZE = 50
LEARNING_RATE = 0.01
REPORT_INTERVAL = 100
SUMMARY_DIRECTORY = "summary"
def epoch(data, operations):
"""
Iterate one epoch of a data set in batches of size BATCH_SIZE through specified operations in the graph.
:param data: the data to iterate
:type data: data set defining num_examples and next_batch
:param operations: operations in a TensorFlow graph
:type operations: list of Operation
:return: iteration over the operation results for each batch
:rtype: iterator
"""
for _ in range(data.num_examples // BATCH_SIZE):
batch = data.next_batch(BATCH_SIZE)
yield session.run(operations, feed_dict={x: batch[0], y: batch[1]})
if tf.gfile.Exists(SUMMARY_DIRECTORY):
tf.gfile.DeleteRecursively(SUMMARY_DIRECTORY)
tf.gfile.MakeDirs(SUMMARY_DIRECTORY)
with tf.Graph().as_default():
with tf.name_scope("Input"):
x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, PIXELS], name="input_image")
y = tf.placeholder(tf.int64, shape=BATCH_SIZE, name="true_digit")
with tf.name_scope("Hidden"):
w = tf.Variable(
tf.truncated_normal(
[PIXELS, HIDDEN], stddev=1.0 / math.sqrt(float(PIXELS))
),
name="weights"
)
b = tf.Variable(tf.zeros([HIDDEN]), name="biases")
y_predicted = tf.matmul(x, w) + b
with tf.name_scope("Loss"):
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(y_predicted, y, name="cross_entropy")
loss = tf.reduce_mean(cross_entropy, name="mean_cross_entropy")
tf.scalar_summary(loss.op.name, loss)
with tf.name_scope("Train"):
global_step = tf.Variable(0, name="global_step", trainable=False)
training_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss, global_step=global_step)
with tf.name_scope("Evaluate"):
correct = tf.reduce_sum(tf.cast(tf.nn.in_top_k(y_predicted, y, 1), tf.int32), name="correct")
summary = tf.merge_all_summaries()
mnist = input_data.read_data_sets("MNIST_data")
with tf.Session() as session:
train_writer = tf.train.SummaryWriter(SUMMARY_DIRECTORY, session.graph)
session.run(tf.initialize_all_variables())
# Train the model, periodically evaluating on the validation set.
for i, l, s, _ in epoch(mnist.train, [global_step, loss, summary, training_step]):
if i % REPORT_INTERVAL == 0:
total_correct = sum(c for c in epoch(mnist.validation, correct))
print("Iteration %d: Training loss %0.5f, Validation correct %0.5f" %
(i, l, total_correct / mnist.validation.num_examples))
train_writer.add_summary(s, global_step=i)
train_writer.flush()
# Run the model on test data.
total_correct = sum(c for c in epoch(mnist.test, correct))
print("Test set correct: %0.5f" % (total_correct / mnist.test.num_examples))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment