Created
July 22, 2019 07:26
-
-
Save R97416032/53682b6c9fe3a72616c8803a403e1c03 to your computer and use it in GitHub Desktop.
fashion-mnist数据集自动下载,构建分类模型(softmax回归)并训练、保存model、预测。
tensorboard可视化
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import os | |
from tensorflow.examples.tutorials.mnist import input_data | |
data = input_data.read_data_sets('data', one_hot=True) | |
x=tf.placeholder(tf.float32,[None,784]) | |
W=tf.Variable(tf.zeros([784,10])) | |
b=tf.Variable(tf.zeros([10])) | |
y=tf.nn.softmax(tf.matmul(x,W)+b) | |
y_=tf.placeholder("float",[None,10]) | |
cross_entropy=-tf.reduce_sum(y_*tf.log(y)) | |
tf.summary.scalar('cross_entropy', cross_entropy) | |
train_step=tf.train.GradientDescentOptimizer(0.0028).minimize(cross_entropy) | |
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) | |
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) | |
tf.summary.scalar('accrucy', accuracy) | |
saver=tf.train.Saver() | |
with tf.Session() as sess: | |
merged = tf.summary.merge_all() | |
writer = tf.summary.FileWriter("log/", sess.graph) | |
sess.run(tf.global_variables_initializer()) | |
for i in range(301): | |
batch_xs,batch_ys=data.train.next_batch(100) | |
sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys}) | |
if i%10==0: | |
result = sess.run(merged, feed_dict={x:batch_xs,y_:batch_ys}) | |
writer.add_summary(result, i) | |
print("accrucy :",(sess.run(accuracy, feed_dict={x: data.test.images, y_: data.test.labels}))) | |
saver.save(sess, "C:\\Users\\R\\PycharmProjects\\PyC\\Fashion\\model_data\\" + 'model.ckpt') | |
##500 0.001 cross_entropy=-tf.reduce_sum(y_*tf.log(y)) 0.811 | |
##500 0.001 cross_entropy=tf.reduce_mean(y_*tf.log(y)) 0.0025 | |
##500 0.001 cross_entropy=tf.reduce_mean(tf.square(y-y_)) 0.6761 | |
##50000 0.001 cross_entropy=tf.reduce_mean(tf.square(y-y_)) 0.7326 | |
##50000 0.01 cross_entropy=tf.reduce_mean(ty-y_)) 0.8917 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_checkpoint_path: "C:\\Users\\R\\PycharmProjects\\PyC\\Fashion\\model_data\\model.ckpt" | |
all_model_checkpoint_paths: "C:\\Users\\R\\PycharmProjects\\PyC\\Fashion\\model_data\\model.ckpt" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters