Skip to content

Instantly share code, notes, and snippets.

@eileen-code4fun
Last active November 23, 2021 14:00
Show Gist options
  • Save eileen-code4fun/ef7c5c58eba039e52770fe640a45fd08 to your computer and use it in GitHub Desktop.
Save eileen-code4fun/ef7c5c58eba039e52770fe640a45fd08 to your computer and use it in GitHub Desktop.
CIFAR10 Data Preparation
import tensorflow as tf
(train_images, train_labels), (test_images, test_labels) = tf.keras.dataset.cifar10.load_data()
def preprocess(filename, images, labels):
with tf.io.TFRecordWriter(filename) as writer:
for image, label in zip(images, labels):
# Encode the image and label in tf.train.Example.
feature = {
# Normalize the image to range [0, 1].
'image': tf.train.Feature(float_list=tf.train.FloatList(value=(image/255.0).reshape(-1))),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=label))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
preprocess('train.tfrecord', train_images[:40000], train_labels[:40000])
preprocess('val.tfrecord', train_images[40000:], train_labels[40000:])
preprocess('test.tfrecord', test_images, test_labels)
@MarCialRG
Copy link

should be 'datasets' instead of 'dataset'
tf.keras.datasets.cifar10.load_data()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment