Instantly share code, notes, and snippets.

Embed
What would you like to do?
Early Stopping Experiment with MNIST
# Early Stopping Experiment with MNIST
# http://fouryears.eu/2017/12/05/the-mystery-of-early-stopping/
#
# Code adapted from: https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py
# By: Konstantin Tretyakov
# License: MIT
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.regularizers import l2
from keras import backend as K
from sklearn.model_selection import train_test_split
import numpy as np
import pickle
import os
from sklearn.metrics import log_loss
img_rows, img_cols = 28, 28
num_classes = 10
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
def model_bare():
m = Sequential([Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dense(num_classes, activation='softmax')])
m.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(lr=0.05),
metrics=['accuracy'])
return m
def model_l2():
m = Sequential([Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(128, activation='relu', kernel_regularizer=l2(0.001)),
Dense(num_classes, activation='softmax')])
m.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(lr=0.05),
metrics=['accuracy'])
return m
def model_dropout():
m = Sequential([Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Dropout(0.25),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.5),
Dense(num_classes, activation='softmax')])
m.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(lr=0.05),
metrics=['accuracy'])
return m
def fit_partial(model, test_size, batch_size=512, epochs_max=100):
m = model()
if os.path.exists('best.weights'): os.unlink('best.weights')
x_fit, x_stop, y_fit, y_stop = train_test_split(x_train, y_train, test_size=test_size)
save_best = keras.callbacks.ModelCheckpoint('best.weights', monitor='val_loss', verbose=1, save_best_only=True)
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)
m.fit(x_fit, y_fit, batch_size=batch_size, epochs=epochs_max, verbose=1,
validation_data=(x_stop, y_stop),
callbacks=[early_stop, save_best])
m.load_weights('best.weights')
p = m.predict(x_test, batch_size=batch_size, verbose=0)
return log_loss(y_test, p)
def fit_full(model, batch_size=512, epochs_max=100):
m = model()
m.fit(x_train, y_train, batch_size=batch_size, epochs=epochs_max, verbose=1)
#return m.evaluate(x_test, y_test, batch_size=batch_size, verbose=0)
p = m.predict(x_test, batch_size=batch_size, verbose=0)
return log_loss(y_test, p)
all_results = {}
for m, title in [(model_bare, "Bare"), (model_l2, "L2"), (model_dropout, "Dropout")]:
res = [fit_full(m, epochs_max=200)]
stops = np.arange(0.05, 1.0, 0.05)
for s in stops:
res.append(fit_partial(m, s, epochs_max=200))
print(res)
all_results[title] = res
print(all_results)
with open('results.pkl', 'wb') as f:
pickle.dump(all_results, f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment