Skip to content

Instantly share code, notes, and snippets.

@jnothman
Last active August 17, 2017 01:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jnothman/019d594d197c98a3d6192fa0cb19c850 to your computer and use it in GitHub Desktop.
Save jnothman/019d594d197c98a3d6192fa0cb19c850 to your computer and use it in GitHub Desktop.
Using a mixin to cache a transform method call in scikit-learn
from sklearn.feature_extraction.text import CountVectorizer
from joblib import Memory
from sklearn.base import clone
from sklearn.datasets import fetch_20newsgroups
class CachedTransformMixin:
memory = Memory('/tmp/cache')
def transform(self, *args, **kwargs):
return self.memory.cache(super(CachedTransformMixin, self).transform)(*args, **kwargs)
class CachedCountVectorizer(CachedTransformMixin, CountVectorizer):
pass
X = fetch_20newsgroups().data
est = CachedCountVectorizer().fit(X)
%time Xt = est.transform(X)
%time Xt = est.transform(X)
est = clone(est).fit(X)
%time Xt = est.transform(X)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment