Skip to content

Instantly share code, notes, and snippets.

@konstantint
Last active August 30, 2020 00:20
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save konstantint/6e5a8ba99da300f1a7c6984c893dec50 to your computer and use it in GitHub Desktop.
Save konstantint/6e5a8ba99da300f1a7c6984c893dec50 to your computer and use it in GitHub Desktop.
Early Stopping Experiment with MNIST
# Early Stopping Experiment with MNIST
# http://fouryears.eu/2017/12/05/the-mystery-of-early-stopping/
#
# Code adapted from: https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py
# By: Konstantin Tretyakov
# License: MIT
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.regularizers import l2
from keras import backend as K
from sklearn.model_selection import train_test_split
import numpy as np
import pickle
import os
from sklearn.metrics import log_loss
img_rows, img_cols = 28, 28
num_classes = 10
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
def model_bare():
m = Sequential([Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dense(num_classes, activation='softmax')])
m.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(lr=0.05),
metrics=['accuracy'])
return m
def model_l2():
m = Sequential([Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(128, activation='relu', kernel_regularizer=l2(0.001)),
Dense(num_classes, activation='softmax')])
m.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(lr=0.05),
metrics=['accuracy'])
return m
def model_dropout():
m = Sequential([Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Dropout(0.25),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.5),
Dense(num_classes, activation='softmax')])
m.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(lr=0.05),
metrics=['accuracy'])
return m
def fit_partial(model, test_size, batch_size=512, epochs_max=100):
m = model()
if os.path.exists('best.weights'): os.unlink('best.weights')
x_fit, x_stop, y_fit, y_stop = train_test_split(x_train, y_train, test_size=test_size)
save_best = keras.callbacks.ModelCheckpoint('best.weights', monitor='val_loss', verbose=1, save_best_only=True)
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)
m.fit(x_fit, y_fit, batch_size=batch_size, epochs=epochs_max, verbose=1,
validation_data=(x_stop, y_stop),
callbacks=[early_stop, save_best])
m.load_weights('best.weights')
p = m.predict(x_test, batch_size=batch_size, verbose=0)
return log_loss(y_test, p)
def fit_full(model, batch_size=512, epochs_max=100):
m = model()
m.fit(x_train, y_train, batch_size=batch_size, epochs=epochs_max, verbose=1)
#return m.evaluate(x_test, y_test, batch_size=batch_size, verbose=0)
p = m.predict(x_test, batch_size=batch_size, verbose=0)
return log_loss(y_test, p)
all_results = {}
for m, title in [(model_bare, "Bare"), (model_l2, "L2"), (model_dropout, "Dropout")]:
res = [fit_full(m, epochs_max=200)]
stops = np.arange(0.05, 1.0, 0.05)
for s in stops:
res.append(fit_partial(m, s, epochs_max=200))
print(res)
all_results[title] = res
print(all_results)
with open('results.pkl', 'wb') as f:
pickle.dump(all_results, f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment