Skip to content

Instantly share code, notes, and snippets.

@staceysv
Created October 2, 2019 21:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save staceysv/446cc3d3d80b3c7715f4e38eae511a56 to your computer and use it in GitHub Desktop.
Save staceysv/446cc3d3d80b3c7715f4e38eae511a56 to your computer and use it in GitHub Desktop.
finetune Keras model, log precision with custom W&B callback
from keras.preprocessing import image
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
import wandb
from wandb.keras import WandbCallback
from keras_callbacks import PerClassMetrics
def run_experiment(args):
wandb.init()
callbacks = [WandbCallback()]
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1. / 255)
# modify image dims depending on base model
img_width, img_height = 299, 299
if args.initial_model == "resnet":
img_width = resnet_img_dim
img_height = resnet_img_dim
# pretrain to solidify top layers
train_generator = train_datagen.flow_from_directory(
args.train_data,
target_size=(img_width, img_height),
batch_size=args.batch_size,
class_mode='categorical',
follow_links=True)
validation_generator = test_datagen.flow_from_directory(
args.val_data,
target_size=(img_width, img_height),
batch_size=args.batch_size,
class_mode='categorical',
follow_links=True)
model = load_model(args.initial_model, args.fc_size, args.num_classes)
log_model_params(model, wandb.config, args, img_width)
model.fit_generator(
train_generator,
steps_per_epoch=args.num_train // args.batch_size,
epochs=args.pretrain_epochs,
validation_data=validation_generator,
callbacks = callbacks,
validation_steps=args.num_valid // args.batch_size)
# finetune step: uncomment to show all layers of base model
#for i, layer in enumerate(model.layers):
# print i, layer.name
# freeze up to freeze_layer
for layer in model.layers[:args.freeze_layer]:
layer.trainable = False
for layer in model.layers[args.freeze_layer:]:
layer.trainable = True
# recompile model
from keras.optimizers import SGD
model.compile(optimizer=SGD(lr=args.learning_rate, momentum=args.momentum), loss='categorical_crossentropy', metrics=["accuracy"])
# if callback is set, compute image table or per-class precision
# per-class precision: currently this runs twice, using sklearn and DIY code, and includes the time
if args.callback:
total_batches = args.num_valid // args.batch_size
num_valid_batches = min(total_batches, 50)
callbacks.append(PerClassMetrics(validation_generator, num_batches=num_valid_batches, mode=args.callback))
# finish training
model.fit_generator(
train_generator,
steps_per_epoch=args.num_train // args.batch_size,
epochs=args.epochs,
validation_data=validation_generator,
callbacks = callbacks,
validation_steps=args.num_valid // args.batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment