Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@jnothman
Created March 6, 2018 23:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jnothman/f7415eaf8a5e2f7715f51204dcb3ba70 to your computer and use it in GitHub Desktop.
Save jnothman/f7415eaf8a5e2f7715f51204dcb3ba70 to your computer and use it in GitHub Desktop.
Scikit-learn: Cache an estimator's fit with a mixin
from sklearn.externals import joblib
from sklearn.linear_model import LogisticRegression
from sklearn.feature_selection import RFE
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_classification
memory = joblib.Memory('/tmp')
class MemoryFit:
def fit(self, *args, **kwargs):
fit = memory.cache(super(MemoryFit, self).fit)
cached_self = fit(*args, **kwargs)
vars(self).update(vars(cached_self))
class CachedLogisticRegression(MemoryFit, LogisticRegression):
pass
gs = GridSearchCV(RFE(CachedLogisticRegression()),
{'n_features_to_select': [1, 2, 3]}, verbose=10)
gs.fit(*make_classification())
@hermidalc
Copy link

Hi - is there a way to be able to pass in the memory as a kwarg during instantiation? I tried writing this:

from sklearn.svm import LinearSVC
from sklearn.utils.validation import check_memory

class MemoryFit:
    def fit(self, *args, **kwargs):
        fit = self.memory.cache(super(MemoryFit, self).fit)
        cached_self = fit(*args, **kwargs)
        vars(self).update(vars(cached_self))

class CachedLinearSVC(MemoryFit, LinearSVC):
    def __init__(self, **kwargs):
        self.memory = kwargs.pop('memory')
        super(CachedLinearSVC, self).__init__(**kwargs)

then in my script code:

cachedir = mkdtemp()
memory = Memory(cachedir=cachedir, verbose=0)
grid = GridSearchCV(RFE(CachedLinearSVC(memory=memory, class_weight='balanced')), verbose=10)
grid.fit(X, y)

The first CachedLinearSVC object instantiation works fine but then scikit-learn instantiates CachedLinearSVC again (possibly during cloning?) and it passes in **kwargs not from the original dict in the script but the **kwargs missing memory since I popped it off before sending **kwargs to super:

$ ./test2.py 
Fitting 1 folds for each of 5 candidates, totalling 5 fits
Traceback (most recent call last):
  File "./test2.py", line 86, in <module>
    grid.fit(X, y)
  File "/home/hermidalc/soft/anaconda3/lib/python3.6/site-packages/sklearn/model_selection/_search.py", line 625, in fit
    base_estimator = clone(self.estimator)
  File "/home/hermidalc/soft/anaconda3/lib/python3.6/site-packages/sklearn/base.py", line 62, in clone
    new_object_params[name] = clone(param, safe=False)
  File "/home/hermidalc/soft/anaconda3/lib/python3.6/site-packages/sklearn/base.py", line 50, in clone
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/home/hermidalc/soft/anaconda3/lib/python3.6/site-packages/sklearn/base.py", line 50, in <listcomp>
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/home/hermidalc/soft/anaconda3/lib/python3.6/site-packages/sklearn/base.py", line 50, in clone
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/home/hermidalc/soft/anaconda3/lib/python3.6/site-packages/sklearn/base.py", line 50, in <listcomp>
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/home/hermidalc/soft/anaconda3/lib/python3.6/site-packages/sklearn/base.py", line 62, in clone
    new_object_params[name] = clone(param, safe=False)
  File "/home/hermidalc/soft/anaconda3/lib/python3.6/site-packages/sklearn/base.py", line 63, in clone
    new_object = klass(**new_object_params)
  File "/home/hermidalc/projects/github/hermidalc/nci-lhc-nsclc/luad/svm.py", line 14, in __init__
    self.memory = kwargs.pop('memory')
KeyError: 'memory'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment