Skip to content

Instantly share code, notes, and snippets.

@mmeendez8
Last active January 18, 2019 12:33
Show Gist options
  • Select an option

  • Save mmeendez8/8949b080739804b8703feb9aff72bf7d to your computer and use it in GitHub Desktop.

Select an option

Save mmeendez8/8949b080739804b8703feb9aff72bf7d to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow import keras
# Define parameters
batch_size = 128
# Get data
fashion_mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Create tf dataset
with tf.variable_scope("DataPipe"):
dataset = tf.data.Dataset.from_tensor_slices(train_images)
dataset = dataset.map(lambda x: tf.image.convert_image_dtype([x], dtype=tf.float32))
dataset = dataset.batch(batch_size=batch_size).prefetch(batch_size)
iterator = dataset.make_initializable_iterator()
input_batch = iterator.get_next()
input_batch = tf.reshape(input_batch, shape=[-1, 28, 28, 1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment