Last active
August 17, 2017 01:15
-
-
Save jnothman/019d594d197c98a3d6192fa0cb19c850 to your computer and use it in GitHub Desktop.
Using a mixin to cache a transform method call in scikit-learn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.feature_extraction.text import CountVectorizer | |
from joblib import Memory | |
from sklearn.base import clone | |
from sklearn.datasets import fetch_20newsgroups | |
class CachedTransformMixin: | |
memory = Memory('/tmp/cache') | |
def transform(self, *args, **kwargs): | |
return self.memory.cache(super(CachedTransformMixin, self).transform)(*args, **kwargs) | |
class CachedCountVectorizer(CachedTransformMixin, CountVectorizer): | |
pass | |
X = fetch_20newsgroups().data | |
est = CachedCountVectorizer().fit(X) | |
%time Xt = est.transform(X) | |
%time Xt = est.transform(X) | |
est = clone(est).fit(X) | |
%time Xt = est.transform(X) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment