Skip to content

Instantly share code, notes, and snippets.

@jychstar
Last active January 20, 2017 00:39
Show Gist options
  • Save jychstar/1516daff64c8f9e1a69dbbc0db662660 to your computer and use it in GitHub Desktop.
Save jychstar/1516daff64c8f9e1a69dbbc0db662660 to your computer and use it in GitHub Desktop.
deep learning,assignment2
batch_size = 128
hidden_size = 1024
graph = tf.Graph()
with graph.as_default():
# place holder for train set, constant for other set
X_train = tf.placeholder(tf.float32,shape=(None, 784))
y_train = tf.placeholder(tf.float32, shape=(None, 10))
# Variables.
weights1 = tf.Variable(tf.truncated_normal([784, hidden_size]))
biases1 = tf.Variable(tf.zeros([hidden_size]))
logits1 = tf.matmul(X_train, weights1) + biases1
hidden1 = tf.nn.relu(logits1)
weights = tf.Variable(tf.truncated_normal([hidden_size, 10]))
biases = tf.Variable(tf.zeros([10]))
logits = tf.matmul(hidden1, weights) + biases
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_train, logits=logits))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
# Predictions for the training, validation, and test data.
correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(y_train,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
num_steps = 4001
t0= time()
feed_valid = {X_train: valid_dataset, y_train: valid_labels}
feed_test = {X_train: test_dataset, y_train: test_labels}
with tf.Session(graph=graph) as s:
tf.global_variables_initializer().run()
print("Initialized")
for step in range(num_steps):
offset = (step * batch_size) % (train_labels.shape[0] - batch_size)
batch_data = train_dataset[offset:(offset + batch_size), :]
batch_labels = train_labels[offset:(offset + batch_size), :]
feed_train = {X_train: batch_data, y_train: batch_labels}
_, l, predictions = s.run([optimizer, loss, accuracy], feed_dict=feed_train)
if (step % 500 == 0):
print("Step = {0:4d}, loss = {1:5.2f},Valid accuracy ={2:g}".
format(step, l,accuracy.eval(feed_dict=feed_valid)))
print("Test accuracy: {0:g}".format(accuracy.eval(feed_dict=feed_test)))
print('Time cost:', time()-t0) # 88% at 47 seconds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment