Skip to content

Instantly share code, notes, and snippets.

@lopuhin
Last active November 25, 2016 13:24
Show Gist options
  • Save lopuhin/2df1d6764744016d2f1c1b5fe6649f78 to your computer and use it in GitHub Desktop.
Save lopuhin/2df1d6764744016d2f1c1b5fe6649f78 to your computer and use it in GitHub Desktop.
We want to serialize a model that might have complex preprocessing, lambdas, etc., so it is not directly pickleable. Here we add custom pickle support and serialize only model parameters.
from typing import Dict, Any
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegressionCV
from sklearn.pipeline import make_pipeline, FeatureUnion
class BaseModel:
def __init__(self, **kwargs):
self._kwargs = kwargs
def get_params(self) -> Dict[str, Any]:
raise NotImplementedError
def set_params(self, **params) -> None:
raise NotImplementedError
def fit(self, xs, ys) -> None:
raise NotImplementedError
def predict(self, xs):
raise NotImplementedError
def predict_proba(self, xs):
raise NotImplementedError
def __getstate__(self):
params = self.get_params()
params['_kwargs'] = self._kwargs
return params
def __setstate__(self, state):
kwargs = state.pop('_kwargs', {})
self.__init__(**kwargs)
self.set_params(**state)
class DefaultModel(BaseModel):
def __init__(self, use_url=False):
vectorizers = []
if use_url:
self.url_vec = TfidfVectorizer(
analyzer='char',
ngram_range=(3, 4),
preprocessor=self.url_preprocessor,
)
vectorizers.append(('url', self.url_vec))
else:
self.url_vec = None
self.default_text_preprocessor = TfidfVectorizer().build_preprocessor()
self.text_vec = TfidfVectorizer(preprocessor=self.text_preprocessor)
vectorizers.append(('text', self.text_vec))
self.vec = FeatureUnion(vectorizers)
self.clf = LogisticRegressionCV()
self.pipeline = make_pipeline(self.vec, self.clf)
super().__init__(use_url=use_url)
def text_preprocessor(self, item):
return self.default_text_preprocessor(item['text'])
def url_preprocessor(self, item):
return item['url'].lower()
def fit(self, xs, ys):
self.pipeline.fit(xs, ys)
def predict(self, xs):
return self.pipeline.predict(xs)
def predict_proba(self, xs):
return self.pipeline.predict_proba(xs)
def get_params(self):
return {
'text_vec_attrs': get_attributes(self.text_vec),
'url_vec_attrs': get_attributes(self.url_vec),
'clf_attrs': get_attributes(self.clf),
}
def set_params(self, *, text_vec_attrs, url_vec_attrs, clf_attrs):
set_attributes(self.text_vec, text_vec_attrs)
set_attributes(self.url_vec, url_vec_attrs)
set_attributes(self.clf, clf_attrs)
def get_attributes(obj):
if isinstance(obj, TfidfVectorizer):
return get_tfidf_attributes(obj)
return {attr: getattr(obj, attr) for attr in dir(obj)
if not attr.startswith('_') and attr.endswith('_')}
def set_attributes(obj, attributes):
if isinstance(obj, TfidfVectorizer):
set_ifidf_attributes(obj, attributes)
else:
for k, v in attributes.items():
setattr(obj, k, v)
def get_tfidf_attributes(obj):
return {
'_idf_diag': obj._tfidf._idf_diag,
'vocabulary_': obj.vocabulary_,
}
def set_ifidf_attributes(obj, attributes):
obj._tfidf._idf_diag = attributes['_idf_diag']
obj.vocabulary_ = attributes['vocabulary_']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment