Skip to content

Instantly share code, notes, and snippets.

@kashif kashif/input_fn.py
Last active Mar 2, 2019

Embed
What would you like to do?
TensorFlow 1.x Estimator input pipeline function to read images organised in their class folders
def input_fn(file_pattern, labels,
image_size=(224,224),
shuffle=False,
batch_size=64,
num_epochs=None,
buffer_size=4096,
prefetch_buffer_size=None):
table = tf.contrib.lookup.index_table_from_tensor(mapping=tf.constant(labels))
num_classes = len(labels)
def _map_func(filename):
label = tf.string_split([filename], delimiter=os.sep).values[-2]
image = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize_images(image, size=image_size)
return (image, tf.one_hot(table.lookup(label), num_classes))
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)
if num_epochs is not None and shuffle:
dataset = dataset.apply(
tf.contrib.data.shuffle_and_repeat(buffer_size, num_epochs))
elif shuffle:
dataset = dataset.shuffle(buffer_size)
elif num_epochs is not None:
dataset = dataset.repeat(num_epochs)
dataset = dataset.apply(
tf.contrib.data.map_and_batch(map_func=_map_func,
batch_size=batch_size,
num_parallel_calls=os.cpu_count()))
dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)
return dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.