Created
January 10, 2017 14:41
-
-
Save Adagio-cantabile/9f29b7f003e3c2876f8e645e4f55c02d to your computer and use it in GitHub Desktop.
MNIST的学习资料
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow.examples.tutorials.mnist.input_data as input_data | |
minist = input_data.read_data_sets("MNIST_data/", one_hot=True) | |
#None代表任意个数 | |
#输入是个 batch数*像素数 的矩阵 | |
x = placeholder(tf.float32, [None, 784]) | |
#注意W,b的维度 | |
W = tf.Variable(tf.zeros([784,10])) | |
b = tf.Variable(tf.zeros([10])) | |
#y是 batch数*label数 的矩阵 | |
y = tf.nn.softmax(tf.matmul(x,W) + b) | |
y_ = tf.placeholder("float", [None, 10]) | |
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) | |
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) | |
#初始化变量 | |
init = tf.initialize_all_variables() | |
sess = tf.Session() | |
sess.run(init) | |
#每次训练随机选择100个图像,训练1000次 | |
for i in range(1000): | |
batch_xs, batch_ys = mnist.train_next_batch(100) | |
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) | |
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_, 1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) | |
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment