Skip to content

Instantly share code, notes, and snippets.

@R97416032
Created July 22, 2019 07:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save R97416032/a12a5a0ae380915558dc212879f4ec31 to your computer and use it in GitHub Desktop.
Save R97416032/a12a5a0ae380915558dc212879f4ec31 to your computer and use it in GitHub Desktop.
自动下载Fashio-mnist,构建分类模型(softmax回归)并训练、保存model、预测,tensorboard可视化
import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('data', one_hot=True)
x=tf.placeholder(tf.float32,[None,784])
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
y=tf.nn.softmax(tf.matmul(x,W)+b)
y_=tf.placeholder("float",[None,10])
cross_entropy=-tf.reduce_sum(y_*tf.log(y))
tf.summary.scalar('cross_entropy', cross_entropy)
train_step=tf.train.GradientDescentOptimizer(0.0028).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
tf.summary.scalar('accrucy', accuracy)
saver=tf.train.Saver()
with tf.Session() as sess:
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter("log/", sess.graph)
sess.run(tf.global_variables_initializer())
for i in range(301):
batch_xs,batch_ys=data.train.next_batch(100)
sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys})
if i%10==0:
result = sess.run(merged, feed_dict={x:batch_xs,y_:batch_ys})
writer.add_summary(result, i)#cmd 输入 tensorboard --dirlog=log地址
print("accrucy :",(sess.run(accuracy, feed_dict={x: data.test.images, y_: data.test.labels})))
saver.save(sess, "C:\\Users\\R\\PycharmProjects\\PyC\\Fashion\\model_data\\" + 'model.ckpt')
##500 0.001 cross_entropy=-tf.reduce_sum(y_*tf.log(y)) 0.811
##500 0.001 cross_entropy=tf.reduce_mean(y_*tf.log(y)) 0.0025
##500 0.001 cross_entropy=tf.reduce_mean(tf.square(y-y_)) 0.6761
##50000 0.001 cross_entropy=tf.reduce_mean(tf.square(y-y_)) 0.7326
##50000 0.01 cross_entropy=tf.reduce_mean(ty-y_)) 0.8917
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment