Created
September 7, 2021 09:57
-
-
Save neworderofjamie/49d7729962d3c301a34570f852b2b4ba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, pathlib | |
import tensorflow as tf | |
from tensorflow.keras import (models, layers, datasets, callbacks, optimizers, | |
initializers, regularizers) | |
from tensorflow.keras.utils import CustomObjectScope | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
from glob import glob | |
import numpy as np | |
# Learning rate schedule | |
def schedule(epoch, learning_rate): | |
if epoch < 30: | |
return 0.05 | |
elif epoch < 60: | |
return 0.005 | |
else: | |
return 0.0005 | |
def initializer(shape, dtype=None): | |
stddev = np.sqrt(2.0 / float(shape[0] * shape[1] * shape[3])) | |
return tf.random.normal(shape, dtype=dtype, stddev=stddev) | |
if __name__ == '__main__': | |
# for gpu in tf.config.experimental.list_physical_devices('GPU'): | |
# tf.config.experimental.set_memory_growth(gpu, True) | |
# READING INPUT FILES | |
# =================== | |
batch_size = 128 | |
train_root = pathlib.Path('/mnt/data0/train') | |
#train_root = pathlib.Path('/mnt/data0/validation') | |
checkpoint_root = pathlib.Path('./training_checkpoints') | |
def parse_buffer(buffer): | |
keys_to_features = { | |
"image/encoded": tf.io.FixedLenFeature((), tf.string, ''), | |
"image/class/label": tf.io.FixedLenFeature([], tf.int64, -1)} | |
parsed = tf.io.parse_single_example(buffer, keys_to_features) | |
# get label | |
label = tf.cast(tf.reshape(parsed["image/class/label"], shape=[]), dtype=tf.int32) - 1 | |
# decode image | |
image = tf.image.decode_jpeg(tf.reshape(parsed["image/encoded"], shape=[]), channels=3) | |
image = tf.image.convert_image_dtype(image, tf.float32) | |
# resize small | |
shape = tf.math.maximum(tf.shape(image)[0:2], 224) | |
image = tf.image.resize(image, shape) | |
# random crop and horizontal flip | |
image = tf.image.random_crop(image, [224, 224, 3]) | |
image = tf.image.random_flip_left_right(image) | |
# zero mean colours | |
image = image - [0.485, 0.456, 0.406] | |
return image, label | |
shards = tf.data.Dataset.list_files(str(train_root/'train*'), shuffle=True) | |
shards = shards.shuffle(len(shards)) | |
shards = shards.repeat() | |
ds = shards.interleave(tf.data.TFRecordDataset, cycle_length=4) | |
ds = ds.shuffle(buffer_size=8192) | |
ds = ds.map(parse_buffer, num_parallel_calls=16) | |
ds = ds.batch(batch_size) | |
ds = ds.prefetch(buffer_size=1) | |
# Create mirrored strategy | |
strategy = tf.distribute.MirroredStrategy() | |
with strategy.scope(): | |
# Create L2 regularizer | |
regularizer = regularizers.l2(0.0001) | |
# Create, train and evaluate TensorFlow model | |
tf_model = models.Sequential([ | |
layers.Conv2D(64, 3, padding='same', activation='relu', use_bias=False, input_shape=(224, 224, 3), | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.Dropout(0.3), | |
layers.Conv2D(64, 3, padding='same', activation='relu', use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.AveragePooling2D(2), | |
layers.Conv2D(128, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.Dropout(0.4), | |
layers.Conv2D(128, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.AveragePooling2D(2), | |
layers.Conv2D(256, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.Dropout(0.4), | |
layers.Conv2D(256, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.Dropout(0.4), | |
layers.Conv2D(256, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.AveragePooling2D(2), | |
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.Dropout(0.4), | |
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.Dropout(0.4), | |
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.AveragePooling2D(2), | |
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.Dropout(0.4), | |
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.Dropout(0.4), | |
layers.Conv2D(512, 3, padding="same", activation="relu", use_bias=False, | |
kernel_initializer=initializer, kernel_regularizer=regularizer), | |
layers.AveragePooling2D(2), | |
layers.Flatten(), | |
layers.Dense(4096, activation="relu", use_bias=False, kernel_regularizer=regularizer), | |
layers.Dropout(0.5), | |
layers.Dense(4096, activation="relu", use_bias=False, kernel_regularizer=regularizer), | |
layers.Dropout(0.5), | |
layers.Dense(1000, activation="softmax", use_bias=False, kernel_regularizer=regularizer), | |
], name='vgg16_imagenet') | |
optimizer = optimizers.SGD(lr=0.05, momentum=0.9) | |
tf_model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy']) | |
initial_epoch = 0 | |
# If there are any existing checkpoints | |
existing_checkpoints = list(sorted(glob(str(checkpoint_root / 'checkpoint-*.hdf5')))) | |
if len(existing_checkpoints) > 0: | |
# Load model from newest checkpoint | |
newest_checkpoint_file = existing_checkpoints[-1] | |
with CustomObjectScope({'initializer': initializer}): | |
tf_model = models.load_model(newest_checkpoint_file) | |
# Extract epoch number from checkpoint | |
existing_checkpoint_title = os.path.splitext(os.path.basename(newest_checkpoint_file))[0] | |
initial_epoch = int(existing_checkpoint_title.split('-')[1]) | |
print("Resuming training at epoch %u from checkpoint %s" % (initial_epoch, newest_checkpoint_file)) | |
callbacks = [callbacks.LearningRateScheduler(schedule), | |
callbacks.ModelCheckpoint(checkpoint_root / 'checkpoint-{epoch:02d}.hdf5')] | |
if False: | |
callbacks.append(callbacks.TensorBoard(log_dir="logs", histogram_freq=1, profile_batch=(4,8))) | |
tf_model.fit(ds, epochs=100, steps_per_epoch=1281167 // batch_size, | |
initial_epoch=initial_epoch, callbacks=callbacks |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment