Skip to content

Instantly share code, notes, and snippets.

@qbeer
Created September 24, 2019 10:55
Show Gist options
  • Save qbeer/d65f4cfd472e395406dc11c82dee4b63 to your computer and use it in GitHub Desktop.
Save qbeer/d65f4cfd472e395406dc11c82dee4b63 to your computer and use it in GitHub Desktop.
How to use the tensroflow dataset API to create TFRecords files in order to store data in raw format
# Tensorflow throws a bunch of `FutureWarning`s
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import tensorflow as tf
import glob
BUFFER_SIZE = 50000
BATCH_SIZE = 32
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 32
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def create_tf_example_from_image(image_string):
feature = {
'image': _bytes_feature(image_string),
'height': _int64_feature(218), # all files most have the same dimensions in this example, sorry
'width': _int64_feature(178), # -
}
return tf.train.Example(features=tf.train.Features(feature=feature))
image_paths = glob.glob("<files-location>")
with tf.io.TFRecordWriter('celeba.tfrecords') as writer:
for image_path in image_paths:
image_string = open(image_path, 'rb').read()
tf_example = create_tf_example_from_image(image_string)
writer.write(tf_example.SerializeToString())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment