Last active
November 18, 2019 12:12
-
-
Save sometimescasey/496beabb98da2fe10996f0d9b5428a6b to your computer and use it in GitHub Desktop.
TensorFlow tutorial: Small CNN to process MNIST data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# From instructions at https://www.tensorflow.org/versions/r1.0/get_started/mnist/pros | |
import argparse | |
import sys | |
import tensorflow as tf | |
import time | |
from tensorflow.examples.tutorials.mnist import input_data | |
FLAGS = None | |
def weight_variable(shape): | |
initial = tf.truncated_normal(shape, stddev=0.1) | |
return tf.Variable(initial) | |
def bias_variable(shape): | |
initial = tf.constant(0.1, shape=shape) | |
return tf.Variable(initial) | |
def conv2d(x, W): | |
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') | |
def max_pool_2x2(x): | |
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') | |
def main(_): | |
mnist = input_data.read_data_sets(FLAGS.data_dir) | |
# ---------------- MODEL DEF ------------------- | |
x = tf.placeholder(tf.float32, [None, 784]) | |
y_ = tf.placeholder(tf.int64, [None]) | |
W = tf.Variable(tf.zeros([784, 10])) | |
b = tf.Variable(tf.zeros([10])) | |
# --- 1 --- define our first convolutional layer | |
W_conv1 = weight_variable([5, 5, 1, 32]) | |
b_conv1 = bias_variable([32]) | |
x_image = tf.reshape(x, [-1, 28, 28, 1]) | |
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) | |
h_pool1 = max_pool_2x2(h_conv1) | |
# image is now 14x14 | |
# --- 2 --- DEFINE 2nd convolutional layer | |
W_conv2 = weight_variable([5, 5, 32, 64]) | |
b_conv2 = bias_variable([64]) | |
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2)+ b_conv2) | |
h_pool2 = max_pool_2x2(h_conv2) | |
# image is now 7x7 | |
# --- 3 --- DENSELY CONNECTED LAYER | |
# fully connected, 1024 neurons | |
W_fc1 = weight_variable([7*7*64, 1024]) | |
# input is 7x7 x 64 channels | |
b_fc1 = bias_variable([1024]) | |
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) | |
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) | |
# --- 4 --- DROPOUT | |
keep_prob = tf.placeholder(tf.float32) | |
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) | |
# --- 5 --- READOUT LAYER: readout to our 10 values in the one-hot label | |
W_fc2 = weight_variable([1024, 10]) | |
b_fc2 = bias_variable([10]) | |
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 | |
# --------------- TRAINING STEPS ----------------- | |
cross_entropy = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y_conv)) | |
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) | |
# evaluating our model | |
correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
# now run it! | |
sess = tf.InteractiveSession() | |
sess.run(tf.global_variables_initializer()) | |
# initialize all variables | |
tf.global_variables_initializer().run() | |
# run training step many times! | |
start = time.time() | |
for i in range(1000): | |
batch = mnist.train.next_batch(50) | |
if i%100 == 0: | |
train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob: 1.0}) | |
print("Step %d, training accuracy %g"%(i, train_accuracy)) | |
train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5}) | |
end = time.time() | |
# print how well the model does on the test data | |
print("test accuracy %g"%accuracy.eval(feed_dict={ | |
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0 | |
})) | |
print("Training time: {} sec".format(end - start)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--data_dir', | |
type=str, | |
default='/tmp/tensorflow/mnist/input_data', | |
help='Directory for storing input data') | |
FLAGS, unparsed = parser.parse_known_args() | |
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment