Create a gist now

Instantly share code, notes, and snippets.

TensorFlow Variable-Length Sequence Classification
# Working example for my blog post at:
# http://danijar.com/variable-sequence-lengths-in-tensorflow/
import functools
import sets
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import rnn
def lazy_property(function):
attribute = '_' + function.__name__
@property
@functools.wraps(function)
def wrapper(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return wrapper
class VariableSequenceClassification:
def __init__(self, data, target, num_hidden=200, num_layers=2):
self.data = data
self.target = target
self._num_hidden = num_hidden
self._num_layers = num_layers
self.prediction
self.error
self.optimize
@lazy_property
def length(self):
used = tf.sign(tf.reduce_max(tf.abs(self.data), reduction_indices=2))
length = tf.reduce_sum(used, reduction_indices=1)
length = tf.cast(length, tf.int32)
return length
@lazy_property
def prediction(self):
# Recurrent network.
output, _ = rnn.dynamic_rnn(
rnn_cell.GRUCell(self._num_hidden),
data,
dtype=tf.float32,
sequence_length=self.length,
)
last = self._last_relevant(output, self.length)
# Softmax layer.
weight, bias = self._weight_and_bias(
self._num_hidden, int(self.target.get_shape()[1]))
prediction = tf.nn.softmax(tf.matmul(last, weight) + bias)
return prediction
@lazy_property
def cost(self):
cross_entropy = -tf.reduce_sum(self.target * tf.log(self.prediction))
return cross_entropy
@lazy_property
def optimize(self):
learning_rate = 0.003
optimizer = tf.train.RMSPropOptimizer(learning_rate)
return optimizer.minimize(self.cost)
@lazy_property
def error(self):
mistakes = tf.not_equal(
tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
return tf.reduce_mean(tf.cast(mistakes, tf.float32))
@staticmethod
def _weight_and_bias(in_size, out_size):
weight = tf.truncated_normal([in_size, out_size], stddev=0.01)
bias = tf.constant(0.1, shape=[out_size])
return tf.Variable(weight), tf.Variable(bias)
@staticmethod
def _last_relevant(output, length):
batch_size = tf.shape(output)[0]
max_length = int(output.get_shape()[1])
output_size = int(output.get_shape()[2])
index = tf.range(0, batch_size) * max_length + (length - 1)
flat = tf.reshape(output, [-1, output_size])
relevant = tf.gather(flat, index)
return relevant
if __name__ == '__main__':
# We treat images as sequences of pixel rows.
train, test = sets.Mnist()
_, rows, row_size = train.data.shape
num_classes = train.target.shape[1]
data = tf.placeholder(tf.float32, [None, rows, row_size])
target = tf.placeholder(tf.float32, [None, num_classes])
model = VariableSequenceClassification(data, target)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for epoch in range(10):
for _ in range(100):
batch = train.sample(10)
sess.run(model.optimize, {data: batch.data, target: batch.target})
error = sess.run(model.error, {data: test.data, target: test.target})
print('Epoch {:2d} error {:3.1f}%'.format(epoch + 1, 100 * error))
@MartinThoma
Traceback (most recent call last):
  File "blog_tensorflow_variable_sequence_classification.py", line 90, in <module>
    train, test = sets.Mnist()
AttributeError: 'module' object has no attribute 'Mnist'
@danijar
Owner
danijar commented Jun 29, 2016

@MartinThoma sorry, just noticed your comment. That's weird. If you still have the problem, please open an issue over at danijar/sets.

@lan2720
lan2720 commented Aug 30, 2016

Great post and code! One question is

def _last_relevant(output, length):
        batch_size = tf.shape(output)[0]
        max_length = int(output.get_shape()[1])
        output_size = int(output.get_shape()[2])

I wonder why you use tf.shape(output) to get batch_size but output.get_shape() to get max_length and output_size? I think batch_size = int(output.get_shape()[0]) is ok too. Is there any difference?

@joyspark
joyspark commented Nov 4, 2016

I have a same problem like MartinThoma

@zak27
zak27 commented Nov 17, 2016

Hello, I would like to use your code to variable lunguezza strings. I need help please. My goal and classify each string with a target (target is 0 or 1 then only 2 classes), I do not work with images.
My trainset file is like the following:
1 s1 s2 ... sn
0 s1 s2 ... si
1 s1 s2
where the first column is the target of the sequence represented by 's' values (the sequence can be binary or not).

After I did the reading files, what can I do to adapt your main?
Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment