Created
July 22, 2019 07:29
-
-
Save R97416032/8c6950d1b01001126fa0a9353aa3098f to your computer and use it in GitHub Desktop.
自动下载Fashion-mnist数据集,构建分类模型(softmax回归、cnn)并训练、保存model、预测,tensorboard可视化
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import os | |
from tensorflow.examples.tutorials.mnist import input_data | |
data = input_data.read_data_sets('data', one_hot=True) | |
x=tf.placeholder(tf.float32,[None,784]) | |
W=tf.Variable(tf.zeros([784,10])) | |
b=tf.Variable(tf.zeros([10])) | |
y=tf.nn.softmax(tf.matmul(x,W)+b) | |
y_=tf.placeholder("float",[None,10]) | |
cross_entropy=-tf.reduce_sum(y_*tf.log(y)) | |
tf.summary.scalar('cross_entropy', cross_entropy) | |
train_step=tf.train.GradientDescentOptimizer(0.0028).minimize(cross_entropy) | |
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) | |
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) | |
tf.summary.scalar('accrucy', accuracy) | |
saver=tf.train.Saver() | |
with tf.Session() as sess: | |
merged = tf.summary.merge_all() | |
writer = tf.summary.FileWriter("log/", sess.graph) | |
sess.run(tf.global_variables_initializer()) | |
for i in range(301): | |
batch_xs,batch_ys=data.train.next_batch(100) | |
sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys}) | |
if i%10==0: | |
result = sess.run(merged, feed_dict={x:batch_xs,y_:batch_ys}) | |
writer.add_summary(result, i) | |
print("accrucy :",(sess.run(accuracy, feed_dict={x: data.test.images, y_: data.test.labels}))) | |
saver.save(sess, "C:\\Users\\R\\PycharmProjects\\PyC\\Fashion\\model_data\\" + 'model.ckpt') | |
##500 0.001 cross_entropy=-tf.reduce_sum(y_*tf.log(y)) 0.811 | |
##500 0.001 cross_entropy=tf.reduce_mean(y_*tf.log(y)) 0.0025 | |
##500 0.001 cross_entropy=tf.reduce_mean(tf.square(y-y_)) 0.6761 | |
##50000 0.001 cross_entropy=tf.reduce_mean(tf.square(y-y_)) 0.7326 | |
##50000 0.01 cross_entropy=tf.reduce_mean(ty-y_)) 0.8917 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_checkpoint_path: "C:\\Users\\R\\PycharmProjects\\PyC\\Fashion\\model_data\\model.ckpt" | |
all_model_checkpoint_paths: "C:\\Users\\R\\PycharmProjects\\PyC\\Fashion\\model_data\\model.ckpt" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters