Skip to content

Instantly share code, notes, and snippets.

@raytroop
Last active September 16, 2018 04:39
Show Gist options
  • Save raytroop/05f4f697c014ee991597c12c6a25aa2a to your computer and use it in GitHub Desktop.
Save raytroop/05f4f697c014ee991597c12c6a25aa2a to your computer and use it in GitHub Desktop.
map and flat_map
import tensorflow as tf
# >1
input = [10, 20, 30]
ds = tf.data.Dataset.from_tensor_slices(input)
ds = ds.flat_map(lambda x: tf.data.Dataset.from_tensor_slices([x, x+1, x+2]))
element = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for _ in range(9):
print(sess.run(element))
# 10
# 11
# 12
# 20
# 21
# 22
# 30
# 31
# 32
# >2
input = [10, 20, 30]
ds = tf.data.Dataset.from_tensor_slices(input)
ds = ds.flat_map(lambda x: tf.data.Dataset.from_tensors([x, x+1, x+2]))
element = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for _ in range(9):
print(sess.run(element))
# [10 11 12]
# [20 21 22]
# [30 31 32]
# OutOfRangeError
# >3
input = [10, 20, 30]
ds = tf.data.Dataset.from_tensor_slices(input)
ds = ds.map(lambda x: [x, x+1, x+2])
element = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for _ in range(9):
print(sess.run(element))
# (10, 11, 12)
# (20, 21, 22)
# (30, 31, 32)
# OutOfRangeError
@raytroop
Copy link
Author

https://stackoverflow.com/a/50533300/8037585
It seems that tf.data.Dataset.from_tensors usually work with flat_map

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

def flat_map_impl(tf_example):
  count = tf.cond(tf.equal(tf_example["a"], 1)),
                  lambda: tf.constant(0, dtype=tf.int64),
                  lambda: tf.constant(2, dtype=tf.int64))

  return tf.data.Dataset.from_tensors(tf_example).repeat(count)

dataset = dataset.flat_map(flat_map_impl)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment