Skip to content

Instantly share code, notes, and snippets.

@eileen-code4fun
Created May 26, 2021 02:39
Show Gist options
  • Save eileen-code4fun/3af7406c536b5e4205b16cc084813dcd to your computer and use it in GitHub Desktop.
Save eileen-code4fun/3af7406c536b5e4205b16cc084813dcd to your computer and use it in GitHub Desktop.
CIFAR10 Data Verification
import matplotlib.pyplot as plt
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dataset = tf.data.TFRecordDataset([GCS_PATH_FOR_DATA + 'train.tfrecord'])
plt.figure(figsize=(10, 10))
for i, example in enumerate(dataset.take(16)):
data = tf.train.Example()
data.ParseFromString(example.numpy())
image = tf.constant(data.features.feature['image'].float_list.value, shape=[32, 32, 3])
label = data.features.feature['label'].int64_list.value[0]
plt.subplot(4, 4, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(image.numpy())
plt.xlabel(class_names[label])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment