Skip to content

Instantly share code, notes, and snippets.

@nlintz
Last active April 8, 2016 19:47
Show Gist options
  • Save nlintz/e711b01208bd0fde6a45b1eb5178c191 to your computer and use it in GitHub Desktop.
Save nlintz/e711b01208bd0fde6a45b1eb5178c191 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
"""
Utilities
"""
def orthogonal_initializer(scale=1.1):
''' From Lasagne and Keras. Reference: Saxe et al., http://arxiv.org/abs/1312.6120
'''
def get_orthogonal(shape):
flat_shape = (shape[0], np.prod(shape[1:]))
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
# pick the one with the correct shape
q = u if u.shape == flat_shape else v
q = q.reshape(shape)
initial_val = scale * q[:shape[0], :shape[1]]
return initial_val
def _initializer(shape, dtype=tf.float32):
initial_val = get_orthogonal(shape)
return tf.constant(initial_val, dtype=tf.float32)
return _initializer
def smooth(x, window_len=11, window='hanning'):
"""
Ripped From http://scipy-cookbook.readthedocs.org/items/SignalSmooth.html
"""
s=np.r_[x[window_len-1:0:-1], x, x[-1:-window_len:-1]]
if window == 'flat': # moving average
w = np.ones(window_len, 'd')
else:
w = eval('np.' + window + '(window_len)')
y=np.convolve(w / w.sum(), s, mode='valid')
return y
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
def get_batch(batch_size, which_set="train"):
if which_set == "train":
X, Y = mnist.train.next_batch(batch_size)
if which_set == "test":
X, Y = mnist.test.next_batch(batch_size)
X = X.reshape((batch_size, 28, 28)).astype("float32")
X = X.transpose(1, 0, 2)
Y = Y.astype("int32")
return (X, Y)
"""
GRU Model
"""
class GRU(object):
def __init__(self, input_dim, n_hidden):
self.input_dim = input_dim
self.n_hidden = n_hidden
with tf.variable_scope("weights", initializer=orthogonal_initializer()):
self.W_z = tf.get_variable("W_z", [self.input_dim, self.n_hidden])
self.W_r = tf.get_variable("W_r", [self.input_dim, self.n_hidden])
self.W_h = tf.get_variable("W_h", [self.input_dim, self.n_hidden])
self.U_z = tf.get_variable("U_z", [self.n_hidden, self.n_hidden])
self.U_r = tf.get_variable("U_r", [self.n_hidden, self.n_hidden])
self.U_h = tf.get_variable("U_h", [self.n_hidden, self.n_hidden])
with tf.variable_scope("biases", initializer=tf.constant_initializer(0.0)):
self.b_z = tf.get_variable("b_z", [self.n_hidden])
self.b_r = tf.get_variable("b_r", [self.n_hidden])
self.b_h = tf.get_variable("b_h", [self.n_hidden])
def step(self, h_tm1, x):
z = tf.nn.sigmoid(tf.nn.xw_plus_b(x, self.W_z, self.b_z) +
tf.matmul(h_tm1, self.U_z))
r = tf.nn.sigmoid(tf.nn.xw_plus_b(x, self.W_r, self.b_r) +
tf.matmul(h_tm1, self.U_r))
h = tf.nn.tanh(tf.nn.xw_plus_b(x, self.W_h, self.b_h) +
tf.matmul(tf.mul(h_tm1, r), self.U_h))
h_t = tf.mul((1. - z), h) + tf.mul(z, h_tm1)
return h_t
def initial_state(self, batch_size):
return tf.zeros([batch_size, self.n_hidden])
def __call__(self, X):
batch_size = X.get_shape()[1].value
return tf.scan(self.step, X, initializer=self.initial_state(batch_size))
"""
Logits and Cost ops
"""
def get_logits(input_, input_size, n_hidden, num_classes):
gru = GRU(input_size, n_hidden)
seq_len = input_.get_shape()[0].value
batch_size = input_.get_shape()[1].value
W_out = tf.get_variable("W_out", [n_hidden, num_classes], initializer=orthogonal_initializer())
b_out = tf.get_variable("b_out", [num_classes], initializer=tf.constant_initializer(0.0))
states = tf.split(0, seq_len, gru(input_))
final_state = states[-1]
final_state = tf.squeeze(final_state)
final_state.set_shape([batch_size, n_hidden])
return tf.nn.xw_plus_b(final_state, W_out, b_out)
def get_cost(logits, targets):
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits,
targets)
def train():
bs = 64
n_iter = 1000
X = tf.placeholder(tf.float32, [28, 64, 28])
targets = tf.placeholder(tf.int32, [bs])
logits = get_logits(X, input_size=28, n_hidden=256, num_classes=10)
loss = tf.reduce_mean(get_cost(logits, targets))
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
sess = tf.Session()
writer = tf.train.SummaryWriter("./gru_logs", sess.graph)
sess.run(tf.initialize_all_variables())
tr_losses = []
te_losses = []
tr_accs = []
te_accs = []
for i in range(n_iter):
trX, trY = get_batch(batch_size=bs)
teX, teY = get_batch(batch_size=bs, which_set="test")
sess.run(train_op, feed_dict={X: trX, targets: trY})
tr_loss, tr_logits = sess.run([loss, logits], feed_dict={X: trX, targets: trY})
te_loss, te_logits = sess.run([loss, logits], feed_dict={X: teX, targets: teY}) # super legit way to estimate test error /s
tr_acc = (trY == np.argmax(tr_logits, axis=1)).mean()
te_acc = (teY == np.argmax(te_logits, axis=1)).mean()
tr_losses.append(tr_loss)
te_losses.append(te_loss)
tr_accs.append(tr_acc)
te_accs.append(te_acc)
print "iter: %d, train_loss: %f, test_loss: %f, train_acc: %f, test_acc: %f" % (i, tr_loss, te_loss, tr_acc, te_acc)
plt.subplot(211)
plt.title('cost')
plt.plot(tr_losses)
plt.plot(te_losses, '--')
plt.subplot(212)
plt.title('accuracy')
plt.plot(smooth(tr_accs))
plt.plot(smooth(te_accs), '--')
plt.show()
if __name__ == "__main__":
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment