Last active
April 3, 2016 20:43
-
-
Save nijotz/bbbfb27ad5336251cea2007ec1c7a8cd to your computer and use it in GitHub Desktop.
Handwritten digit analysis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) | |
# 784 pixels (28 x 28 pixel images), 10 possibile outputs | |
x = tf.placeholder(tf.float32, [None, 784]) # placeholder | |
W = tf.Variable(tf.zeros([784, 10])) # weights | |
b = tf.Variable(tf.zeros([10])) # biases | |
# y is the predicted probability distribution | |
y = tf.nn.softmax(tf.matmul(x, W) + b) | |
# y prime is the true distribution | |
y_ = tf.placeholder(tf.float32, [None, 10]) | |
# Use cross-entropy for the cost function | |
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) | |
# Use gradient descent on the cost function for training | |
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) | |
init = tf.initialize_all_variables() | |
sess = tf.Session() | |
sess.run(init) | |
for i in range(1000): | |
batch_xs, batch_ys = mnist.train.next_batch(100) | |
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) | |
# Measure the accuracy of the algorithm | |
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment