Skip to content

Instantly share code, notes, and snippets.

import tensorflow as tf
from tensorflow.keras import datasets
def get_cifar_10(batch_size = 128, test_batch_size = 128, shuffle_buffer = 1024):
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
num_classes = len(np.unique(train_labels))
train_images = train_images.astype("float32")/255
train_images[:,:,:,0] = (train_images[:,:,:,0] - 0.4914) / 0.2023
train_images[:,:,:,1] = (train_images[:,:,:,1] - 0.4822) / 0.1994