Skip to content

Instantly share code, notes, and snippets.

@jnothman
Created August 23, 2017 03:07
Show Gist options
  • Save jnothman/4807b1b0266613c20ba4d1f88d0f8cf5 to your computer and use it in GitHub Desktop.
Save jnothman/4807b1b0266613c20ba4d1f88d0f8cf5 to your computer and use it in GitHub Desktop.
multilabel decision_function and predict_proba output shapes
import warnings
import sklearn
warnings.simplefilter('ignore')
from sklearn import *
X, y = datasets.make_multilabel_classification()
for clf in [tree.DecisionTreeClassifier(),
neighbors.KNeighborsClassifier(),
neural_network.MLPClassifier(),
multioutput.MultiOutputClassifier(linear_model.LogisticRegression()),
multiclass.OneVsRestClassifier(linear_model.LogisticRegression()),
]:
clf.fit(X[:-10], y[:-10])
for method in ['decision_function', 'predict_proba']:
if not hasattr(clf, method):
continue
s = getattr(clf, method)(X[-3:])
if hasattr(s, 'shape'):
print(type(clf).__name__, method, s.shape)
else:
print(type(clf).__name__, method, [x.shape for x in s])
"""Output:
DecisionTreeClassifier predict_proba [(3, 2), (3, 2), (3, 2), (3, 2), (3, 2)]
KNeighborsClassifier predict_proba [(3, 2), (3, 2), (3, 2), (3, 2), (3, 2)]
MLPClassifier predict_proba (3, 5)
MultiOutputClassifier predict_proba [(3, 2), (3, 2), (3, 2), (3, 2), (3, 2)]
OneVsRestClassifier decision_function (3, 5)
OneVsRestClassifier predict_proba (3, 5)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment