Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save danijar/d11c77c5565482e965d1919291044470 to your computer and use it in GitHub Desktop.
Save danijar/d11c77c5565482e965d1919291044470 to your computer and use it in GitHub Desktop.
TensorFlow Variable-Length Sequence Labelling
# Working example for my blog post at:
# http://danijar.com/variable-sequence-lengths-in-tensorflow/
import functools
import sets
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import rnn
def lazy_property(function):
attribute = '_' + function.__name__
@property
@functools.wraps(function)
def wrapper(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return wrapper
class VariableSequenceLabelling:
def __init__(self, data, target, num_hidden=200, num_layers=3):
self.data = data
self.target = target
self._num_hidden = num_hidden
self._num_layers = num_layers
self.prediction
self.error
self.optimize
@lazy_property
def length(self):
used = tf.sign(tf.reduce_max(tf.abs(self.data), reduction_indices=2))
length = tf.reduce_sum(used, reduction_indices=1)
length = tf.cast(length, tf.int32)
return length
@lazy_property
def prediction(self):
# Recurrent network.
output, _ = rnn.dynamic_rnn(
rnn_cell.GRUCell(self._num_hidden),
self.data,
dtype=tf.float32,
sequence_length=self.length,
)
# Softmax layer.
max_length = int(self.target.get_shape()[1])
num_classes = int(self.target.get_shape()[2])
weight, bias = self._weight_and_bias(self._num_hidden, num_classes)
# Flatten to apply same weights to all time steps.
output = tf.reshape(output, [-1, self._num_hidden])
prediction = tf.nn.softmax(tf.matmul(output, weight) + bias)
prediction = tf.reshape(prediction, [-1, max_length, num_classes])
return prediction
@lazy_property
def cost(self):
# Compute cross entropy for each frame.
cross_entropy = self.target * tf.log(self.prediction)
cross_entropy = -tf.reduce_sum(cross_entropy, reduction_indices=2)
mask = tf.sign(tf.reduce_max(tf.abs(self.target), reduction_indices=2))
cross_entropy *= mask
# Average over actual sequence lengths.
cross_entropy = tf.reduce_sum(cross_entropy, reduction_indices=1)
cross_entropy /= tf.cast(self.length, tf.float32)
return tf.reduce_mean(cross_entropy)
@lazy_property
def optimize(self):
learning_rate = 0.0003
optimizer = tf.train.AdamOptimizer(learning_rate)
return optimizer.minimize(self.cost)
@lazy_property
def error(self):
mistakes = tf.not_equal(
tf.argmax(self.target, 2), tf.argmax(self.prediction, 2))
mistakes = tf.cast(mistakes, tf.float32)
mask = tf.sign(tf.reduce_max(tf.abs(self.target), reduction_indices=2))
mistakes *= mask
# Average over actual sequence lengths.
mistakes = tf.reduce_sum(mistakes, reduction_indices=1)
mistakes /= tf.cast(self.length, tf.float32)
return tf.reduce_mean(mistakes)
@staticmethod
def _weight_and_bias(in_size, out_size):
weight = tf.truncated_normal([in_size, out_size], stddev=0.01)
bias = tf.constant(0.1, shape=[out_size])
return tf.Variable(weight), tf.Variable(bias)
def get_dataset():
"""Read dataset and flatten images."""
dataset = sets.Ocr()
dataset = sets.OneHot(dataset.target, depth=2)(dataset, columns=['target'])
dataset['data'] = dataset.data.reshape(
dataset.data.shape[:-2] + (-1,)).astype(float)
train, test = sets.Split(0.66)(dataset)
return train, test
if __name__ == '__main__':
train, test = get_dataset()
_, length, image_size = train.data.shape
num_classes = train.target.shape[2]
data = tf.placeholder(tf.float32, [None, length, image_size])
target = tf.placeholder(tf.float32, [None, length, num_classes])
model = VariableSequenceLabelling(data, target)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for epoch in range(10):
for _ in range(100):
batch = train.sample(10)
sess.run(model.optimize, {data: batch.data, target: batch.target})
error = sess.run(model.error, {data: test.data, target: test.target})
print('Epoch {:2d} error {:3.1f}%'.format(epoch + 1, 100 * error))
@Engineero
Copy link

@bivanalhar the 2 in tf.argmax(input, 2) is the dimension over which you are executing your argmax function. Since his targets and predictions are of shape [batch_size, num_steps, num_classes], then for each sample and step, you have num_classes probabilities that the network thinks your observation falls into each of your classes, and your targets for each sample and step will be a vector of length num_classes of all zeros with a one in the correct class. You want to reduce this so that you only have a single value associated with whatever index is the maximum for each of these vectors, so for every sample and step, tf.argmax(input, 2) will give you one number representing which class is correct (if input is self.target), or which class the network predicts (if input is self.prediction).

By checking where these values are not equal with tf.not_equal(...) you get a vector of all zeros where your network predicted correctly, and ones where your network made a mistake. You then take the mean of this vector to get an idea of the error. This is equivalent to (but probably more efficient than) manually determining all of your false positives, fp, false negatives, fn, true positives, tp, and true negatives, tn, and doing:

error = (fp + fn) / (fp + fn + tp + tn)

@yoon28
Copy link

yoon28 commented Mar 5, 2018

Does max_length in line 50 have the same value as length in line 108?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment