Skip to content

Instantly share code, notes, and snippets.

@levimcclenny
Created January 25, 2021 18:32
Show Gist options
  • Save levimcclenny/75994c4e889afdcc727fc8d1fd9d845e to your computer and use it in GitHub Desktop.
Save levimcclenny/75994c4e889afdcc727fc8d1fd9d845e to your computer and use it in GitHub Desktop.
Build and train a large dataset using tf.Data and tf.Keras
import tensorflow as tf
import numpy as np
from tensorflow.keras.applications.vgg16 import VGG16
# generate 10k random 224x224x3 tensors (to simulate images)
dataset = tf.random.normal((10000, 224, 224,3))
# generate 10k one-hot labels for categorical cross entropy loss
labels = tf.constant(np.eye(1000)[np.random.choice(1000, 10000)])
# use tf.Data API to build dataset
data = tf.data.Dataset.from_tensor_slices((dataset, labels))
# Batch it
# may need to lower the batch number depending on your system RAM capabilites
data = data.batch(64)
# Import model
model = VGG16()
model.compile(optimizer='adam',
loss="categorical_crossentropy",
metrics=['accuracy'])
# Take a look at the model, its a computational graph that takes about 553Mb in memory and
# has 168m trainable parameters
model.summary()
# fit it, super simple with tf.Data
# Note that this will output garbage because its being trained on random numbers and labels,
# but this workflow is how to get off the ground training a keras model with the tf.Data API
model.fit(data, epochs=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment