Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Early Stopping Experiment with MNIST
# Early Stopping Experiment with MNIST
# Code adapted from:
# By: Konstantin Tretyakov
# License: MIT
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
from sklearn.model_selection import train_test_split
import numpy as np
import pickle
import os
img_rows, img_cols = 28, 28
num_classes = 10
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
def make_model():
m = Sequential([Conv2D(32, kernel_size=(3, 3),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Dense(128, activation='relu'),
Dense(num_classes, activation='softmax')])
return m
def fit_partial(test_size, batch_size=512, epochs_max=100):
m = make_model()
if os.path.exists('best.weights'): os.unlink('best.weights')
x_fit, x_stop, y_fit, y_stop = train_test_split(x_train, y_train, test_size=test_size)
save_best = keras.callbacks.ModelCheckpoint('best.weights', monitor='val_loss', verbose=1, save_best_only=True)
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1), y_fit, batch_size=batch_size, epochs=epochs_max, verbose=1,
validation_data=(x_stop, y_stop),
callbacks=[early_stop, save_best])
return m.evaluate(x_test, y_test, batch_size=batch_size, verbose=0)
def fit_full(batch_size=512, epochs_max=100):
m = make_model(), y_train, batch_size=batch_size, epochs=epochs_max, verbose=1)
return m.evaluate(x_test, y_test, batch_size=batch_size, verbose=0)
res = [fit_full()]
stops = np.arange(0.05, 1, 0.05)
for s in stops:
with open('results.pkl', 'wb') as f:
pickle.dump(res, f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment