Skip to content

Instantly share code, notes, and snippets.

@Myeongjoon
Created August 9, 2018 23:23
Show Gist options
  • Save Myeongjoon/15dcebf10c2e147fd624ee24735b7334 to your computer and use it in GitHub Desktop.
Save Myeongjoon/15dcebf10c2e147fd624ee24735b7334 to your computer and use it in GitHub Desktop.
RNN train, test 스텝
# Merge all the summaries
merged = tf.summary.merge_all()
# Get a small test set
test_data = mnist.test.images[:batch_size].reshape((-1, time_steps, element_size))
test_label = mnist.test.labels[:batch_size]
with tf.Session() as sess:
# Write summaries to LOG_DIR -- used by TensorBoard
train_writer = tf.summary.FileWriter(LOG_DIR + '/train',
graph=tf.get_default_graph())
test_writer = tf.summary.FileWriter(LOG_DIR + '/test',
graph=tf.get_default_graph())
sess.run(tf.global_variables_initializer())
for i in range(10000):
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 sequences of 28 pixels
batch_x = batch_x.reshape((batch_size, time_steps, element_size))
summary, _ = sess.run([merged, train_step],
feed_dict={_inputs: batch_x, y: batch_y})
# Add to summaries
train_writer.add_summary(summary, i)
if i % 1000 == 0:
acc, loss, = sess.run([accuracy, cross_entropy],
feed_dict={_inputs: batch_x,y: batch_y})
print("Iter " + str(i) + ", Minibatch Loss= " +
"{:.6f}".format(loss) + ", Training Accuracy= " +
"{:.5f}".format(acc))
if i % 100 == 0:
# Calculate accuracy for 128 mnist test images and
# add to summaries
summary, acc = sess.run([merged, accuracy],
feed_dict={_inputs: test_data,y: test_label})
test_writer.add_summary(summary, i)
test_acc = sess.run(accuracy, feed_dict={_inputs: test_data,y: test_label})
print("Test Accuracy:", test_acc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment