Last active
November 25, 2016 13:24
-
-
Save lopuhin/2df1d6764744016d2f1c1b5fe6649f78 to your computer and use it in GitHub Desktop.
We want to serialize a model that might have complex preprocessing, lambdas, etc., so it is not directly pickleable. Here we add custom pickle support and serialize only model parameters.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Any | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.linear_model import LogisticRegressionCV | |
from sklearn.pipeline import make_pipeline, FeatureUnion | |
class BaseModel: | |
def __init__(self, **kwargs): | |
self._kwargs = kwargs | |
def get_params(self) -> Dict[str, Any]: | |
raise NotImplementedError | |
def set_params(self, **params) -> None: | |
raise NotImplementedError | |
def fit(self, xs, ys) -> None: | |
raise NotImplementedError | |
def predict(self, xs): | |
raise NotImplementedError | |
def predict_proba(self, xs): | |
raise NotImplementedError | |
def __getstate__(self): | |
params = self.get_params() | |
params['_kwargs'] = self._kwargs | |
return params | |
def __setstate__(self, state): | |
kwargs = state.pop('_kwargs', {}) | |
self.__init__(**kwargs) | |
self.set_params(**state) | |
class DefaultModel(BaseModel): | |
def __init__(self, use_url=False): | |
vectorizers = [] | |
if use_url: | |
self.url_vec = TfidfVectorizer( | |
analyzer='char', | |
ngram_range=(3, 4), | |
preprocessor=self.url_preprocessor, | |
) | |
vectorizers.append(('url', self.url_vec)) | |
else: | |
self.url_vec = None | |
self.default_text_preprocessor = TfidfVectorizer().build_preprocessor() | |
self.text_vec = TfidfVectorizer(preprocessor=self.text_preprocessor) | |
vectorizers.append(('text', self.text_vec)) | |
self.vec = FeatureUnion(vectorizers) | |
self.clf = LogisticRegressionCV() | |
self.pipeline = make_pipeline(self.vec, self.clf) | |
super().__init__(use_url=use_url) | |
def text_preprocessor(self, item): | |
return self.default_text_preprocessor(item['text']) | |
def url_preprocessor(self, item): | |
return item['url'].lower() | |
def fit(self, xs, ys): | |
self.pipeline.fit(xs, ys) | |
def predict(self, xs): | |
return self.pipeline.predict(xs) | |
def predict_proba(self, xs): | |
return self.pipeline.predict_proba(xs) | |
def get_params(self): | |
return { | |
'text_vec_attrs': get_attributes(self.text_vec), | |
'url_vec_attrs': get_attributes(self.url_vec), | |
'clf_attrs': get_attributes(self.clf), | |
} | |
def set_params(self, *, text_vec_attrs, url_vec_attrs, clf_attrs): | |
set_attributes(self.text_vec, text_vec_attrs) | |
set_attributes(self.url_vec, url_vec_attrs) | |
set_attributes(self.clf, clf_attrs) | |
def get_attributes(obj): | |
if isinstance(obj, TfidfVectorizer): | |
return get_tfidf_attributes(obj) | |
return {attr: getattr(obj, attr) for attr in dir(obj) | |
if not attr.startswith('_') and attr.endswith('_')} | |
def set_attributes(obj, attributes): | |
if isinstance(obj, TfidfVectorizer): | |
set_ifidf_attributes(obj, attributes) | |
else: | |
for k, v in attributes.items(): | |
setattr(obj, k, v) | |
def get_tfidf_attributes(obj): | |
return { | |
'_idf_diag': obj._tfidf._idf_diag, | |
'vocabulary_': obj.vocabulary_, | |
} | |
def set_ifidf_attributes(obj, attributes): | |
obj._tfidf._idf_diag = attributes['_idf_diag'] | |
obj.vocabulary_ = attributes['vocabulary_'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment