Skip to content

Instantly share code, notes, and snippets.

@FreeFly19
Last active November 21, 2022 20:22
Show Gist options
  • Save FreeFly19/fa10a04dd08ca0b520a2e661110dc8e7 to your computer and use it in GitHub Desktop.
Save FreeFly19/fa10a04dd08ca0b520a2e661110dc8e7 to your computer and use it in GitHub Desktop.
Tensorflow 2.0 Dataloading with aug and caching
import tensorflow as tf
train_ds = tf.keras.utils.image_dataset_from_directory(
train_dir,
validation_split=0.1,
subset="training",
label_mode="categorical",
#shuffle=True,
seed=123,
image_size=(image_size, image_size),
batch_size=batch_size)
class_names = train_ds.class_names
val_ds = tf.keras.utils.image_dataset_from_directory(
train_dir,
validation_split=0.1,
subset="validation",
label_mode="categorical",
seed=123,
image_size=(image_size, image_size),
batch_size=batch_size)
normalization_layer = tf.keras.layers.Rescaling(1./255)
augmentation = tf.keras.Sequential([
tf.keras.layers.RandomBrightness(0.2, value_range=(0.0, 1.0)),
tf.keras.layers.RandomContrast(0.1),
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomZoom(0.3)
])
def augmentate(x, y):
return augmentation(x, training=True), y
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)).cache().prefetch(buffer_size=AUTOTUNE).map(augmentate)
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)).cache().prefetch(buffer_size=AUTOTUNE)
model.fit_generator(train_ds,
validation_data=val_ds,
epochs=100,
verbose=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment