Skip to content

Instantly share code, notes, and snippets.

@Mark1626
Created May 16, 2023 09:15
Show Gist options
  • Save Mark1626/aaa9e4cde44af2ade72057ab13cc26b2 to your computer and use it in GitHub Desktop.
Save Mark1626/aaa9e4cde44af2ade72057ab13cc26b2 to your computer and use it in GitHub Desktop.
def serialise_array(clf):
tree_ser = []
offsets = []
for estimator in clf.estimators_:
offsets.append(len(tree_ser))
tree = estimator.tree_
is_leaf = [1 if child_left == -1 else 0 for child_left in tree.children_left]
features = list(map(int, tree.feature))
classes = [int(np.argmax(val)) for val in tree.value]
thresholds = tree.threshold.tolist()
children_left = list(map(int, tree.children_left))
children_right = list(map(int, tree.children_right))
feature_classes = [feature if feature >= 0 else cls for (feature, cls) in zip(features, classes)]
feature_classes
nodes = []
for (leaf, ft_cls, threshold, left, right) in zip(is_leaf, feature_classes, thresholds, children_left, children_right):
nodes.append([leaf, ft_cls, threshold, left, right])
if len(nodes) > 1000:
print("WARNING: Number of nodes in one tree exceeds 1000")
new_nodes = []
for (i, node) in enumerate(nodes):
if node[0] == 0:
new_nodes.append([node[0], node[1], node[2], node[3] - i, node[4] - i])
else:
new_nodes.append(node)
tree_ser += new_nodes
return (offsets, tree_ser)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment