Skip to content

Instantly share code, notes, and snippets.

@aqibsaeed
Last active October 12, 2021 10:29
Show Gist options
  • Save aqibsaeed/d01aeceec0563a0f48b456026f73ef3f to your computer and use it in GitHub Desktop.
Save aqibsaeed/d01aeceec0563a0f48b456026f73ef3f to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.keras import datasets
def get_cifar_10(batch_size = 128, test_batch_size = 128, shuffle_buffer = 1024):
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
num_classes = len(np.unique(train_labels))
train_images = train_images.astype("float32")/255
train_images[:,:,:,0] = (train_images[:,:,:,0] - 0.4914) / 0.2023
train_images[:,:,:,1] = (train_images[:,:,:,1] - 0.4822) / 0.1994
train_images[:,:,:,2] = (train_images[:,:,:,2] - 0.4465) / 0.2010
test_images = test_images.astype("float32")/255
test_images[:,:,:,0] = (test_images[:,:,:,0] - 0.4914) / 0.2023
test_images[:,:,:,1] = (test_images[:,:,:,1] - 0.4822) / 0.1994
test_images[:,:,:,2] = (test_images[:,:,:,2] - 0.4465) / 0.2010
train_labels = np.array(train_labels, dtype=np.int64)
test_labels = np.array(test_labels, dtype=np.int64)
train_data = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_data = train_data.shuffle(shuffle_buffer).map(augment_cifar,
num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(
batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
test_data = tf.data.Dataset.from_tensor_slices((test_images,
test_labels)).batch(test_batch_size).prefetch(tf.data.experimental.AUTOTUNE)
return train_data, test_data, num_classes
def augment_cifar(x, y):
height, width = x.shape[:2]
_x = tf.image.random_flip_left_right(
tf.image.random_crop(
tf.image.pad_to_bounding_box(x, 2, 2, height + 4, width + 4), x.shape))
return (_x, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment