Skip to content

Instantly share code, notes, and snippets.

@zacharynevin
Last active February 13, 2019 04:48
Show Gist options
  • Save zacharynevin/d9c6aa21a2d52299dfc56c12804d6770 to your computer and use it in GitHub Desktop.
Save zacharynevin/d9c6aa21a2d52299dfc56c12804d6770 to your computer and use it in GitHub Desktop.
Creating a TFRecord file for image data
import os
import cv2
import tensorflow as tf
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
writer = tf.python_io.TFRecordWriter("my_dataset.tfrecords")
# Assume that ./data dir contains image files (e.g. ./data/elephant.jpg)
for f in os.listdir("./data"):
img = cv2.imread(os.path.join("./data", f))
height, width, channels = img.shape
example = tf.train.Example(
features = tf.train.Features(
feature = {
"image": _bytes_feature(img.tostring()),
"height": _int64_feature(height),
"width": _int64_feature(width),
"channels": _int64_feature(channels),
"label": _bytes_feature(tf.compat.as_bytes("animals"))
}
)
)
writer.write(example.SerializeToString())
writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment