Skip to content

Instantly share code, notes, and snippets.

@gvanhorn38
Last active January 20, 2020 01:19
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gvanhorn38/ac19b85a4f7b5fb9e82e04f4ac6d5566 to your computer and use it in GitHub Desktop.
Save gvanhorn38/ac19b85a4f7b5fb9e82e04f4ac6d5566 to your computer and use it in GitHub Desktop.
Basics of generating a tfrecord file for a dataset.
import tensorflow as tf
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def write_examples(image_data, output_path):
"""
Create a tfrecord file.
Args:
image_data (List[(image_file_path (str), label (int), instance_id (str)]): the data to store in the tfrecord file.
The `image_file_path` should be the full path to the image, accessible by the machine that will be running the
TensorFlow network. The `label` should be an integer in the range [0, number_of_classes). `instance_id` should be
some unique identifier for this example (such as a database identifier).
output_path (str): the full path name for the tfrecord file.
"""
writer = tf.python_io.TFRecordWriter(output_path)
for image_path, label, instance_id in image_data:
example = tf.train.Example(features=tf.train.Features(
feature={
'label': _int64_feature([label]),
'path': _bytes_feature([image_path]),
'instance' : _bytes_feature([instance_id])
}
))
writer.write(example.SerializeToString())
writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment