Skip to content

Instantly share code, notes, and snippets.

@geyang
Created December 11, 2018 02:23
Show Gist options
  • Save geyang/51846895da7b84a07e89510fafa9ba70 to your computer and use it in GitHub Desktop.
Save geyang/51846895da7b84a07e89510fafa9ba70 to your computer and use it in GitHub Desktop.
tensorflow dataloading scripts
import tensorflow as tf
from termcolor import cprint
from ml_logger import logger
with tf.Session() as sess:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
# Loop forever, alternating between training and validation.
for i in range(10):
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
# Run one pass over the validation dataset.
# logger.split()
sess.run(validation_iterator.initializer)
# print(logger.split())
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})
cprint('done!', 'green')
"""
Initializable iterators take inputs. These inputs are fed into
the initialization operator as a feed_dict.
"""
import tensorflow as tf
from termcolor import cprint
with tf.Session() as sess:
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
it = dataset.make_initializable_iterator()
next_element = it.get_next()
# Initialize an iterator over a dataset with 10 elements.
sess.run(it.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value, "numbers should be correct"
# Initialize the same iterator over a dataset with 100 elements.
sess.run(it.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value, "numbers should be correct"
cprint("done!", 'green')
import tensorflow as tf
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
import tensorflow as tf
from PIL import Image
from termcolor import cprint
string_tensor = tf.convert_to_tensor(['../figures/PointMass-v0.png'])
filename_queue = tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(string_tensor.shape[0]).repeat(10)
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
# A vector of filenames.
filenames = tf.constant(['test-dataset/PointMass-v0.png'])
# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([37])
iterator = tf.data.Dataset \
.from_tensor_slices((filenames, labels)) \
.map(_parse_function) \
.make_one_shot_iterator()
next_element = iterator.get_next()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
for i in range(1):
img, lable = sess.run(next_element)
im = Image.fromarray(img.reshape(img.shape[:2]))
im.show()
import tensorflow as tf
from termcolor import cprint
with tf.Session() as sess:
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value, "numbers should be correct"
cprint("done!", 'green')

List of Data Loader Patterns for Tensorflow

These code come from the Tensorflow documentation the (experimental) new Dataset API.

General types:

  • one_shot
  • initializable
  • reinitializable
  • feedable

I also included some image loading (and processing) example inside. See loading_images.py

-- Ge

"""
Initializable iterators take inputs. These inputs are fed into
the initialization operator as a feed_dict.
"""
import tensorflow as tf
from termcolor import cprint
from ml_logger import logger
with tf.Session() as sess:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
# logger.split()
sess.run(validation_init_op)
# cprint(logger.split(), 'yellow')
for _ in range(50):
sess.run(next_element)
# assert i == value, "numbers should be correct"
cprint("done!", 'green')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment