Created
August 12, 2018 02:02
-
-
Save goraj/70a3bebea8a90aff285abc2efeff4e2c to your computer and use it in GitHub Desktop.
custom oob score function sklearn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Fri Aug 10 21:41:24 2018 | |
@author: goraj | |
""" | |
import numpy as np | |
from sklearn.ensemble import RandomForestClassifier | |
from warnings import warn | |
from sklearn.tree._tree import DTYPE, DOUBLE | |
from sklearn.ensemble.forest import _generate_unsampled_indices | |
from sklearn.utils import check_array | |
def custom_oob_score(self, X, y): | |
"""Compute out-of-bag score""" | |
X = check_array(X, dtype=DTYPE, accept_sparse='csr') | |
n_classes_ = self.n_classes_ | |
n_samples = y.shape[0] | |
oob_decision_function = [] | |
oob_score = 0.0 | |
predictions = [] | |
for k in range(self.n_outputs_): | |
predictions.append(np.zeros((n_samples, n_classes_[k]))) | |
for estimator in self.estimators_: | |
unsampled_indices = _generate_unsampled_indices( | |
estimator.random_state, n_samples) | |
p_estimator = estimator.predict_proba(X[unsampled_indices, :], | |
check_input=False) | |
if self.n_outputs_ == 1: | |
p_estimator = [p_estimator] | |
for k in range(self.n_outputs_): | |
predictions[k][unsampled_indices, :] += p_estimator[k] | |
for k in range(self.n_outputs_): | |
if (predictions[k].sum(axis=1) == 0).any(): | |
warn("Some inputs do not have OOB scores. " | |
"This probably means too few trees were used " | |
"to compute any reliable oob estimates.") | |
decision = (predictions[k] / | |
predictions[k].sum(axis=1)[:, np.newaxis]) | |
print(decision) | |
oob_decision_function.append(decision) | |
oob_score += np.mean(y[:, k] == | |
np.argmax(predictions[k], axis=1), axis=0) | |
if self.n_outputs_ == 1: | |
self.oob_decision_function_ = oob_decision_function[0] | |
else: | |
self.oob_decision_function_ = oob_decision_function | |
self.oob_score_ = oob_score / self.n_outputs_ | |
if __name__ == '__main__': | |
m = RandomForestClassifier(oob_score=True) | |
m._set_oob_score = custom_oob_score.__get__(m, type(m)) | |
# etc | |
m.fit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment