Skip to content

Instantly share code, notes, and snippets.

@namakemono
Last active September 6, 2016 18:05
Show Gist options
  • Save namakemono/1ff6604065fa6eb32983 to your computer and use it in GitHub Desktop.
Save namakemono/1ff6604065fa6eb32983 to your computer and use it in GitHub Desktop.
# mnist_expert.py
import tensorflow as tf
import os
if not os.path.exists("input_data.py"):
os.system("curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/mnist/input_data.py -o input_data.py")
import input_data
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev=0.1))
def bias_variable(shape):
return tf.Variable(tf.constant(0.1, shape=shape))
def main(args):
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
with tf.Session() as sess:
# Variables
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10]) # y'
keep_prob = tf.placeholder("float") # Used for Dropout
# Build a Multilayer Convolutional Network: INPUT -> [CONV -> RELU -> POOL] * 2 -> FC -> RELU -> FC
W1, b1 = weight_variable([3,3,1,32]), bias_variable([32]) # 3x3 Filter, input channel: 1, output channel: 32
W2, b2 = weight_variable([3,3,32,64]), bias_variable([64]) # 3x3 Filter, input channel: 32, output channel: 64
W3, b3 = weight_variable([7*7*64,1024]), bias_variable([1024])
W4, b4 = weight_variable([1024,10]), bias_variable([10])
x_ = tf.reshape(x, [-1, 28, 28, 1]) # 28x28, channel=1
h1 = max_pool_2x2(tf.nn.relu(conv2d(x_, W1) + b1)) # First Convolutional Layer: CONV -> RELU -> POOL, image size: 28x28 -> 14x14
h2 = max_pool_2x2(tf.nn.relu(conv2d(h1, W2) + b2)) # Second Convolutional Layer: CONV -> RELU -> POOL, image size: 14x14 -> 7x7
h3 = tf.nn.relu(tf.matmul(tf.reshape(h2, [-1, 7*7*64]), W3) + b3) # Densely Connected Layer: FC -> RELU
h4 = tf.nn.dropout(h3, keep_prob) # Dropout
y = tf.nn.softmax(tf.matmul(h4, W4) + b4) # Readout Layer: FC
# Train and Evaluate the Model
cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # H = y' * log(y)
optimizer = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
sess.run(tf.initialize_all_variables())
for i in range(1000):
images, labels = mnist.train.next_batch(50)
optimizer.run(feed_dict={x: images, y_: labels, keep_prob: 0.5})
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
if i % 100 == 0:
print "[%d]\ttest-accuracy:%.5f\ttest-accuracy:%.5f" % (i, accuracy.eval(feed_dict={x: images, y_: labels, keep_prob: 1.0}), accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
print "Accuracy: %.3f" % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
if __name__ == "__main__":
tf.app.run()
@namakemono
Copy link
Author

$time python mnist_expert.py
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4
[0] test-accuracy:0.10000 test-accuracy:0.10110
[100] test-accuracy:0.80000 test-accuracy:0.77120
[200] test-accuracy:0.94000 test-accuracy:0.87740
[300] test-accuracy:0.86000 test-accuracy:0.89560
[400] test-accuracy:0.98000 test-accuracy:0.91680
[500] test-accuracy:0.90000 test-accuracy:0.92050
[600] test-accuracy:0.96000 test-accuracy:0.93260
[700] test-accuracy:0.90000 test-accuracy:0.93360
[800] test-accuracy:0.88000 test-accuracy:0.94250
[900] test-accuracy:1.00000 test-accuracy:0.94650
Accuracy: 0.951

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment