Skip to content

Instantly share code, notes, and snippets.

@eileen-code4fun
Created May 26, 2021 02:52
Show Gist options
  • Save eileen-code4fun/380dba48803feea1bb74118910622368 to your computer and use it in GitHub Desktop.
Save eileen-code4fun/380dba48803feea1bb74118910622368 to your computer and use it in GitHub Desktop.
Input Datasets for CIFAR10
def extract(example):
data = tf.io.parse_example(
example,
# Schema of the example.
{
'image': tf.io.FixedLenFeature(shape=(32, 32, 3), dtype=tf.float32),
'label': tf.io.FixedLenFeature(shape=(), dtype=tf.int64)
}
)
return data['image'], data['label']
def get_dataset(filename):
return tf.data.TFRecordDataset([GCS_PATH_FOR_DATA + filename]).
map(extract, num_parallel_calls=tf.data.experimental.AUTOTUNE).
shuffle(1024).
batch(128).
cache().
prefetch(tf.data.experimental.AUTOTUNE)
train_dataset = get_dataset('train.tfrecord')
val_dataset = get_dataset('val.tfrecord')
test_dataset = get_dataset('test.tfrecord')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment