Last active
July 13, 2017 19:59
-
-
Save springle/95e5e766b7ebf2cdb80cad4c079e1f48 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from tensorflow.examples.tutorials.mnist import input_data | |
import tensorflow as tf | |
def main(server): | |
# Import data | |
mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True) | |
# Create the model | |
x = tf.placeholder(tf.float32, [None, 784]) | |
W = tf.Variable(tf.zeros([784, 10])) | |
b = tf.Variable(tf.zeros([10])) | |
y = tf.matmul(x, W) + b | |
y_ = tf.placeholder(tf.float32, [None, 10]) | |
# The StopAtStepHook handles stopping after running given steps. | |
global_step = tf.contrib.framework.get_or_create_global_step() | |
stop_at_step_hook = tf.train.StopAtStepHook(last_step=30000) | |
final_ops_hook = tf.train.FinalOpsHook([W, b]) | |
hooks=[stop_at_step_hook, final_ops_hook] | |
# Define loss and optimizer | |
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) | |
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy, global_step=global_step) | |
# Create Testing Ops | |
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
# Start Monitored Training Session | |
local_count = 0 | |
with tf.train.MonitoredTrainingSession(master=server.target, | |
is_chief=(server.server_def.task_index == 0), | |
hooks=hooks) as sess: | |
while not sess.should_stop(): | |
local_count += 1 | |
batch_xs, batch_ys = mnist.train.next_batch(100) | |
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) | |
# Periodically Test Model | |
if local_count % 1000 == 0: | |
print("--------------------------------------") | |
print("GLOBAL STEP {}".format(sess.run(global_step))) | |
print("LOCAL COUNT {}".format(local_count)) | |
print("BIASES: {}".format(sess.run(b))) | |
print("ACCURACY: {}".format(sess.run(accuracy, | |
feed_dict={x: mnist.test.images, | |
y_: mnist.test.labels}))) | |
print("--------------------------------------\n") | |
print("FINAL VALUES: {}".format(final_ops_hook.final_ops_values)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment