Skip to content

Instantly share code, notes, and snippets.

@elibixby
Last active October 25, 2019 21:51
Show Gist options
  • Save elibixby/1c7a2497f96a457130241c59c676ebd4 to your computer and use it in GitHub Desktop.
Save elibixby/1c7a2497f96a457130241c59c676ebd4 to your computer and use it in GitHub Desktop.
Batch read of SequenceExamples
import tensorflow as tf
def make_test_data():
return [
tf.train.SequenceExample(
feature_lists=tf.train.FeatureLists(
feature_list={
'chars': tf.train.FeatureList(
feature=[
tf.train.Feature(
int64_list=tf.train.Int64List(value=word_chars)
) for word_chars in sentence
]
)
}
)
).SerializeToString()
for sentence in [[[5, 10], [5, 10, 20]],
[[0, 1, 2], [2, 1, 0], [0, 1, 2, 3]],
[[5, 10], [5, 10, 20]]]
]
def main():
test_data = make_test_data()
graph = tf.Graph()
with graph.as_default():
sequence_example_binaries = tf.placeholder(shape=[None], dtype=tf.string)
sequence_features = {'chars': tf.VarLenFeature(dtype=tf.int64)}
indices_array = tf.TensorArray(tf.int32, size=tf.shape(sequence_example_binaries)[0])
values_array = tf.TensorArray(tf.int64, size=tf.shape(sequence_example_binaries)[0])
def c(i, ia, va):
return i < tf.shape(sequence_example_binaries)[0]
def b(i, ia, va):
_, seq_dict = tf.parse_single_sequence_example(
serialized=sequence_example_binaries[i],
sequence_features=sequence_features)
sparse_tensor = seq_dict['chars']
batch_dim = tf.tile([i], multiples=tf.shape(sparse_tensor.values))
batch_dim = tf.expand_dims(batch_dim, axis=-1)
new_indices = tf.concat([batch_dim, tf.to_int32(sparse_tensor.indices)], axis=-1)
return i + 1, ia.write(i, new_indices), va.write(i, sparse_tensor.values)
_, indices, values = tf.while_loop(
c,
b,
[0, indices_array, values_array]
)
indices_final = tf.to_int64(indices.concat())
values_final = values.concat()
final = tf.SparseTensor(
indices=indices_final,
values=values_final,
dense_shape=1 + tf.reduce_max(indices_final, axis=0)
)
with tf.Session(graph=graph) as sess:
sess.run(tf.local_variables_initializer())
tf.train.start_queue_runners(sess=sess)
print(sess.run(final, feed_dict={sequence_example_binaries: test_data}))
if __name__ == '__main__':
main()
@jenishah
Copy link

Awesome code!
Thanks a lot for sharing. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment