Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Created October 21, 2013 22:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ogrisel/7091781 to your computer and use it in GitHub Desktop.
Save ogrisel/7091781 to your computer and use it in GitHub Desktop.
Prototype for JSON-able IO of scikit-learn parameterized models (without the fitted parameters for now).
import json
import importlib
def model_to_type(model):
type_ = type(model)
return type_.__module__ + '.' + type_.__name__
def model_from_type(name):
mod_name, cls_name = name.rsplit('.', 1)
mod = importlib.import_module(mod_name)
return getattr(mod, cls_name)()
def describe_model(model, random_state_seed=None):
if hasattr(model, 'get_params'):
model_description = dict(_type=model_to_type(model))
model_params = model.get_params(deep=False)
for k, v in model_params.items():
if k == 'random_state' and hasattr(v, 'randint'):
if random_state_seed is None:
# Export the random_state instance as the next int seed it
# can generate and reseed the original rng with to ensure
# that the model description is an accurate description of
# the RNG at model at end of the export process. This way
# describe_model is idempotent by default while not
# including the full internal RNG state in the description.
next_seed = v.randint(np.iinfo(np.int).max)
v.seed(next_seed)
model_description[k] = next_seed
else:
# FIX the rng value to the arbitrary provided seed
model_description[k] = random_state_seed
else:
# Recursive call to describe nested models
model_description[k] = describe_model(v)
return model_description
else:
# Assume literal parameter that maps to a pure data
# representation
return model
def construct_model(model_description, random_state_seed=None):
if (not hasattr(model_description, 'keys')
or '_type' not in model_description.keys()):
# Assume a literal model parameter
return model_description
params = model_description.copy()
model = model_from_type(params.pop('_type'))
for k, v in params.items():
if k == 'random_state' and random_state_seed is not None:
params[k] = random_state_seed
else:
params[k] = construct_model(v)
model.set_params(**params)
return model
if __name__ == "__main__":
import numpy as np
from sklearn.ensemble import BaggingClassifier
from sklearn.svm import SVC
from sklearn.datasets import load_digits
from sklearn.cross_validation import cross_val_score
from pprint import pprint
digits = load_digits()
X, y = digits.data, digits.target
print("First model")
model_1 = BaggingClassifier(SVC(gamma=0.005, C=10), max_features=0.8)
description_1 = describe_model(model_1)
pprint(description_1)
model_clone_1 = construct_model(description_1)
print("CV score orig: {:.3}".format(
np.mean(cross_val_score(model_1, X, y, cv=5))))
print("CV score clone: {:.3}".format(
np.mean(cross_val_score(model_clone_1, X, y, cv=5))))
print("Model with fixed random state")
model_2 = BaggingClassifier(random_state=1)
description_2 = describe_model(model_2)
pprint(description_2)
model_clone_2 = construct_model(json.loads(json.dumps(description_2)))
print("CV score orig: {:.3}".format(
np.mean(cross_val_score(model_2, X, y, cv=5))))
print("CV score clone: {:.3}".format(
np.mean(cross_val_score(model_clone_2, X, y, cv=5))))
@ogrisel
Copy link
Author

ogrisel commented Nov 29, 2013

This kind of model description could be used as a custom cache key (see joblib/joblib#69) in a pipeline see: scikit-learn/scikit-learn#2086

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment