Skip to content

Instantly share code, notes, and snippets.

@mathandy
Last active January 13, 2022 06:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mathandy/d0272ced4d867f2be78de9c52d558fd7 to your computer and use it in GitHub Desktop.
Save mathandy/d0272ced4d867f2be78de9c52d558fd7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import tensorflow as tf
import pathlib
data_dir = pathlib.Path.home() / 'datasets' / 'dogs-v-cats' / 'train'
# data_dir = pathlib.Path.home() / 'datasets' / '10flowers' / 'images'
batch_size = 32
img_height = 224
img_width = 224
epochs = 10
use_pretrained_model = False
print('Images:', len(list(data_dir.glob('*.jpg'))))
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
labels='inferred',
validation_split=0.2,
subset='training',
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
labels='inferred',
validation_split=0.2,
subset='validation',
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
class_names = train_ds.class_names
@tf.function
def cast_x_to_float(x, y):
return tf.cast(x, tf.float32), y
train_ds = train_ds.map(cast_x_to_float)
val_ds = val_ds.map(cast_x_to_float)
# `cache()` means the dataset gets stored in memory as it's read so it
# doesn't have to be read multiple times from the hard drive (which is
# slower than the main system memory)
# in an ideal world, the dataset could all be stored in the GPU, but usually
# it's a better idea to reserve that precious GPU memory for other things
# remember, most time a computer spends doing things is not spent on
# calculations but instead on moving data from HDD to mem to CPU to GPU, etc.
train_ds = train_ds.cache()
val_ds = val_ds.cache()
# ATTENTION YUIGO!
# Let's do two things to improve this machine learning pipeline:
# 1. let's replace the three conv layers with a model that's been
# pretrained on imagenet
# 2. let's do some data augmentation to help avoid overfitting
# here's (2) our data augmentation code, notice I only augment the training set
flip = tf.keras.layers.RandomFlip('horizontal')
rotate = tf.keras.layers.RandomRotation(0.2)
@tf.function
def augment(image, label):
p = 0.5
if tf.random.uniform([]) < p:
image = flip(image)
if tf.random.uniform([]) < p:
image = rotate(image)
return image, label
train_ds = train_ds.map(augment) # comment-out this line to try without augmentation
# always remember to normalize data so it's values aren't too large in magnitude
# there are many slight variations to this -- since images are stored as uint8
# values (integers 0-255), usually this just means dividing by 255.
# When using a pretrained model it can help a bit to know how a model was normalized
# when it was originally trained. Sometimes they come with "preprocess" function
# that takes care of this step. We'll use their function, but it's
# probably very similar to the `a_typical_way_to_normalize` function below.
# Also notice, since we're doing this here, I removed the `tf.keras.layers.Rescaling(1. / 255)`
@tf.function
def normalize(x, y):
x = tf.keras.applications.mobilenet_v2.preprocess_input(x)
return x, y
@tf.function
def a_typical_way_to_normalize(x, y):
"""scale images to have values between -0.5 and 0.5 (instead of 0 to 255)"""
x = tf.cast(x, tf.float32) / 255. - 0.5
return x, y
train_ds = train_ds.map(normalize)
val_ds = val_ds.map(normalize)
# here's (1) our new model
# there are many pre-trained models you can easily use with the keras api
# see https://www.tensorflow.org/api_docs/python/tf/keras/applications
if use_pretrained_model:
model = tf.keras.Sequential([
tf.keras.applications.MobileNetV2(
input_shape=(img_height, img_width, 3),
include_top=False, # this means the dense layers at the end are removed
weights='imagenet'
),
# here we add back (untrained) dense layers to replace the ones removed
tf.keras.layers.Flatten(),
# tf.keras.layers.Dense(64, activation='relu'), # this extra layer may help or may hurt
tf.keras.layers.Dense(len(class_names))
])
else:
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'), # this extra layer may help or may hurt
tf.keras.layers.Dense(len(class_names)),
])
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# `prefetch()` loads some images to the gpu to before they're needed,
# which can speed things up, or slow them down -- autotune experiments a
# bit and tries to decide the best setting for your system
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
model.fit(
train_ds,
epochs=epochs,
validation_data=val_ds
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment