Skip to content

Instantly share code, notes, and snippets.

@KMarkert
Last active February 26, 2021 04:51
Show Gist options
  • Save KMarkert/15469eb59efc15ef655b2b3f51e9db01 to your computer and use it in GitHub Desktop.
Save KMarkert/15469eb59efc15ef655b2b3f51e9db01 to your computer and use it in GitHub Desktop.
import numpy as np
import pandas as pd
def sklearn_tree_to_ee_string(estimator, feature_names):
# extract out the information need to build the tree string
n_nodes = estimator.tree_.node_count
children_left = estimator.tree_.children_left
children_right = estimator.tree_.children_right
feature_idx = estimator.tree_.feature
impurities = estimator.tree_.impurity
n_samples = estimator.tree_.n_node_samples
thresholds = estimator.tree_.threshold
features = [feature_names[i] for i in feature_idx]
raw_vals = estimator.tree_.value
if raw_vals.ndim == 3:
# take argmax along class axis from values
values = np.squeeze(raw_vals.argmax(axis=-1))
elif raw_vals.ndim == 2:
# take values and drop un needed axis
values = np.squeeze(raw_vals)
else:
raise RuntimeError("could not understand estimator type and parse out the values")
# use iterative pre-order search to extract node depth and leaf information
node_ids = np.zeros(shape=n_nodes, dtype=np.int64)
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, -1)] # seed is the root node id and its parent depth
while len(stack) > 0:
node_id, parent_depth = stack.pop()
node_depth[node_id] = parent_depth + 1
node_ids[node_id] = node_id
# If we have a test node
if children_left[node_id] != children_right[node_id]:
stack.append((children_left[node_id], parent_depth + 1))
stack.append((children_right[node_id], parent_depth + 1))
else:
is_leaves[node_id] = True
# create a table of the initial structure
# each row is a node or leaf
df = pd.DataFrame(
{
"node_id": node_ids,
"node_depth": node_depth,
"is_leaf": is_leaves,
"children_left": children_left,
"children_right": children_right,
"value": values,
"criterion": impurities,
"n_samples": n_samples,
"threshold": thresholds,
"feature_name": features,
"sign": ["<="] * n_nodes,
}
)
# the table representation does not have lef vs right node structure
# so we need to add in right nodes in the correct location
# we do this by first calculating which nodes are right and then insert them at the correct index
# get a dict of right node rows and assign key based on index where to insert
inserts = {}
for row in df.itertuples():
child_r = row.children_right
if child_r > row.Index:
ordered_row = np.array(row)
ordered_row[-1] = ">"
inserts[child_r] = ordered_row[1:] # drop index value
# sort the inserts as to keep track of the additive indexing
inserts_sorted = {k: inserts[k] for k in sorted(inserts.keys())}
# loop through the row inserts and add to table (array)
table_values = df.values
for i, k in enumerate(inserts_sorted.keys()):
table_values = np.insert(table_values, (k + i), inserts_sorted[k], axis=0)
# make the ordered table array into a dataframe
# note: df is dtype "object", need to cast later on
ordered_df = pd.DataFrame(table_values, columns=df.columns)
max_depth = np.max(ordered_df.node_depth.astype(int))
tree_str = f"1) root {n_samples[0]} 9999 9999 ({impurities.sum()})\n"
previous_depth = -1
cnts = []
# loop through the nodes and calculate the node number and values per node
for row in ordered_df.itertuples():
node_depth = int(row.node_depth)
left = int(row.children_left)
right = int(row.children_right)
if left != right:
if row.Index == 0:
cnt = 2
elif previous_depth > node_depth:
depths = ordered_df.node_depth.values[: row.Index]
idx = np.where(depths == node_depth)[0][-1]
# cnt = (cnts[row.Index-1] // 2) + 1
cnt = cnts[idx] + 1
elif previous_depth < node_depth:
cnt = cnts[row.Index - 1] * 2
elif previous_depth == node_depth:
cnt = cnts[row.Index - 1] + 1
if node_depth == (max_depth - 1):
value = float(ordered_df.iloc[row.Index + 1].value)
samps = int(ordered_df.iloc[row.Index + 1].n_samples)
criterion = float(ordered_df.iloc[row.Index + 1].criterion)
tail = " *\n"
else:
if (
(bool(ordered_df.loc[ordered_df.node_id == left].iloc[0].is_leaf))
and (
bool(
int(row.Index)
< int(ordered_df.loc[ordered_df.node_id == left].index[0])
)
)
and (str(row.sign) == "<=")
):
rowx = ordered_df.loc[ordered_df.node_id == left].iloc[0]
tail = " *\n"
value = float(rowx.value)
samps = int(rowx.n_samples)
criterion = float(rowx.criterion)
elif (
(bool(ordered_df.loc[ordered_df.node_id == right].iloc[0].is_leaf))
and (
bool(
int(row.Index)
< int(ordered_df.loc[ordered_df.node_id == right].index[0])
)
)
and (str(row.sign) == ">")
):
rowx = ordered_df.loc[ordered_df.node_id == right].iloc[0]
tail = " *\n"
value = float(rowx.value)
samps = int(rowx.n_samples)
criterion = float(rowx.criterion)
else:
value = float(row.value)
samps = int(row.n_samples)
criterion = float(row.criterion)
tail = "\n"
# extract out the information needed in each line
spacing = (node_depth + 1) * " " # for pretty printing
fname = str(row.feature_name) # name of the feature (i.e. band name)
tresh = float(row.threshold) # threshold
sign = str(row.sign)
tree_str += f"{spacing}{cnt}) {fname} {sign} {tresh:.6f} {samps} {criterion:.4f} {value:.6f}{tail}"
previous_depth = node_depth
cnts.append(cnt)
return tree_str
if __name__ == "__main__":
# do your model training here
# model is an sklearn RandomForestClassifier or RandomForestRegressor
estimators = model.estimators_
trees = []
for i, estimator in enumerate(estimators):
string = sklearn_tree_to_ee_string(estimator, features)
trees.append(trees)
# save tree strings to text files of GCS
# or use directly with ee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment