Created
November 2, 2018 09:04
-
-
Save kezunlin/9b3a36c2fc52f5c0436cb49fe62b1124 to your computer and use it in GitHub Desktop.
Tensorflow MNIST CNN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
# http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html | |
# http://www.tensorfly.cn/tfdoc/tutorials/mnist_pros.html | |
""" | |
import numpy as np | |
a = [1,0,0,0,0] | |
b = [0,1,0,0,0] | |
c = [0,0,1,0,0] | |
data = [a,b,c] | |
data2 = [a,b,a] | |
print(np.argmax(a)) # 1 dim---> 0 dim | |
print(np.argmax(b)) # 1 dim---> 0 dim | |
print(np.argmax(c)) # 1 dim---> 0 dim | |
#print(np.argmax(data,0)) # 2 dim: matrix---> 1 dim: vector | |
#print(np.argmax(data2,0)) | |
print(np.argmax(data,1)) # 2 dim: matrix---> 1 dim: vector | |
print(np.argmax(data2,1)) | |
print( np.equal(np.argmax(data,1),np.argmax(data2,1)) ) # 1 vector | |
0 | |
1 | |
2 | |
[0 1 2] | |
[0 1 0] | |
[ True True False] | |
""" | |
import os | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
# 默认为0:输出所有log信息 | |
# 设置为1:进一步屏蔽INFO信息 | |
# 设置为2:进一步屏蔽WARNING信息 | |
# 设置为3:进一步屏蔽ERROR信息 | |
import tensorflow as tf | |
import tensorflow.examples.tutorials.mnist.input_data as input_data | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) | |
x = tf.placeholder("float", [None, 784]) | |
y_ = tf.placeholder("float", [None,10]) # real results | |
# W, b initializer | |
def weight_variable(shape): | |
initial = tf.truncated_normal(shape, stddev=0.1) | |
return tf.Variable(initial) | |
def bias_variable(shape): | |
initial = tf.constant(0.1, shape=shape) | |
return tf.Variable(initial) | |
# conv2d and maxpool | |
def conv2d(x, W): | |
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') | |
def max_pool_2x2(x): | |
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], | |
strides=[1, 2, 2, 1], padding='SAME') | |
x_image = tf.reshape(x, [-1,28,28,1]) | |
# conv1 pool1 | |
W_conv1 = weight_variable([5, 5, 1, 32]) | |
b_conv1 = bias_variable([32]) | |
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) | |
h_pool1 = max_pool_2x2(h_conv1) | |
# conv2 pool2 | |
W_conv2 = weight_variable([5, 5, 32, 64]) | |
b_conv2 = bias_variable([64]) | |
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) | |
h_pool2 = max_pool_2x2(h_conv2) | |
# fc1 (relu) | |
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) | |
W_fc1 = weight_variable([7 * 7 * 64, 1024]) | |
b_fc1 = bias_variable([1024]) | |
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) | |
# fc1 dropout (used for train p = 0.5, not used for test p =1.0) | |
keep_prob = tf.placeholder("float") | |
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) | |
# fc2 (softmax) | |
W_fc2 = weight_variable([1024, 10]) | |
b_fc2 = bias_variable([10]) | |
y =tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) # predicted results | |
# define cost and train step | |
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) # cost: 2 matrix ---> scalar | |
#train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) | |
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) | |
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) # [True,True,False] | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) # [1,1,0] ---> 0.67 | |
# initializer before sess.run() | |
init = tf.global_variables_initializer() | |
config_proto = tf.ConfigProto() | |
config_proto.gpu_options.allow_growth = True # allow gpu dynamic grow | |
sess = tf.Session(config=config_proto) | |
sess.run(init) # initial all variables to default 0 | |
with tf.device("/gpu:0"): | |
#if True: | |
for i in range(20000): | |
batch_xs, batch_ys = mnist.train.next_batch(50) | |
# output accuracy for train data every 100 iterations | |
if i%100 == 0: | |
train_accuracy = sess.run(accuracy, feed_dict={x: batch_xs, y_: batch_ys,keep_prob: 1.0}) # p=1.0 for test | |
print("step {0}, training accuracy {1:.4f}".format(i, train_accuracy)) | |
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys,keep_prob: 0.5}) # p=0.5 for train | |
print("") | |
test_accuracy = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels,keep_prob: 1.0}) | |
print("test accuracy ",test_accuracy) # p=1.0 for test we got 99.20% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment