Skip to content

Instantly share code, notes, and snippets.

@springle
Last active July 13, 2017 19:59
Show Gist options
  • Save springle/95e5e766b7ebf2cdb80cad4c079e1f48 to your computer and use it in GitHub Desktop.
Save springle/95e5e766b7ebf2cdb80cad4c079e1f48 to your computer and use it in GitHub Desktop.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
def main(server):
# Import data
mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
y_ = tf.placeholder(tf.float32, [None, 10])
# The StopAtStepHook handles stopping after running given steps.
global_step = tf.contrib.framework.get_or_create_global_step()
stop_at_step_hook = tf.train.StopAtStepHook(last_step=30000)
final_ops_hook = tf.train.FinalOpsHook([W, b])
hooks=[stop_at_step_hook, final_ops_hook]
# Define loss and optimizer
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy, global_step=global_step)
# Create Testing Ops
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Start Monitored Training Session
local_count = 0
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(server.server_def.task_index == 0),
hooks=hooks) as sess:
while not sess.should_stop():
local_count += 1
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# Periodically Test Model
if local_count % 1000 == 0:
print("--------------------------------------")
print("GLOBAL STEP {}".format(sess.run(global_step)))
print("LOCAL COUNT {}".format(local_count))
print("BIASES: {}".format(sess.run(b)))
print("ACCURACY: {}".format(sess.run(accuracy,
feed_dict={x: mnist.test.images,
y_: mnist.test.labels})))
print("--------------------------------------\n")
print("FINAL VALUES: {}".format(final_ops_hook.final_ops_values))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment