Last active
May 10, 2022 16:13
-
-
Save lorenzwalthert/51371894225c7b530b66bdabfad60327 to your computer and use it in GitHub Desktop.
Scikit learn ordinal classifier according to https://link.springer.com/chapter/10.1007/3-540-44795-4_13
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# adapted from https://towardsdatascience.com/simple-trick-to-train-an-ordinal-regression-with-any-classifier-6911183d2a3c | |
class OrdinalClassifier: | |
def __init__(self, clf): | |
self.clf = clf | |
self.clfs = {} | |
def fit(self, X, y): | |
self.classes_ = np.sort(np.unique(y)) | |
if self.classes_.shape[0] > 2: | |
for level in self.classes_[: (self.classes_.shape[0] - 1)]: | |
# for each k - 1 ordinal value we fit a binary classification problem | |
binary_y = (y > level).astype(np.uint8) | |
clf = clone(self.clf) | |
clf.fit(X, binary_y) | |
self.clfs[level] = clf | |
def predict_proba(self, X): | |
clfs_predict = {k: self.clfs[k].predict_proba(X) for k in self.clfs} | |
predicted = [] | |
for previous_level, level in zip([None, *self.classes_], self.classes_): | |
if previous_level is None: | |
# V1 = 1 - Pr(y > V1) | |
predicted.append(1 - clfs_predict[level][:, 1]) | |
elif level in clfs_predict: | |
# Vi = Pr(y > Vi-1) - Pr(y > Vi) | |
predicted.append(clfs_predict[previous_level][:, 1] - clfs_predict[level][:, 1]) | |
else: | |
# Vk = Pr(y > Vk-1) | |
predicted.append(clfs_predict[previous_level][:, 1]) | |
decision = np.vstack(predicted).T | |
larger_zero = np.where(decision < 0, 0, decision) | |
probs = np.where(larger_zero > 1, 1, larger_zero) | |
normalized = probs / probs.sum(axis=1).reshape((-1, 1)) | |
return normalized | |
def predict(self, X): | |
return self.classes_[np.argmax(self.predict_proba(X), axis=1)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment