import sklearn.base as base | |
import sklearn.linear_model as lm | |
import numpy | |
class ByThreshold(base.BaseEstimator, base.ClassifierMixin): | |
def __init__(self, estimator, threshold=0.95): | |
self.threshold = threshold | |
self.estimator = estimator | |
def get_params(self): | |
return { | |
'threshold': self.threshold, | |
'estimator': self.estimator, | |
} | |
def set_params(self, params): | |
self.threshold = params['threshold'] | |
self.estimator = params['estimator'] | |
def fit(self, X, y): | |
return self.estimator.fit(X, y) | |
def predict(self, X): | |
p = self.estimator.predict_proba(X) | |
ix = p.max(axis=1) < self.threshold | |
y = self.estimator.predict(X) | |
y[ix] = -1 | |
return y | |
X = numpy.array([[0, 0, 1],[1, 0, 1],[1, 1, 1]]) | |
y = numpy.array([0,1,1]) | |
clf = ByThreshold(lm.LogisticRegression(), 0.6) | |
clf.fit(X, y) | |
print clf.predict([ | |
[1, 1, 0], | |
[0, 0, 1], | |
]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment