Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Stacked ensemble of classifiers compatible with scikit-learn
class StackedEnsembleClassifier(sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin):
"""
Ensemble of classifiers. Each classifier must have predict_proba and the head classifier is trained
to predict the output based on the individual classifier outputs. Stratified k-fold cross-validation
is used to get probabilities on held-out data.
"""
def __init__(self, classifiers, head_classifier, num_folds=3):
self.classifiers = classifiers
self.head_classifier = head_classifier
self.num_folds = num_folds
self.classifiers_ = None
self.arbiters_ = None
def fit(self, X, y, **kwargs):
self.classifiers_ = [None for _ in xrange(self.num_folds)]
self.arbiters_ = [None for _ in xrange(self.num_folds)]
for fold, (train_indexes, test_indexes) in enumerate(sklearn.cross_validation.StratifiedKFold(y, self.num_folds)):
# train clones of base classifiers on this data
self.classifiers_[fold] = [sklearn.base.clone(c).fit(X[train_indexes], y[train_indexes], **kwargs) for c in self.classifiers]
# test them all on the held-out part
probas = numpy.hstack([c.predict_proba(X[test_indexes]) for c in self.classifiers_[fold]])
assert probas.shape[0] == len(test_indexes)
assert probas.shape[1] == len(self.classifiers) * 2
self.arbiters_[fold] = sklearn.base.clone(self.head_classifier)
self.arbiters_[fold].fit(probas, y[test_indexes])
def predict_proba(self, X):
arbiter_probas = []
for fold, (classifier_list, arbiter) in enumerate(zip(self.classifiers_, self.arbiters_)):
assert len(classifier_list) == len(self.classifiers)
# individual classifier probabilities
probas = numpy.hstack([c.predict_proba(X) for c in classifier_list])
# arbiter probs
arbiter_probas.append(arbiter.predict_proba(probas))
return numpy.mean(numpy.dstack(arbiter_probas), axis=2)
def predict(self, X):
return numpy.argmax(self.predict_proba(X), axis=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment