Last active
November 5, 2021 18:47
-
-
Save mdbecker/5c236c820faa1b3b324cdb184f18a59e to your computer and use it in GitHub Desktop.
sklearn_12052
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Fixes https://github.com/scikit-learn/scikit-learn/issues/12052 | |
CalibratedClassifierGroupCV is a drop in replacment for CalibratedClassifierCV that supports GroupKFold cv. | |
This is based off of https://github.com/scikit-learn/scikit-learn/blob/0.24.1/sklearn/calibration.py. | |
If you are using a different version of sklearn, you can make similar modifications to your version. | |
Example usage: | |
``` | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.model_selection import GroupKFold | |
base_clf = RandomForestClassifier() | |
gfk = GroupKFold(n_splits=5) | |
calibrated_clf = CalibratedClassifierGroupCV(base_estimator=base_clf, method='isotonic', cv=gfk) | |
calibrated_clf.fit(X, y, groups=groups) | |
``` | |
Before you pickle the classifier, you'll also need to monkey patch the class so that it | |
can be unpickled without importing `CalibratedClassifierGroupCV` | |
``` | |
from sklearn.calibration import CalibratedClassifierCV | |
calibrated_clf.__class__ = CalibratedClassifierCV | |
pickle.dump(calibrated_clf, file) | |
``` | |
""" | |
import warnings | |
from contextlib import suppress | |
from functools import partial | |
from inspect import signature | |
import numpy as np | |
from joblib import Parallel | |
from sklearn.base import clone | |
from sklearn.calibration import ( | |
CalibratedClassifierCV, _get_prediction_method, _compute_predictions, | |
_fit_calibrator, _fit_classifier_calibrator_pair) | |
from sklearn.model_selection import check_cv, cross_val_predict | |
from sklearn.pipeline import Pipeline | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.svm import LinearSVC | |
from sklearn.utils import indexable | |
from sklearn.utils.fixes import delayed | |
from sklearn.utils.multiclass import check_classification_targets | |
from sklearn.utils.validation import check_is_fitted, _check_sample_weight | |
class CalibratedClassifierGroupCV(CalibratedClassifierCV): | |
# groups ADDED as a kwarg | |
def fit(self, X, y, sample_weight=None, *, groups=None): | |
"""Fit the calibrated model. | |
Parameters | |
---------- | |
X : array-like of shape (n_samples, n_features) | |
Training data. | |
y : array-like of shape (n_samples,) | |
Target values. | |
sample_weight : array-like of shape (n_samples,), default=None | |
Sample weights. If None, then samples are equally weighted. | |
groups : array-like of shape (n_samples,), default=None | |
Group labels for the samples used while splitting the dataset into | |
train/test set. Only used in conjunction with a "Group" :term:`cv` | |
instance (e.g., :class:`~sklearn.model_selection.GroupKFold`). | |
Returns | |
------- | |
self : object | |
Returns an instance of self. | |
""" | |
check_classification_targets(y) | |
# THIS LINE WAS CHANGED TO ADD groups | |
X, y, groups = indexable(X, y, groups) | |
if self.base_estimator is None: | |
# we want all classifiers that don't expose a random_state | |
# to be deterministic (and we don't want to expose this one). | |
base_estimator = LinearSVC(random_state=0) | |
else: | |
base_estimator = self.base_estimator | |
self.calibrated_classifiers_ = [] | |
if self.cv == "prefit": | |
# `classes_` and `n_features_in_` should be consistent with that | |
# of base_estimator | |
if isinstance(self.base_estimator, Pipeline): | |
check_is_fitted(self.base_estimator[-1]) | |
else: | |
check_is_fitted(self.base_estimator) | |
with suppress(AttributeError): | |
self.n_features_in_ = base_estimator.n_features_in_ | |
self.classes_ = self.base_estimator.classes_ | |
pred_method = _get_prediction_method(base_estimator) | |
n_classes = len(self.classes_) | |
predictions = _compute_predictions(pred_method, X, n_classes) | |
calibrated_classifier = _fit_calibrator( | |
base_estimator, predictions, y, self.classes_, self.method, | |
sample_weight | |
) | |
self.calibrated_classifiers_.append(calibrated_classifier) | |
else: | |
X, y = self._validate_data( | |
X, y, accept_sparse=['csc', 'csr', 'coo'], | |
force_all_finite=False, allow_nd=True | |
) | |
# Set `classes_` using all `y` | |
label_encoder_ = LabelEncoder().fit(y) | |
self.classes_ = label_encoder_.classes_ | |
n_classes = len(self.classes_) | |
# sample_weight checks | |
fit_parameters = signature(base_estimator.fit).parameters | |
supports_sw = "sample_weight" in fit_parameters | |
if sample_weight is not None: | |
sample_weight = _check_sample_weight(sample_weight, X) | |
if not supports_sw: | |
estimator_name = type(base_estimator).__name__ | |
warnings.warn(f"Since {estimator_name} does not support " | |
"sample_weights, sample weights will only be" | |
" used for the calibration itself.") | |
# Check that each cross-validation fold can have at least one | |
# example per class | |
if isinstance(self.cv, int): | |
n_folds = self.cv | |
elif hasattr(self.cv, "n_splits"): | |
n_folds = self.cv.n_splits | |
else: | |
n_folds = None | |
if n_folds and np.any([np.sum(y == class_) < n_folds | |
for class_ in self.classes_]): | |
raise ValueError(f"Requesting {n_folds}-fold " | |
"cross-validation but provided less than " | |
f"{n_folds} examples for at least one class.") | |
cv = check_cv(self.cv, y, classifier=True) | |
if self.ensemble: | |
parallel = Parallel(n_jobs=self.n_jobs) | |
# groups checks | |
split_parameters = signature(cv.split).parameters | |
supports_groups = "groups" in split_parameters | |
cv_split = partial(cv.split, X=X, y=y) | |
if groups is not None: | |
if supports_groups: | |
# ADDED groups | |
cv_split = partial(cv.split, X=X, y=y, groups=groups) | |
else: | |
cv_name = type(cv).__name__ | |
warnings.warn(f"{cv_name} does not support groups " | |
"and will be ignored.") | |
self.calibrated_classifiers_ = parallel( | |
delayed(_fit_classifier_calibrator_pair)( | |
clone(base_estimator), X, y, train=train, test=test, | |
method=self.method, classes=self.classes_, | |
supports_sw=supports_sw, sample_weight=sample_weight) | |
for train, test in cv_split() | |
) | |
else: | |
this_estimator = clone(base_estimator) | |
method_name = _get_prediction_method(this_estimator).__name__ | |
pred_method = partial( | |
cross_val_predict, estimator=this_estimator, X=X, y=y, | |
cv=cv, method=method_name, n_jobs=self.n_jobs | |
) | |
predictions = _compute_predictions(pred_method, X, n_classes) | |
if sample_weight is not None and supports_sw: | |
this_estimator.fit(X, y, sample_weight) | |
else: | |
this_estimator.fit(X, y) | |
calibrated_classifier = _fit_calibrator( | |
this_estimator, predictions, y, self.classes_, self.method, | |
sample_weight | |
) | |
self.calibrated_classifiers_.append(calibrated_classifier) | |
return self |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment