Skip to content

Instantly share code, notes, and snippets.

@ornithos
Created June 30, 2023 09:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ornithos/f63d5745cb440421b297df89cdec5823 to your computer and use it in GitHub Desktop.
Save ornithos/f63d5745cb440421b297df89cdec5823 to your computer and use it in GitHub Desktop.
Serialization / Saving sklearn object to disk without using pickle (using numpy)
class SklearnSerializer:
"""
Sklearn recommends we use joblib's dump/load commands, but this uses pickle and pickle
is a PITA. Numpy's serialization is compact and reliable and performs well across
different platforms and versions of Python and joblib. This serializer extracts all
the components of each sklearn object that are fitted (assuming sklearn's convention
of fitted attributes ending in underscore), and saves them in a numpy .npz file.
"""
def __init__(self):
pass
def _unpackage_sklearn_object(obj):
save_out={}
save_type={}
modified_attrs = [x for x in dir(obj) if (not x.startswith("_")) and x.endswith("_")]
for attr_name in modified_attrs:
attr = getattr(obj, attr_name)
save_out[attr_name] = attr
save_type[attr_name] = type(attr)
return save_out, save_type
def _repackage_sklearn_object(obj, components, types):
for attr_name in list(components.keys()):
c_type = types[attr_name]
attr = components[attr_name]
if c_type == "int":
attr = int(attr)
elif c_type == "float":
attr = float(attr)
elif c_type == "str":
attr = str(attr)
setattr(obj, attr_name, attr)
@classmethod
def save(cls, obj, filepath):
save_out, save_type = cls._unpackage_sklearn_object(obj)
type_keys, type_vals = list(save_type.keys()), list(save_type.keys())
np.savez(str(filepath), type_keys=type_keys, type_vals=type_vals, **save_out)
@classmethod
def load(cls, obj, filepath):
np_archive = np.load(str(filepath))
loaded, types = {}, {}
for key in list(np_archive.keys()):
if key == "type_keys":
types["keys"] = np_archive[key]
continue
if key == "type_vals":
types["vals"] = np_archive[key]
continue
loaded[key] = np_archive[key]
if not len(types) == 2:
raise Exception("Bad types specification")
types = dict(zip(types["keys"], types["vals"]))
cls._repackage_sklearn_object(obj, loaded, types)
@ornithos
Copy link
Author

ornithos commented Jun 30, 2023

Example usage:

# Fit a PCA model and save via the above method
pca = PCA()
pca.fit(X)
SklearnSerializer.save(pca, "test_pca.npz")

# Create a new PCA model and load in the saved components from the original fit
pca2 = PCA()
SklearnSerializer.load(pca2, "test_pca.npz")

# Check that the loaded PCA model produces identical results
np.allclose(pca.transform(X) == pca2.transform(X))  # True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment