Skip to content

Instantly share code, notes, and snippets.

@arngarden
Created March 24, 2014 19:28
Show Gist options
  • Save arngarden/9747287 to your computer and use it in GitHub Desktop.
Save arngarden/9747287 to your computer and use it in GitHub Desktop.
Pylearn2 example with termination criteria, momentum and learning rate adjustor
import theano
from pylearn2.models import mlp
from pylearn2.train_extensions import best_params
from pylearn2.training_algorithms import sgd, learning_rule
from pylearn2.utils import serial
from pylearn2.termination_criteria import MonitorBased
from pylearn2.datasets.dense_design_matrix import DenseDesignMatrix
from sklearn.preprocessing import StandardScaler
import numpy as np
from random import randint
import itertools
import os
os.system('rm /tmp/best.pkl')
PIMA_DATASET = './pima-indians-diabetes.data'
scaler = StandardScaler()
class Pima(DenseDesignMatrix):
def __init__(self, X=None, y=None):
X = X
y = y
if X is None:
X = []
y = []
with open(PIMA_DATASET) as f:
for line in f:
features, label = line.rsplit(',', 1)
X.append(map(float, features.split(',')))
if int(label) == 0:
y.append([1, 0])
else:
y.append([0, 1])
X = np.asarray(X)
X = scaler.fit_transform(X)
y = np.asarray(y)
super(Pima, self).__init__(X=X, y=y)
@property
def nr_inputs(self):
return len(self.X[0])
def split(self, prop=.8):
cutoff = int(len(self.y) * prop)
X1, X2 = self.X[:cutoff], self.X[cutoff:]
y1, y2 = self.y[:cutoff], self.y[cutoff:]
return Pima(X1, y1), Pima(X2, y2)
def __len__(self):
return self.X.shape[0]
def __iter__(self):
return itertools.izip_longest(self.X, self.y)
# create datasets
ds_train = Pima()
ds_train, ds_valid = ds_train.split(0.7)
ds_valid, ds_test = ds_valid.split(0.7)
# create sigmoid hidden layer with 20 nodes, init weights in range -0.05 to 0.05 and add
# a bias with value 1
hidden_layer = mlp.Sigmoid(layer_name='hidden', dim=20, irange=.05, init_bias=1.)
# softmax output layer
output_layer = mlp.Softmax(2, 'output', irange=.05)
layers = [hidden_layer, output_layer]
# termination criterion that stops after 50 epochs without
# any increase in misclassification on the validation set
termination_criterion = MonitorBased(channel_name='output_misclass',
N=50, prop_decrease=0.0)
# momentum
initial_momentum = .5
final_momentum = .99
start = 1
saturate = 50
momentum_adjustor = learning_rule.MomentumAdjustor(final_momentum, start, saturate)
momentum_rule = learning_rule.Momentum(initial_momentum)
# learning rate
start = 1
saturate = 50
decay_factor = .1
learning_rate_adjustor = sgd.LinearDecayOverEpoch(start, saturate, decay_factor)
# create neural net
ann = mlp.MLP(layers, nvis=ds_train.nr_inputs)
# create Stochastic Gradient Descent trainer
trainer = sgd.SGD(learning_rate=.05, batch_size=10, monitoring_dataset=ds_valid,
termination_criterion=termination_criterion, learning_rule=momentum_rule)
trainer.setup(ann, ds_train)
# add monitor for saving the model with best score
monitor_save_best = best_params.MonitorBasedSaveBest('output_misclass',
'/tmp/best.pkl')
# train neural net until the termination criterion is true
while True:
trainer.train(dataset=ds_train)
ann.monitor.report_epoch()
ann.monitor()
monitor_save_best.on_monitor(ann, ds_valid, trainer)
if not trainer.continue_learning(ann):
break
momentum_adjustor.on_monitor(ann, ds_valid, trainer)
learning_rate_adjustor.on_monitor(ann, ds_valid, trainer)
# load the best model
ann = serial.load('/tmp/best.pkl')
# function for classifying a input vector
def classify(inp):
inp = np.asarray(inp)
inp.shape = (1, ds_train.nr_inputs)
return np.argmax(ann.fprop(theano.shared(inp, name='inputs')).eval())
# function for calculating and printing the models accuracy on a given dataset
def score(dataset):
nr_correct = 0
for features, label in dataset:
if classify(features) == np.argmax(label):
nr_correct += 1
print '%s/%s correct' % (nr_correct, len(dataset))
print
print 'Accuracy of train set:'
score(ds_train)
print 'Accuracy of validation set:'
score(ds_valid)
print 'Accuracy of test set:'
score(ds_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment