Skip to content

Instantly share code, notes, and snippets.

@chkoar
Created March 30, 2018 09:28
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chkoar/2993a6e3f6bae1887eabc3fa27bb06a6 to your computer and use it in GitHub Desktop.
Save chkoar/2993a6e3f6bae1887eabc3fa27bb06a6 to your computer and use it in GitHub Desktop.
Savable Classifier
import uuid
from pathlib import Path
from sklearn.base import BaseEstimator, ClassifierMixin, clone
from sklearn.datasets import make_classification
from sklearn.externals import joblib
from sklearn.feature_selection import SelectPercentile
from sklearn.model_selection import cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC
class SavableClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, estimator, export_directory):
self.estimator = estimator
self.export_directory = export_directory
def fit(self, X, y):
self.estimator_ = clone(self.estimator).fit(X, y)
self._save()
return self
def predict(self, X):
return self.estimator_.predict(X)
def _save(self):
filename = "{}.pkl".format(uuid.uuid4().hex)
fullpath = Path(self.export_directory, filename)
joblib.dump(self.estimator_, fullpath)
X, y = make_classification(random_state=0)
model = make_pipeline(
SelectPercentile(percentile=0.5),
LinearSVC()
)
sc = SavableClassifier(model, "/temp")
if __name__ == '__main__':
cross_val_score(sc, X, y, cv=10, n_jobs=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment