Skip to content

Instantly share code, notes, and snippets.

@alexwal
Last active December 15, 2020 12:03
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save alexwal/9fca4efb936265d62e389fba5bacd4b3 to your computer and use it in GitHub Desktop.
Save alexwal/9fca4efb936265d62e389fba5bacd4b3 to your computer and use it in GitHub Desktop.
Example of how to handle errors in a tf.data.Dataset input pipeline
import tensorflow as tf
def create_bad_dataset(create_batches=True):
dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4., 8., 16.])
# Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError.
if create_batches:
# Demonstrates that error handling works with map_and_batch
dataset = dataset.apply(tf.contrib.data.map_and_batch(
map_func=lambda x: tf.check_numerics(1. / x, 'error'), batch_size=2))
else:
dataset = dataset.map(lambda x: tf.check_numerics(1. / x, 'error'))
return dataset
def create_bad_dataset_with_filter(create_batches=True):
# Should never error because 0 are filtered and 1 / 0 never computed.
dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4., 8., 16.])
dataset = dataset.prefetch(3)
# filtering
dataset = dataset.filter(lambda x: tf.not_equal(x, 0.))
# Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError.
if create_batches:
# Demonstrates that error handling works with map_and_batch
dataset = dataset.apply(tf.contrib.data.map_and_batch(
map_func=lambda x: tf.check_numerics(1. / x, 'error'), batch_size=4, drop_remainder=False))
else:
dataset = dataset.map(lambda x: tf.check_numerics(1. / x, 'error'))
return dataset
def test_without_error_handling():
dataset = create_bad_dataset()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
while True:
try:
x = sess.run(next_element)
print(x)
except tf.errors.OutOfRangeError:
print('break from loop')
break
def test_catch_error_in_run_loop():
dataset = create_bad_dataset()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
while True:
try:
x = sess.run(next_element)
print(x)
except tf.errors.OutOfRangeError:
print('break from loop')
break
except tf.errors.InvalidArgumentError:
print('Error: InvalidArgumentError')
def test_ignore_errors():
dataset = create_bad_dataset()
# Using `ignore_errors()` will drop the element that causes an error.
dataset = dataset.apply(tf.contrib.data.ignore_errors()) # ==> { 1., 0.5, 0.25, 0.125, 0.0625 }
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
while True:
try:
x = sess.run(next_element)
print(x)
except tf.errors.OutOfRangeError:
print('break from loop')
break
def run():
print('\n--> Testing by catching errors in run loop...')
test_catch_error_in_run_loop()
print('\n--> Testing by catching errors with tf.contrib.data.ignore_errors()...')
test_ignore_errors()
# Uncomment below to run with uncaught exception
print('\n--> Testing without error handling...')
test_without_error_handling()
if __name__ == '__main__':
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment