Skip to content

Instantly share code, notes, and snippets.

@talolard
Created February 15, 2018 18:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save talolard/3195243f4f40ff7ccdd4ecb287b4c062 to your computer and use it in GitHub Desktop.
Save talolard/3195243f4f40ff7ccdd4ecb287b4c062 to your computer and use it in GitHub Desktop.
Storing and Extracting TFRecords
class BibPreppy(Preppy):
'''
We'll slightly extend to way we right tfrecords to store the id of the book it came from
'''
def __init__(self,tokenizer_fn):
super(BibPreppy,self).__init__(tokenizer_fn)
self.book_map ={}
def sequence_to_tf_example(self, sequence, book_id):
id_list = self.sentance_to_id_list(sequence)
ex = tf.train.SequenceExample()
# A non-sequential feature of our example
sequence_length = len(sequence)
ex.context.feature["length"].int64_list.value.append(sequence_length)
ex.context.feature["book_id"].int64_list.value.append(book_id)
# Feature lists for the two sequential features of our example
fl_tokens = ex.feature_lists.feature_list["tokens"]
for token in id_list:
fl_tokens.feature.add().int64_list.value.append(token)
return ex
@staticmethod
def parse(ex):
'''
Explain to TF how to go froma serialized example back to tensors
:param ex:
:return:
'''
context_features = {
"length": tf.FixedLenFeature([], dtype=tf.int64),
"book_id": tf.FixedLenFeature([], dtype=tf.int64)
}
sequence_features = {
"tokens": tf.FixedLenSequenceFeature([], dtype=tf.int64),
}
# Parse the example (returns a dictionary of tensors)
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=ex,
context_features=context_features,
sequence_features=sequence_features
)
return {"seq": sequence_parsed["tokens"], "length": context_parsed["length"],
"book_id": context_parsed["book_id"]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment