Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
TensorFlow Sequence Classification
# Example for my blog post at:
import functools
import sets
import tensorflow as tf
def lazy_property(function):
attribute = '_' + function.__name__
def wrapper(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return wrapper
class SequenceClassification:
def __init__(self, data, target, dropout, num_hidden=200, num_layers=3): = data = target
self.dropout = dropout
self._num_hidden = num_hidden
self._num_layers = num_layers
def prediction(self):
# Recurrent network.
network = tf.contrib.rnn.GRUCell(self._num_hidden)
network = tf.contrib.rnn.DropoutWrapper(
network, output_keep_prob=self.dropout)
network = tf.contrib.rnn.MultiRNNCell([network] * self._num_layers)
output, _ = tf.nn.dynamic_rnn(network,, dtype=tf.float32)
# Select last output.
output = tf.transpose(output, [1, 0, 2])
last = tf.gather(output, int(output.get_shape()[0]) - 1)
# Softmax layer.
weight, bias = self._weight_and_bias(
self._num_hidden, int([1]))
prediction = tf.nn.softmax(tf.matmul(last, weight) + bias)
return prediction
def cost(self):
cross_entropy = -tf.reduce_sum( * tf.log(self.prediction))
return cross_entropy
def optimize(self):
learning_rate = 0.003
optimizer = tf.train.RMSPropOptimizer(learning_rate)
return optimizer.minimize(self.cost)
def error(self):
mistakes = tf.not_equal(
tf.argmax(, 1), tf.argmax(self.prediction, 1))
return tf.reduce_mean(tf.cast(mistakes, tf.float32))
def _weight_and_bias(in_size, out_size):
weight = tf.truncated_normal([in_size, out_size], stddev=0.01)
bias = tf.constant(0.1, shape=[out_size])
return tf.Variable(weight), tf.Variable(bias)
def main():
# We treat images as sequences of pixel rows.
train, test = sets.Mnist()
_, rows, row_size =
num_classes =[1]
data = tf.placeholder(tf.float32, [None, rows, row_size])
target = tf.placeholder(tf.float32, [None, num_classes])
dropout = tf.placeholder(tf.float32)
model = SequenceClassification(data, target, dropout)
sess = tf.Session()
for epoch in range(10):
for _ in range(100):
batch = train.sample(10), {
data:, target:, dropout: 0.5})
error =, {
data:, target:, dropout: 1})
print('Epoch {:2d} error {:3.1f}%'.format(epoch + 1, 100 * error))
if __name__ == '__main__':

wirth6 commented Dec 19, 2016 edited

If it's not a problem, I'd have a question about the learn rate: does the value 0.003 being wired in mean that the learn rate will be the same in every epoch? Also, I'm fairly new to python and tensorflow, so I don't quite understand what @lazy_property actually does. Can anyone tell me where to read about constructs like this?

wirth6 commented Jan 3, 2017

And another thing I've been wondering about: shouldn't the data in the following line be instead?
output, _ = rnn.dynamic_rnn(network, data, dtype=tf.float32)


danijar commented Feb 28, 2017

@wirth6 Sorry for the taking so long. @lazy_property causes the method to act like a property, so you can access it without parentheses. Moreover, the function is only evaluated once, when it's accessed for the first time. The result is stored an directly returned for later accesses. This is useful since we don't want to create this part of the TensorFlow graph once, but access the resulting tensor multiple times. For more information, please refer to my post Structuring Your TensorFlow Models.

Regarding your second question, you're right. It unexpectedly worked anyways since the __name__ == '__main__' block created a global data. In the updated example, I moved that code into a main() function so that can't happen anymore.

wazzy commented Mar 23, 2017 edited

Great tutorial....
One problem I was facing when I was trying on other data set is when I added sequence_length to dynamic_rnn it was not training.
output, _ = tf.nn.dynamic_rnn(network,, sequence_length=self.seq_len, dtype=tf.float32)
Can you please suggest me what is going wrong.

herleeyandi commented May 10, 2017 edited

Hello I have a problem in line train, test = sets.Mnist() I found that I must install sets first as your product in here . @danijar Can you tell me how to config and install it?, Please I need help. I have successfully install it but still failed in import the modul Mnist() . I hope that you can create the requirement first for your tutorial so I can know what should I install before start to learn. The error is ImportError: cannot import name Mnist while I am using sudo pip install sets==0.3.2 to install it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment