Skip to content

Instantly share code, notes, and snippets.

@phreeza
Created June 12, 2015 17:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phreeza/ba64c8fd26d331576fe2 to your computer and use it in GitHub Desktop.
Save phreeza/ba64c8fd26d331576fe2 to your computer and use it in GitHub Desktop.
Test the regularisation parameter on the convolutional Layer for keras
# Test the regularisation parameter on the convolutional Layer
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.convolutional import Convolution2D
from keras.layers.core import Dense,Flatten
from keras.utils import np_utils
from keras.regularizers import l2
import numpy as np
np.random.seed(1337)
max_train_samples = 128*4
max_test_samples = 1000
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_test = X_test.reshape(-1,1,28,28)
X_train = X_train.reshape(-1,1,28,28)
nb_classes = len(np.unique(y_train))
Y_train = np_utils.to_categorical(y_train, nb_classes)[:max_train_samples]
Y_test = np_utils.to_categorical(y_test, nb_classes)[:max_test_samples]
model_noreg = Sequential()
model_noreg.add(Convolution2D(1, 1, 20, 20))
model_noreg.add(Flatten())
model_noreg.add(Dense(9*9, 10))
model_noreg.compile(loss='categorical_crossentropy', optimizer='rmsprop')
model_noreg.fit(X_train, Y_train)
score_noreg = model_noreg.evaluate(X_test, Y_test)
score_train_noreg = model_noreg.evaluate(X_train, Y_train)
model = Sequential()
model.add(Convolution2D(1, 1, 20, 20, W_regularizer=l2(0.05)))
model.add(Flatten())
model.add(Dense(9*9, 10))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
model.fit(X_train, Y_train)
score_reg = model.evaluate(X_test, Y_test)
score_train_reg = model.evaluate(X_train, Y_train)
print
print
print "Overfitting without regularisation: %f - %f = %f" % ( score_noreg , score_train_noreg , score_noreg-score_train_noreg)
print "Overfitting with regularisation: %f - %f = %f" % ( score_reg , score_train_reg , score_reg-score_train_reg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment