Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 40 You must be signed in to star a gist
  • Fork 9 You must be signed in to fork a gist
  • Save danijar/d11c77c5565482e965d1919291044470 to your computer and use it in GitHub Desktop.
Save danijar/d11c77c5565482e965d1919291044470 to your computer and use it in GitHub Desktop.
TensorFlow Variable-Length Sequence Labelling
# Working example for my blog post at:
# http://danijar.com/variable-sequence-lengths-in-tensorflow/
import functools
import sets
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import rnn
def lazy_property(function):
attribute = '_' + function.__name__
@property
@functools.wraps(function)
def wrapper(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return wrapper
class VariableSequenceLabelling:
def __init__(self, data, target, num_hidden=200, num_layers=3):
self.data = data
self.target = target
self._num_hidden = num_hidden
self._num_layers = num_layers
self.prediction
self.error
self.optimize
@lazy_property
def length(self):
used = tf.sign(tf.reduce_max(tf.abs(self.data), reduction_indices=2))
length = tf.reduce_sum(used, reduction_indices=1)
length = tf.cast(length, tf.int32)
return length
@lazy_property
def prediction(self):
# Recurrent network.
output, _ = rnn.dynamic_rnn(
rnn_cell.GRUCell(self._num_hidden),
self.data,
dtype=tf.float32,
sequence_length=self.length,
)
# Softmax layer.
max_length = int(self.target.get_shape()[1])
num_classes = int(self.target.get_shape()[2])
weight, bias = self._weight_and_bias(self._num_hidden, num_classes)
# Flatten to apply same weights to all time steps.
output = tf.reshape(output, [-1, self._num_hidden])
prediction = tf.nn.softmax(tf.matmul(output, weight) + bias)
prediction = tf.reshape(prediction, [-1, max_length, num_classes])
return prediction
@lazy_property
def cost(self):
# Compute cross entropy for each frame.
cross_entropy = self.target * tf.log(self.prediction)
cross_entropy = -tf.reduce_sum(cross_entropy, reduction_indices=2)
mask = tf.sign(tf.reduce_max(tf.abs(self.target), reduction_indices=2))
cross_entropy *= mask
# Average over actual sequence lengths.
cross_entropy = tf.reduce_sum(cross_entropy, reduction_indices=1)
cross_entropy /= tf.cast(self.length, tf.float32)
return tf.reduce_mean(cross_entropy)
@lazy_property
def optimize(self):
learning_rate = 0.0003
optimizer = tf.train.AdamOptimizer(learning_rate)
return optimizer.minimize(self.cost)
@lazy_property
def error(self):
mistakes = tf.not_equal(
tf.argmax(self.target, 2), tf.argmax(self.prediction, 2))
mistakes = tf.cast(mistakes, tf.float32)
mask = tf.sign(tf.reduce_max(tf.abs(self.target), reduction_indices=2))
mistakes *= mask
# Average over actual sequence lengths.
mistakes = tf.reduce_sum(mistakes, reduction_indices=1)
mistakes /= tf.cast(self.length, tf.float32)
return tf.reduce_mean(mistakes)
@staticmethod
def _weight_and_bias(in_size, out_size):
weight = tf.truncated_normal([in_size, out_size], stddev=0.01)
bias = tf.constant(0.1, shape=[out_size])
return tf.Variable(weight), tf.Variable(bias)
def get_dataset():
"""Read dataset and flatten images."""
dataset = sets.Ocr()
dataset = sets.OneHot(dataset.target, depth=2)(dataset, columns=['target'])
dataset['data'] = dataset.data.reshape(
dataset.data.shape[:-2] + (-1,)).astype(float)
train, test = sets.Split(0.66)(dataset)
return train, test
if __name__ == '__main__':
train, test = get_dataset()
_, length, image_size = train.data.shape
num_classes = train.target.shape[2]
data = tf.placeholder(tf.float32, [None, length, image_size])
target = tf.placeholder(tf.float32, [None, length, num_classes])
model = VariableSequenceLabelling(data, target)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for epoch in range(10):
for _ in range(100):
batch = train.sample(10)
sess.run(model.optimize, {data: batch.data, target: batch.target})
error = sess.run(model.error, {data: test.data, target: test.target})
print('Epoch {:2d} error {:3.1f}%'.format(epoch + 1, 100 * error))
@jinghuangzhu
Copy link

Why don't you use tensorflow build-in op softmax_cross_entropy_with_logits instead of writing your own?

@danijar
Copy link
Author

danijar commented Jun 29, 2016

@jinghuangzhu You could do that and it's a bit more efficient. You would have a logits property and implement prediction just as tf.nn.softmax(self.logits). You'd need the flatten/unflatten trick as currently used inside the prediction property in order to make the built-in cost function work with sequences though.

@jinghuangzhu
Copy link

Thanks for replying. Did you run this code with initial learning rate set to 0.0003 (very low IMO)? No need for adjusting learning rate because you used AdamOptimizer, right?

@sonalgupta
Copy link

The cost function implementation is prone to NaN values. The in-built softmax_cross_entropy_with_logits is more robust.

http://stackoverflow.com/questions/34240703/difference-between-tensorflow-tf-nn-softmax-and-tf-nn-softmax-cross-entropy-with

@Sraw
Copy link

Sraw commented Nov 23, 2016

The reason that cost function is prone to NaN is log, just add a minimal value: log(self.prediction + 1e-10)

@wazzy
Copy link

wazzy commented Mar 24, 2017

I am trying with targets
target = tf.placeholder(tf.float32, [None, num_classes])
instead of
target = tf.placeholder(tf.float32, [None, length, num_classes])

Please help me in cross_entropy I am making mistake there in dimensions.

@Tingbopku
Copy link

Tingbopku commented Mar 26, 2017

outputs, states = tf.nn.dynamic_rnn(lstm_cell, x_sector, dtype=tf.float32, time_major=True, sequence_length=seqlen)
sq_err = tf.square(outputs-y)
mask = tf.sequence_mask(seqlen, seq_max_len)
cost = tf.reduce_mean(tf.boolean_mask(sq_err, mask))

tf.boolean_mask func can be used to mask the cost function, too.
I wonder which method is more efficient?

@bivanalhar
Copy link

I am quite curious, though, on the meaning of that tf.not_equal(tf.argmax(self.target, 2), tf.argmax(self.prediction, 2))
I mean, what does that 2 stands for? I still don't understand it.
Also, how are 2 labeling sequences be considered equal, in this case?

@dipspb
Copy link

dipspb commented Apr 25, 2017

Hello, is it possible to add dropout wrapper to this sample? Like that:
https://gist.github.com/danijar/61f9226f7ea498abce36187ddaf51ed5#file-blog_tensorflow_sequence_labelling-py-L36

@Engineero
Copy link

@bivanalhar the 2 in tf.argmax(input, 2) is the dimension over which you are executing your argmax function. Since his targets and predictions are of shape [batch_size, num_steps, num_classes], then for each sample and step, you have num_classes probabilities that the network thinks your observation falls into each of your classes, and your targets for each sample and step will be a vector of length num_classes of all zeros with a one in the correct class. You want to reduce this so that you only have a single value associated with whatever index is the maximum for each of these vectors, so for every sample and step, tf.argmax(input, 2) will give you one number representing which class is correct (if input is self.target), or which class the network predicts (if input is self.prediction).

By checking where these values are not equal with tf.not_equal(...) you get a vector of all zeros where your network predicted correctly, and ones where your network made a mistake. You then take the mean of this vector to get an idea of the error. This is equivalent to (but probably more efficient than) manually determining all of your false positives, fp, false negatives, fn, true positives, tp, and true negatives, tn, and doing:

error = (fp + fn) / (fp + fn + tp + tn)

@yoon28
Copy link

yoon28 commented Mar 5, 2018

Does max_length in line 50 have the same value as length in line 108?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment