Skip to content

Instantly share code, notes, and snippets.

@AFAgarap
Last active September 29, 2019 05:20
Show Gist options
  • Save AFAgarap/17adebb2578ac746c0d57d0ba4eac1ef to your computer and use it in GitHub Desktop.
Save AFAgarap/17adebb2578ac746c0d57d0ba4eac1ef to your computer and use it in GitHub Desktop.
Loading MNIST dataset and creating a tf.data.Dataset object for it. Link to blog: https://towardsdatascience.com/how-can-i-trust-you-fb433a06256c?source=friends_link&sk=0af208dc53be2a326d2407577184686b
(train_features, train_labels), (test_features, test_labels) = tf.keras.datasets.mnist.load_data()
train_features = train_features.reshape(-1, 28, 28, 1)
train_features = train_features.astype('float32')
train_features = train_features / 255.
test_features = test_features.reshape(-1, 28, 28, 1)
test_features = test_features.astype('float32')
test_features = test_features / 255.
train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)
validation_features, test_features, validation_labels, test_labels = train_test_split(test_features,
test_labels,
test_size=0.50,
stratify=test_labels)
train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_labels))
train_dataset = train_dataset.prefetch(BATCH_SIZE * 8)
train_dataset = train_dataset.shuffle(train_features.shape[0])
train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)
validation_dataset = tf.data.Dataset.from_tensor_slices((validation_features, validation_labels))
validation_dataset = validation_dataset.batch((BATCH_SIZE // 4))
test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_labels))
test_dataset = test_dataset.batch((BATCH_SIZE // 4))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment