Created November 24, 2017 12:18
Create TFRecords
Create the tfrecord files for a dataset.
A lot of this code comes from the tensorflow inception example, so here is their license:
# Copyright 2016 Google Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from datetime import datetime
import json
import os
from queue import Queue
import random
import sys
import threading
import numpy as np
import tensorflow as tf
def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float_feature(value):
"""Wrapper for inserting float features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _validate_text(text):
"""If text is not str or unicode, then try to convert it to str."""
if isinstance(text, str):
return text
return str(text)
def _convert_to_example(image_example, image_buffer, height, width, colorspace=b'RGB',
channels=3, image_format=b'JPEG'):
"""Build an Example proto for an example.
image_example: dict, an image example
image_buffer: string, JPEG encoding of RGB image
height: integer, image height in pixels
width: integer, image width in pixels
Example proto
# Required
filename = str.encode(image_example['filename'])
image_id = str.encode(image_example['id'])
# Class label for the whole image
image_class = image_example.get('class', {})
class_label = image_class.get('label', 0)
class_text = str.encode(_validate_text(image_class.get('text', '')))
class_conf = image_class.get('conf', 1.)
# Driving related labels
image_drive = image_example.get('drive', {})
drive_steer = image_drive.get('steer', 0.0)
# Objects
image_objects = image_example.get('object', {})
object_count = image_objects.get('count', 0)
# Bounding Boxes
image_bboxes = image_objects.get('bbox', {})
xmin = image_bboxes.get('xmin', [])
xmax = image_bboxes.get('xmax', [])
ymin = image_bboxes.get('ymin', [])
ymax = image_bboxes.get('ymax', [])
bbox_scores = image_bboxes.get('score', [])
bbox_labels = image_bboxes.get('label', [])
bbox_text = map(_validate_text, image_bboxes.get('text', []))
bbox_text = list(map(str.encode, bbox_text))
bbox_label_confs = image_bboxes.get('conf', [])
# Parts
image_parts = image_objects.get('parts', {})
parts_x = image_parts.get('x', [])
parts_y = image_parts.get('y', [])
parts_v = image_parts.get('v', [])
parts_s = image_parts.get('score', [])
# Areas
object_areas = image_objects.get('area', [])
# Ids
object_ids = map(str, image_objects.get('id', []))
object_ids = list(map(str.encode, object_ids))
# Any extra data (e.g. stringified json)
extra_info = str.encode((image_class.get('extra', '')))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(colorspace),
'image/channels': _int64_feature(channels),
'image/format': _bytes_feature(image_format),
'image/filename': _bytes_feature(filename),
'image/id': _bytes_feature(image_id),
'image/encoded': _int64_feature(image_buffer),
'image/extra': _bytes_feature(extra_info),
'image/class/label': _float_feature(class_label),
'image/class/text': _bytes_feature(class_text),
'image/class/conf': _float_feature(class_conf),
'image/drive/steer': _float_feature(steer),
'image/object/bbox/xmin': _float_feature(xmin),
'image/object/bbox/xmax': _float_feature(xmax),
'image/object/bbox/ymin': _float_feature(ymin),
'image/object/bbox/ymax': _float_feature(ymax),
'image/object/bbox/label': _int64_feature(bbox_labels),
'image/object/bbox/text': _bytes_feature(bbox_text),
'image/object/bbox/conf': _float_feature(bbox_label_confs),
'image/object/bbox/score' : _float_feature(bbox_scores),
'image/object/parts/x' : _float_feature(parts_x),
'image/object/parts/y' : _float_feature(parts_y),
'image/object/parts/v' : _int64_feature(parts_v),
'image/object/parts/score' : _float_feature(parts_s),
'image/object/count' : _int64_feature(object_count),
'image/object/area' : _float_feature(object_areas),
'image/object/id' : _bytes_feature(object_ids)
return example
class ImageCoder(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Create a single Session to run all image coding calls.
self._sess = tf.Session()
# Initializes function that converts PNG to JPEG data.
self._png_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_png(self._png_data, channels=3)
self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def png_to_jpeg(self, image_data):
# Convert the image data from png to jpg
feed_dict={self._png_data: image_data})
def decode_jpeg(self, image_data):
# Decode the image data as a jpeg image
image =,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3, "JPEG needs to have height x width x channels"
assert image.shape[2] == 3, "JPEG needs to have 3 channels (RGB)"
return image
def _is_png(filename):
"""Determine if a file contains a PNG format image.
filename: string, path of the image file.
boolean indicating if the image is a PNG.
_, file_extension = os.path.splitext(filename)
return file_extension.lower() == '.png'
def _process_image(filename, coder):
"""Process a single image file.
filename: string, path to an image file e.g., '/path/to/example.JPG'.
coder: instance of ImageCoder to provide TensorFlow image coding utils.
image_buffer: string, JPEG encoding of RGB image.
height: integer, image height in pixels.
width: integer, image width in pixels.
# Read the image file.
image_data = tf.gfile.FastGFile(filename, 'rb').read()
# Clean the dirty data.
if _is_png(filename):
image_data = coder.png_to_jpeg(image_data)
# Decode the RGB JPEG.
image = coder.decode_jpeg(image_data)
# Check that image converted to RGB
assert len(image.shape) == 3
height = image.shape[0]
width = image.shape[1]
assert image.shape[2] == 3
return image_data, height, width
def _process_image_files_batch(coder, thread_index, ranges, name, output_directory,
dataset, num_shards, error_queue):
"""Processes and saves list of images as TFRecord in 1 thread.
coder: instance of ImageCoder to provide TensorFlow image coding utils.
thread_index: integer, unique batch to run index is within [0, len(ranges)).
ranges: list of pairs of integers specifying ranges of each batches to
analyze in parallel.
name: string, unique identifier specifying the data set (e.g. `train` or `test`)
output_directory: string, file path to store the tfrecord files.
dataset: list, a list of image example dicts
num_shards: integer number of shards for this data set.
error_queue: Queue, a queue to place image examples that failed.
# Each thread produces N shards where N = int(num_shards / num_threads).
# For instance, if num_shards = 128, and the num_threads = 2, then the first
# thread would produce shards [0, 64).
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][0],
num_shards_per_batch + 1).astype(int)
num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
counter = 0
error_counter = 0
for s in range(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
shard = thread_index * num_shards_per_batch + s
output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards)
output_file = os.path.join(output_directory, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in files_in_shard:
image_example = dataset[i]
filename = str(image_example['filename'])
if 'encoded' in image_example:
image_buffer = image_example['encoded']
height = image_example['height']
width = image_example['width']
colorspace = image_example['colorspace']
image_format = image_example['format']
num_channels = image_example['channels']
example = _convert_to_example(image_example, image_buffer, height,
width, colorspace, num_channels,
image_buffer, height, width = _process_image(filename, coder)
example = _convert_to_example(image_example, image_buffer, height,
shard_counter += 1
counter += 1
if not counter % 1000:
print('%s [thread %d]: Processed %d of %d images in thread batch, with %d errors.' %
(, thread_index, counter, num_files_in_thread, error_counter))
print('%s [thread %d]: Wrote %d images to %s, with %d errors.' %
(, thread_index, shard_counter, output_file, error_counter))
shard_counter = 0
print('%s [thread %d]: Wrote %d images to %d shards, with %d errors.' %
(, thread_index, counter, num_files_in_thread, error_counter))
def create(dataset, dataset_name, output_directory, num_shards, num_threads, shuffle=True):
"""Create the tfrecord files to be used to train or test a model.
dataset : [{
"filename" : <REQUIRED: path to the image file>,
"id" : <REQUIRED: id of the image>,
"class" : {
"label" : <[0, num_classes)>,
"text" : <text description of class>
"object" : {
"bbox" : {
"xmin" : [],
"xmax" : [],
"ymin" : [],
"ymax" : [],
"label" : []
dataset_name: a name for the dataset
output_directory: path to a directory to write the tfrecord files
num_shards: the number of tfrecord files to create
num_threads: the number of threads to use
shuffle : bool, should the image examples be shuffled or not prior to creating the tfrecords.
list : a list of image examples that failed to process.
# Images in the tfrecords set must be shuffled properly
if shuffle:
# Break all images into batches with a [ranges[i][0], ranges[i][1]].
spacing = np.linspace(0, len(dataset), num_threads + 1).astype(
ranges = []
threads = []
for i in range(len(spacing) - 1):
ranges.append([spacing[i], spacing[i + 1]])
# Launch a thread for each batch.
print('Launching %d threads for spacings: %s' % (num_threads, ranges))
# Create a mechanism for monitoring when all threads are finished.
coord = tf.train.Coordinator()
# Create a generic TensorFlow-based utility for converting all image codings.
coder = ImageCoder()
# A Queue to hold the image examples that fail to process.
error_queue = Queue()
threads = []
for thread_index in range(len(ranges)):
args = (coder, thread_index, ranges, dataset_name, output_directory, dataset,
num_shards, error_queue)
t = threading.Thread(target=_process_image_files_batch, args=args)
# Wait for all the threads to terminate.
print('%s: Finished writing all %d images in data set.' %
(, len(dataset)))
# Collect the errors
errors = []
while not error_queue.empty():
print('%d examples failed.' % (len(errors),))
return errors
def parse_args():
parser = argparse.ArgumentParser(description='Basic statistics on tfrecord files')
parser.add_argument('--dataset_path', dest='dataset_path',
help='Path to the dataset json file.', type=str,
parser.add_argument('--prefix', dest='dataset_name',
help='Prefix for the tfrecords (e.g. `train`, `test`, `val`).', type=str,
parser.add_argument('--output_dir', dest='output_dir',
help='Directory for the tfrecords.', type=str,
parser.add_argument('--shards', dest='num_shards',
help='Number of shards to make.', type=int,
parser.add_argument('--threads', dest='num_threads',
help='Number of threads to make.', type=int,
parser.add_argument('--shuffle', dest='shuffle',
help='Shuffle the records before saving them.',
required=False, action='store_true', default=True)
parsed_args = parser.parse_args()
return parsed_args
def main():
args = parse_args()
with open(args.dataset_path) as f:
dataset = json.load(f)
errors = create(
return errors
if __name__ == '__main__':
