Skip to content

Instantly share code, notes, and snippets.

@psycharo-zz
Created January 23, 2017 14:57
Show Gist options
  • Save psycharo-zz/58717872a3a00284fbbcd9575d265785 to your computer and use it in GitHub Desktop.
Save psycharo-zz/58717872a3a00284fbbcd9575d265785 to your computer and use it in GitHub Desktop.
example of an efficient and simple input pipeline in tensorflow
import threading
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_example(rgb_path, label_path):
# rgb_png = tf.gfile.GFile(rgb_path, 'rb').read()
# label_png = tf.gfile.GFile(label_path, 'rb').read()
rgb_png = open(rgb_path, 'rb').read()
label_png = open(label_path, 'rb').read()
example = tf.train.Example(features=tf.train.Features(feature={
'rgb_png': _bytes_feature(rgb_png),
'label_png': _bytes_feature(label_png)
}))
return example.SerializeToString()
def _convert_dataset_shard(filenames, output_path):
"""A per-thread unit of work for dataset processing
Args:
filenames: a list of (fname0, fname1) tuples
output_path: where to store the records
"""
writer = tf.python_io.TFRecordWriter(output_path)
for rgb_path, label_path in filenames:
writer.write(_convert_example(rgb_path, label_path))
writer.close()
def _to_filenames(raw_data_dir, tag, city, fid):
"""returns a tuple of filenames"""
rgb_path = ('%s/leftImg8bit/%s/%s/%s_leftImg8bit.png' %
(raw_data_dir, tag, city, fid))
label_path = ('%s/gtFine/%s/%s/%s_gtFine_labelIds.png' %
(raw_data_dir, tag, city, fid))
return rgb_path, label_path
def convert_dataset(raw_data_dir, processed_dir, tag, num_threads=1,
max_num_examples=10):
"""Converts the dataset into TFRecords
Args:
raw_data_dir: directory with the unprocessed dataset
processed_dir: where to store TFRecords
tag: "train"|"test"|"val"
num_threads: number of threads to use in parallel
max_num_examples: maximum number of examples to load
"""
cities = sorted(os.listdir('%s/leftImg8bit/%s' % (raw_data_dir, tag)))
fids = [(city, p.rsplit('_', 1)[0])
for city in cities
for p in os.listdir('%s/leftImg8bit/%s/%s' %
(raw_data_dir, tag, city))]
filenames = [_to_filenames(raw_data_dir, tag, city, fid)
for (city, fid) in fids[:max_num_examples]]
filenames_sliced = []
slices = np.linspace(0, len(filenames), num_threads+1).astype(np.int32)
for i in range(num_threads):
filenames_sliced.append(filenames[slices[i]:slices[i+1]])
coord = tf.train.Coordinator()
threads = []
for i in range(num_threads):
output_path = (processed_dir +
'/%s-%02d-of-%02d.tfrecord' % (tag, i, num_threads))
args = (filenames_sliced[i], output_path)
thread = threading.Thread(target=_convert_dataset_shard, args=args)
thread.start()
threads.append(thread)
coord.join(threads)
# inputs
def read_and_decode(filename_queue, height=1024, width=2048):
"""parse TFRecord example"""
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={ 'rgb_png': tf.FixedLenFeature([], tf.string),
'label_png': tf.FixedLenFeature([], tf.string) })
rgb = tf.image.decode_png(features['rgb_png'], channels=3)
label = tf.image.decode_png(features['label_png'], channels=1)
rgb.set_shape([height, width, 3])
label.set_shape([height, width, 1])
return rgb, label
def input_pipeline(file_pattern, batch_size, min_values_dequeue,
num_epochs=None, num_reader_threads=1):
"""Creates input pipeline
Args:
file_pattern: pattern for input files, e.g. 'train-??-of-10.tfrecord'
batch_size: used to determine buffer sizes
min_values_dequeue: the size of buffer: the more, the better shuffling
num_epochs: how many times to go through all the data}
"""
# TODO: is_training flag?
filenames = tf.gfile.Glob(file_pattern)
filename_queue = tf.train.string_input_producer(filenames, num_epochs,
shuffle=False,
name='filename_queue')
example_list = [read_and_decode(filename_queue)
for t in range(num_reader_threads)]
capacity = min_values_dequeue + 64 * batch_size
rgb_batch, label_batch = tf.train.batch_join(example_list, batch_size, capacity)
return rgb_batch, label_batch
tf.reset_default_graph()
train_file_pattern = os.path.join(processed_data_dir, 'train-??-of-??.tfrecord')
rgb_batch, label_batch = input_pipeline(train_file_pattern,
batch_size=8,
min_values_dequeue=3200,
num_epochs=1)
init_op = tf.local_variables_initializer()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment