Skip to content

Instantly share code, notes, and snippets.

@MathiasGruber
Created July 16, 2019 03:55
Show Gist options
  • Save MathiasGruber/a582f6d02d8658c02d5762518801f96f to your computer and use it in GitHub Desktop.
Save MathiasGruber/a582f6d02d8658c02d5762518801f96f to your computer and use it in GitHub Desktop.
tf.Dataset pipeline boilerplate
features = {
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64)
}
def parse(record, image_size=256):
# Parse data
parsed = tf.parse_single_example(record, features)
# Decode image
img = parsed['image/encoded']
img = tf.image.decode_jpeg(parsed['image/encoded'], channels=3)
img = tf.cast(img, tf.float32)
# Reshape
img_h = tf.cast(parsed['image/height'], tf.int32)
img_w = tf.cast(parsed['image/width'], tf.int32)
img = tf.reshape(img, [img_h, img_w, 3])
# Augmentation
img = tf.image.resize_image_with_crop_or_pad(img, image_size + 4, image_size + 4)
# Preprocessing for VGG19. Using preprocess_input will mess up, since using keras instead tf.keras
mean_tensor = tf.keras.backend.constant(-np.array([103.939, 116.779, 123.68]))
img = img[..., ::-1] # 'RGB'->'BGR'
img = tf.keras.backend.bias_add(img, mean_tensor)
# Return for autoencoder
return img, img
ds = PipeModeDataset(channel=channel, record_format='TFRecord')
ds = ds.apply(tf.data.experimental.shuffle_and_repeat(SHUFFLE_SIZE, epochs))
ds = ds.map(parse, num_parallel_calls=NUM_PARALLEL_BATCHES)
ds = ds.batch(batch_size)
ds = ds.prefetch(PREFETCH_SIZE)
ds = ds.apply(tf.data.experimental.ignore_errors()) # ignore broken records
return ds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment