Skip to content

Instantly share code, notes, and snippets.

@goraj
Created August 12, 2018 02:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goraj/70a3bebea8a90aff285abc2efeff4e2c to your computer and use it in GitHub Desktop.
Save goraj/70a3bebea8a90aff285abc2efeff4e2c to your computer and use it in GitHub Desktop.
custom oob score function sklearn
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 10 21:41:24 2018
@author: goraj
"""
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from warnings import warn
from sklearn.tree._tree import DTYPE, DOUBLE
from sklearn.ensemble.forest import _generate_unsampled_indices
from sklearn.utils import check_array
def custom_oob_score(self, X, y):
"""Compute out-of-bag score"""
X = check_array(X, dtype=DTYPE, accept_sparse='csr')
n_classes_ = self.n_classes_
n_samples = y.shape[0]
oob_decision_function = []
oob_score = 0.0
predictions = []
for k in range(self.n_outputs_):
predictions.append(np.zeros((n_samples, n_classes_[k])))
for estimator in self.estimators_:
unsampled_indices = _generate_unsampled_indices(
estimator.random_state, n_samples)
p_estimator = estimator.predict_proba(X[unsampled_indices, :],
check_input=False)
if self.n_outputs_ == 1:
p_estimator = [p_estimator]
for k in range(self.n_outputs_):
predictions[k][unsampled_indices, :] += p_estimator[k]
for k in range(self.n_outputs_):
if (predictions[k].sum(axis=1) == 0).any():
warn("Some inputs do not have OOB scores. "
"This probably means too few trees were used "
"to compute any reliable oob estimates.")
decision = (predictions[k] /
predictions[k].sum(axis=1)[:, np.newaxis])
print(decision)
oob_decision_function.append(decision)
oob_score += np.mean(y[:, k] ==
np.argmax(predictions[k], axis=1), axis=0)
if self.n_outputs_ == 1:
self.oob_decision_function_ = oob_decision_function[0]
else:
self.oob_decision_function_ = oob_decision_function
self.oob_score_ = oob_score / self.n_outputs_
if __name__ == '__main__':
m = RandomForestClassifier(oob_score=True)
m._set_oob_score = custom_oob_score.__get__(m, type(m))
# etc
m.fit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment