Skip to content

Instantly share code, notes, and snippets.

@neworderofjamie
Created September 7, 2021 09:57
Show Gist options
  • Save neworderofjamie/49d7729962d3c301a34570f852b2b4ba to your computer and use it in GitHub Desktop.
Save neworderofjamie/49d7729962d3c301a34570f852b2b4ba to your computer and use it in GitHub Desktop.
import os, pathlib
import tensorflow as tf
from tensorflow.keras import (models, layers, datasets, callbacks, optimizers,
initializers, regularizers)
from tensorflow.keras.utils import CustomObjectScope
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from glob import glob
import numpy as np
# Learning rate schedule
def schedule(epoch, learning_rate):
if epoch < 30:
return 0.05
elif epoch < 60:
return 0.005
else:
return 0.0005
def initializer(shape, dtype=None):
stddev = np.sqrt(2.0 / float(shape[0] * shape[1] * shape[3]))
return tf.random.normal(shape, dtype=dtype, stddev=stddev)
if __name__ == '__main__':
# for gpu in tf.config.experimental.list_physical_devices('GPU'):
# tf.config.experimental.set_memory_growth(gpu, True)
# READING INPUT FILES
# ===================
batch_size = 128
train_root = pathlib.Path('/mnt/data0/train')
#train_root = pathlib.Path('/mnt/data0/validation')
checkpoint_root = pathlib.Path('./training_checkpoints')
def parse_buffer(buffer):
keys_to_features = {
"image/encoded": tf.io.FixedLenFeature((), tf.string, ''),
"image/class/label": tf.io.FixedLenFeature([], tf.int64, -1)}
parsed = tf.io.parse_single_example(buffer, keys_to_features)
# get label
label = tf.cast(tf.reshape(parsed["image/class/label"], shape=[]), dtype=tf.int32) - 1
# decode image
image = tf.image.decode_jpeg(tf.reshape(parsed["image/encoded"], shape=[]), channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
# resize small
shape = tf.math.maximum(tf.shape(image)[0:2], 224)
image = tf.image.resize(image, shape)
# random crop and horizontal flip
image = tf.image.random_crop(image, [224, 224, 3])
image = tf.image.random_flip_left_right(image)
# zero mean colours
image = image - [0.485, 0.456, 0.406]
return image, label
shards = tf.data.Dataset.list_files(str(train_root/'train*'), shuffle=True)
shards = shards.shuffle(len(shards))
shards = shards.repeat()
ds = shards.interleave(tf.data.TFRecordDataset, cycle_length=4)
ds = ds.shuffle(buffer_size=8192)
ds = ds.map(parse_buffer, num_parallel_calls=16)
ds = ds.batch(batch_size)
ds = ds.prefetch(buffer_size=1)
# Create mirrored strategy
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
# Create L2 regularizer
regularizer = regularizers.l2(0.0001)
# Create, train and evaluate TensorFlow model
tf_model = models.Sequential([
layers.Conv2D(64, 3, padding='same', activation='relu', use_bias=False, input_shape=(224, 224, 3),
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.Dropout(0.3),
layers.Conv2D(64, 3, padding='same', activation='relu', use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.AveragePooling2D(2),
layers.Conv2D(128, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.Dropout(0.4),
layers.Conv2D(128, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.AveragePooling2D(2),
layers.Conv2D(256, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.Dropout(0.4),
layers.Conv2D(256, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.Dropout(0.4),
layers.Conv2D(256, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.AveragePooling2D(2),
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.Dropout(0.4),
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.Dropout(0.4),
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.AveragePooling2D(2),
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.Dropout(0.4),
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.Dropout(0.4),
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False,
kernel_initializer=initializer, kernel_regularizer=regularizer),
layers.AveragePooling2D(2),
layers.Flatten(),
layers.Dense(4096, activation="relu", use_bias=False, kernel_regularizer=regularizer),
layers.Dropout(0.5),
layers.Dense(4096, activation="relu", use_bias=False, kernel_regularizer=regularizer),
layers.Dropout(0.5),
layers.Dense(1000, activation="softmax", use_bias=False, kernel_regularizer=regularizer),
], name='vgg16_imagenet')
optimizer = optimizers.SGD(lr=0.05, momentum=0.9)
tf_model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
initial_epoch = 0
# If there are any existing checkpoints
existing_checkpoints = list(sorted(glob(str(checkpoint_root / 'checkpoint-*.hdf5'))))
if len(existing_checkpoints) > 0:
# Load model from newest checkpoint
newest_checkpoint_file = existing_checkpoints[-1]
with CustomObjectScope({'initializer': initializer}):
tf_model = models.load_model(newest_checkpoint_file)
# Extract epoch number from checkpoint
existing_checkpoint_title = os.path.splitext(os.path.basename(newest_checkpoint_file))[0]
initial_epoch = int(existing_checkpoint_title.split('-')[1])
print("Resuming training at epoch %u from checkpoint %s" % (initial_epoch, newest_checkpoint_file))
callbacks = [callbacks.LearningRateScheduler(schedule),
callbacks.ModelCheckpoint(checkpoint_root / 'checkpoint-{epoch:02d}.hdf5')]
if False:
callbacks.append(callbacks.TensorBoard(log_dir="logs", histogram_freq=1, profile_batch=(4,8)))
tf_model.fit(ds, epochs=100, steps_per_epoch=1281167 // batch_size,
initial_epoch=initial_epoch, callbacks=callbacks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment