Skip to content

Instantly share code, notes, and snippets.

@pavelzw
Last active June 23, 2022 18:24
Show Gist options
  • Save pavelzw/1e5ba9ca5b102e463635f0943322a7d7 to your computer and use it in GitHub Desktop.
Save pavelzw/1e5ba9ca5b102e463635f0943322a7d7 to your computer and use it in GitHub Desktop.
Optimize pickling disk space for deploying scikit-learn trees to production
# This file explains how to improve the pickling algorithm of sklearn trees to achieve significantly smaller file sizes.
# See https://tech.quantco.com/2022/06/23/dtype-reduction-sklearn.html for the corresponding blog post.
import pickle
import copyreg
from typing import Any, BinaryIO
import numpy as np
from sklearn.tree._tree import Tree
def dump_dtype_reduction(model: Any, file: BinaryIO):
p = pickle.Pickler(file)
p.dispatch_table = copyreg.dispatch_table.copy()
p.dispatch_table[Tree] = compressed_tree_pickle
p.dump(model)
def compressed_tree_pickle(tree):
assert isinstance(tree, Tree)
cls, init_args, state = tree.__reduce__()
compressed_state = compress_tree_state(state)
return compressed_tree_unpickle, (cls, init_args, compressed_state)
def compressed_tree_unpickle(cls, init_args, state):
tree = cls(*init_args)
decompressed_state = decompress_tree_state(state)
tree.__setstate__(decompressed_state)
return tree
def compress_tree_state(state: dict):
nodes = state['nodes']
# nodes is a numpy array of tuples of the following form
# (left_child, right_child, feature, threshold, impurity, n_node_samples, weighted_n_node_samples)
dtype_child = np.int16
dtype_feature = np.int16
dtype_threshold = np.float64
dtype_value = np.float32
children_left = nodes['left_child'].astype(dtype_child)
is_leaf = children_left == -1
is_not_leaf = np.logical_not(is_leaf)
# feature, threshold and children are irrelevant when leaf
children_left = children_left[is_not_leaf]
children_right = nodes['right_child'][is_not_leaf].astype(dtype_child)
features = nodes['feature'][is_not_leaf].astype(dtype_feature)
# do lossless compression for thresholds by downcasting half ints (e.g. 5.5, 10.5, ...) to int8
thresholds = nodes['threshold'][is_not_leaf].astype(dtype_threshold)
thresholds = compress_half_int_float_array(thresholds)
# value is irrelevant when node not a leaf
values = state['values'][is_leaf].astype(dtype_value)
return {'max_depth': state['max_depth'],
'node_count': state['node_count'],
'is_leaf': is_leaf,
'children_left': children_left,
'children_right': children_right,
'features': features,
'thresholds': thresholds,
'values': values}
def decompress_tree_state(state: dict):
is_leaf = state['is_leaf']
is_not_leaf = np.logical_not(is_leaf)
n_edges = len(is_leaf)
children_left = np.empty(n_edges, dtype=np.int64)
children_right = np.empty(n_edges, dtype=np.int64)
features = np.empty(n_edges, dtype=np.int64)
thresholds = np.empty(n_edges, dtype=np.float64)
# same shape as values but with all edges instead of only the leaves
values = np.zeros((n_edges, *state['values'].shape[1:]), dtype=np.float64)
children_left[is_not_leaf] = state['children_left']
children_left[is_leaf] = -1 # child of leaves is -1
children_right[is_not_leaf] = state['children_right']
children_right[is_leaf] = -1 # child of leaves is -1
features[is_not_leaf] = state['features']
features[is_leaf] = -2 # feature of leaves is -2
thresholds[is_not_leaf] = decompress_half_int_float_array(state['thresholds'])
thresholds[is_leaf] = -2 # threshold of leaves is -2
values[is_leaf] = state['values']
dtype = np.dtype([('left_child', '<i8'), ('right_child', '<i8'), ('feature', '<i8'), ('threshold', '<f8'),
('impurity', '<f8'), ('n_node_samples', '<i8'), ('weighted_n_node_samples', '<f8')])
nodes = np.zeros(n_edges, dtype=dtype)
nodes['left_child'] = children_left
nodes['right_child'] = children_right
nodes['feature'] = features
nodes['threshold'] = thresholds
return {'max_depth': state['max_depth'],
'node_count': state['node_count'],
'nodes': nodes,
'values': values}
def compress_half_int_float_array(a, compression_dtype=np.int8):
info = np.iinfo(compression_dtype)
a2 = 2. * a
is_compressible = (np.minimum(np.abs(a2 % 1 - 1), a2 % 1) < 1e-12) & \
(a2 >= info.min) & (a2 <= info.max)
not_compressible = np.logical_not(is_compressible)
a2_compressible = a2[is_compressible].astype(compression_dtype)
a_incompressible = a[not_compressible]
state = {
"is_compressible": is_compressible,
"a2_compressible": a2_compressible,
"a_incompressible": a_incompressible,
}
return state
def decompress_half_int_float_array(state):
is_compressible = state["is_compressible"]
a = np.zeros(len(is_compressible), dtype="float64")
a[is_compressible] = state["a2_compressible"] / 2.
a[np.logical_not(is_compressible)] = state["a_incompressible"]
return a
if __name__ == '__main__':
# create sklearn random forest regressor and fit it to the boston dataset
from sklearn.datasets import load_boston
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
boston = load_boston()
X, y = boston.data, boston.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
regressor = RandomForestRegressor(n_estimators=100, random_state=42)
regressor.fit(X_train, y_train)
tree = regressor.estimators_[0].tree_
cls, init_args, state = tree.__reduce__()
# save the model
# get temporary file
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as f:
dump_dtype_reduction(regressor, f)
print(f"Compressed model is {f.tell() / 2 ** 20:.2f} MB large.")
path = f.name
# make sure that the compressed model actually predicts the same things as the original model
with open(path, "rb") as f:
regressor_dtype_reduction = pickle.load(f)
np.testing.assert_allclose(regressor.predict(X_test), regressor_dtype_reduction.predict(X_test))
with tempfile.NamedTemporaryFile() as f:
pickle.dump(regressor, f)
print(f"Uncompressed model is {f.tell() / 2 ** 20:.2f} MB large.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment