Created
January 25, 2021 18:32
-
-
Save levimcclenny/75994c4e889afdcc727fc8d1fd9d845e to your computer and use it in GitHub Desktop.
Build and train a large dataset using tf.Data and tf.Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
from tensorflow.keras.applications.vgg16 import VGG16 | |
# generate 10k random 224x224x3 tensors (to simulate images) | |
dataset = tf.random.normal((10000, 224, 224,3)) | |
# generate 10k one-hot labels for categorical cross entropy loss | |
labels = tf.constant(np.eye(1000)[np.random.choice(1000, 10000)]) | |
# use tf.Data API to build dataset | |
data = tf.data.Dataset.from_tensor_slices((dataset, labels)) | |
# Batch it | |
# may need to lower the batch number depending on your system RAM capabilites | |
data = data.batch(64) | |
# Import model | |
model = VGG16() | |
model.compile(optimizer='adam', | |
loss="categorical_crossentropy", | |
metrics=['accuracy']) | |
# Take a look at the model, its a computational graph that takes about 553Mb in memory and | |
# has 168m trainable parameters | |
model.summary() | |
# fit it, super simple with tf.Data | |
# Note that this will output garbage because its being trained on random numbers and labels, | |
# but this workflow is how to get off the ground training a keras model with the tf.Data API | |
model.fit(data, epochs=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment