Early Stopping Experiment with MNIST
# Early Stopping Experiment with MNIST | |
# http://fouryears.eu/2017/12/05/the-mystery-of-early-stopping/ | |
# | |
# Code adapted from: https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py | |
# By: Konstantin Tretyakov | |
# License: MIT | |
import keras | |
from keras.datasets import mnist | |
from keras.models import Sequential | |
from keras.layers import Dense, Dropout, Flatten | |
from keras.layers import Conv2D, MaxPooling2D | |
from keras.regularizers import l2 | |
from keras import backend as K | |
from sklearn.model_selection import train_test_split | |
import numpy as np | |
import pickle | |
import os | |
from sklearn.metrics import log_loss | |
img_rows, img_cols = 28, 28 | |
num_classes = 10 | |
# the data, shuffled and split between train and test sets | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
if K.image_data_format() == 'channels_first': | |
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) | |
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) | |
input_shape = (1, img_rows, img_cols) | |
else: | |
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) | |
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) | |
input_shape = (img_rows, img_cols, 1) | |
x_train = x_train.astype('float32') | |
x_test = x_test.astype('float32') | |
x_train /= 255 | |
x_test /= 255 | |
print('x_train shape:', x_train.shape) | |
print(x_train.shape[0], 'train samples') | |
print(x_test.shape[0], 'test samples') | |
# convert class vectors to binary class matrices | |
y_train = keras.utils.to_categorical(y_train, num_classes) | |
y_test = keras.utils.to_categorical(y_test, num_classes) | |
def model_bare(): | |
m = Sequential([Conv2D(32, kernel_size=(3, 3), | |
activation='relu', | |
input_shape=input_shape), | |
Conv2D(64, (3, 3), activation='relu'), | |
MaxPooling2D(pool_size=(2, 2)), | |
Flatten(), | |
Dense(128, activation='relu'), | |
Dense(num_classes, activation='softmax')]) | |
m.compile(loss=keras.losses.categorical_crossentropy, | |
optimizer=keras.optimizers.SGD(lr=0.05), | |
metrics=['accuracy']) | |
return m | |
def model_l2(): | |
m = Sequential([Conv2D(32, kernel_size=(3, 3), | |
activation='relu', | |
input_shape=input_shape), | |
Conv2D(64, (3, 3), activation='relu'), | |
MaxPooling2D(pool_size=(2, 2)), | |
Flatten(), | |
Dense(128, activation='relu', kernel_regularizer=l2(0.001)), | |
Dense(num_classes, activation='softmax')]) | |
m.compile(loss=keras.losses.categorical_crossentropy, | |
optimizer=keras.optimizers.SGD(lr=0.05), | |
metrics=['accuracy']) | |
return m | |
def model_dropout(): | |
m = Sequential([Conv2D(32, kernel_size=(3, 3), | |
activation='relu', | |
input_shape=input_shape), | |
Conv2D(64, (3, 3), activation='relu'), | |
MaxPooling2D(pool_size=(2, 2)), | |
Dropout(0.25), | |
Flatten(), | |
Dense(128, activation='relu'), | |
Dropout(0.5), | |
Dense(num_classes, activation='softmax')]) | |
m.compile(loss=keras.losses.categorical_crossentropy, | |
optimizer=keras.optimizers.SGD(lr=0.05), | |
metrics=['accuracy']) | |
return m | |
def fit_partial(model, test_size, batch_size=512, epochs_max=100): | |
m = model() | |
if os.path.exists('best.weights'): os.unlink('best.weights') | |
x_fit, x_stop, y_fit, y_stop = train_test_split(x_train, y_train, test_size=test_size) | |
save_best = keras.callbacks.ModelCheckpoint('best.weights', monitor='val_loss', verbose=1, save_best_only=True) | |
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1) | |
m.fit(x_fit, y_fit, batch_size=batch_size, epochs=epochs_max, verbose=1, | |
validation_data=(x_stop, y_stop), | |
callbacks=[early_stop, save_best]) | |
m.load_weights('best.weights') | |
p = m.predict(x_test, batch_size=batch_size, verbose=0) | |
return log_loss(y_test, p) | |
def fit_full(model, batch_size=512, epochs_max=100): | |
m = model() | |
m.fit(x_train, y_train, batch_size=batch_size, epochs=epochs_max, verbose=1) | |
#return m.evaluate(x_test, y_test, batch_size=batch_size, verbose=0) | |
p = m.predict(x_test, batch_size=batch_size, verbose=0) | |
return log_loss(y_test, p) | |
all_results = {} | |
for m, title in [(model_bare, "Bare"), (model_l2, "L2"), (model_dropout, "Dropout")]: | |
res = [fit_full(m, epochs_max=200)] | |
stops = np.arange(0.05, 1.0, 0.05) | |
for s in stops: | |
res.append(fit_partial(m, s, epochs_max=200)) | |
print(res) | |
all_results[title] = res | |
print(all_results) | |
with open('results.pkl', 'wb') as f: | |
pickle.dump(all_results, f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment