Last active
February 12, 2019 15:47
-
-
Save noahfl/0b244346d4ad2501718bbb226be16b1e to your computer and use it in GitHub Desktop.
Implementing the TensorFlow MNIST example without feed_dict
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# In[1]: | |
from tensorflow.examples.tutorials.mnist import input_data | |
import tensorflow as tf | |
# In[9]: | |
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) | |
BATCH_SIZE = 100 | |
# In[10]: | |
#initialize weights | |
def weight_variable(name, shape): | |
initial = tf.get_variable(name, shape=shape, initializer=tf.truncated_normal_initializer(stddev=0.1)) | |
return tf.Variable(initial) | |
#initialize neurons w/ small bias | |
def bias_variable(name, shape): | |
initial = tf.get_variable(name, shape=shape, initializer=tf.constant_initializer(0.1)) | |
return tf.Variable(initial) | |
# In[11]: | |
#convolution | |
def conv2d(x, W): | |
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') | |
#max pooling | |
def max_pool_2x2(x): | |
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], | |
strides=[1, 2, 2, 1], padding='SAME') | |
# In[12]: | |
def input_pipeline(batch_size, test=False): | |
if test: | |
inputs, labels = mnist.test.images, mnist.test.labels | |
#batch_size = len(mnist.test.images) | |
else: | |
inputs, labels = mnist.train.images, mnist.train.labels | |
# min_after_dequeue defines how big a buffer we will randomly sample | |
# from -- bigger means better shuffling but slower start up and more | |
# memory used. | |
# capacity must be larger than min_after_dequeue and the amount larger | |
# determines the maximum we will prefetch. Recommendation: | |
# min_after_dequeue + (num_threads + a small safety margin) * batch_size | |
min_after_dequeue = 1000 | |
capacity = min_after_dequeue + 3 * batch_size | |
if test: | |
example_batch, label_batch = tf.train.batch( | |
[inputs, labels], batch_size=batch_size, capacity=capacity, | |
enqueue_many=True) | |
else: | |
example_batch, label_batch = tf.train.shuffle_batch( | |
[inputs, labels], batch_size=batch_size, num_threads=3, capacity=capacity, | |
min_after_dequeue=min_after_dequeue, enqueue_many=True) | |
return example_batch, label_batch | |
# In[13]: | |
with tf.variable_scope("queue"): | |
#initialize batches (uses RandomShuffleQueue under the hood) | |
batch, labels = input_pipeline(BATCH_SIZE) | |
input_, y_true = batch, tf.cast(labels, tf.float32) | |
# In[14]: | |
with tf.variable_scope("layers"): | |
#first convolutional layer | |
#W_conv1 = weight_variable('W_conv1', [5, 5, 1, 32]) | |
W_conv1 = tf.get_variable('W_conv1', shape=[5, 5, 1, 32], | |
initializer=tf.truncated_normal_initializer(stddev=0.1)) | |
#b_conv1 = bias_variable('b_conv1', [32]) | |
b_conv1 = tf.get_variable('b_conv1', shape=[32], | |
initializer=tf.constant_initializer(0.1)) | |
x_image = tf.reshape(input_, [-1,28,28,1]) | |
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) | |
h_pool1 = max_pool_2x2(h_conv1) | |
#second convolutional layer | |
#W_conv2 = weight_variable('W_conv2', [5, 5, 32, 64]) | |
W_conv2 = tf.get_variable('W_conv2', shape=[5, 5, 32, 64], | |
initializer=tf.truncated_normal_initializer(stddev=0.1)) | |
#b_conv2 = bias_variable('b_conv2',[64]) | |
b_conv2 = tf.get_variable('b_conv2', shape=[64], | |
initializer=tf.constant_initializer(0.1)) | |
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) | |
h_pool2 = max_pool_2x2(h_conv2) | |
#fully connected layer | |
#W_fc1 = weight_variable('W_fc1', [7 * 7 * 64, 1024]) | |
W_fc1 = tf.get_variable('W_fc1', shape=[7 * 7 * 64, 1024], | |
initializer=tf.truncated_normal_initializer(stddev=0.1)) | |
#b_fc1 = bias_variable('b_fc1',[1024]) | |
b_fc1 = tf.get_variable('b_fc1', shape=[1024], | |
initializer=tf.constant_initializer(0.1)) | |
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) | |
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) | |
""" | |
DROPOUT: this is where the magic happens | |
this is the stock version | |
""" | |
#holds probability that neuron's output is kept during dropout | |
keep_prob = 0.5 | |
#TensorFlow's stock dropout function | |
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) | |
#readout layer | |
#W_fc2 = weight_variable('W_fc2', [1024, 10]) | |
W_fc2 = tf.get_variable('W_fc2', shape=[1024,10], | |
initializer=tf.truncated_normal_initializer(stddev=0.1)) | |
#b_fc2 = bias_variable('b_fc2', [10]) | |
b_fc2 = tf.get_variable('b_fc2', shape=[10], | |
initializer=tf.constant_initializer(0.1)) | |
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 | |
# In[15]: | |
with tf.variable_scope("loss"): | |
loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_conv, labels=y_true) | |
loss_op = tf.reduce_mean(loss) | |
# In[16]: | |
with tf.variable_scope("accuracy"): | |
correct_pred = tf.cast(tf.equal(tf.argmax(y_conv,1), tf.argmax(y_true,1)), tf.float32) | |
accuracy = tf.reduce_mean(correct_pred) | |
#acc = tf.metrics.accuracy(y_true, y_conv) | |
accuracy_ = tf.Print(accuracy, data=[accuracy], message="accuracy: ") | |
#accuracy_curr = tf.Print(acc, data=[acc], message="accuracy: ") | |
# In[17]: | |
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss_op, name="train_op") | |
# In[18]: | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
#print(labels.shape) | |
#""" | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
sess.run(accuracy_) | |
batch_old, labels_old = batch, labels | |
for i in range(20000): | |
#_, loss = sess.run(train_op) | |
_, loss, acc = sess.run([train_op, loss_op, accuracy]) | |
# We regularly check the loss | |
if i % 100 == 0: | |
print('iter:%d - loss:%f, acc:%f' % (i, loss, acc)) | |
sess.run(accuracy_) | |
#train_accuracy = accuracy.eval([train_op, loss_op]) | |
#print("step %d: %.4g" % (i, train_accuracy)) | |
sess.run(accuracy_) | |
#run test set | |
input_, y_true = input_pipeline(len(mnist.test.images), test=True) | |
acc = sess.run(accuracy) | |
print("Test accuracy: " + str(acc)) | |
coord.request_stop() | |
coord.join(threads) | |
#""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment