Skip to content

Instantly share code, notes, and snippets.

@8bit-pixies
Last active January 22, 2021 04:57
Show Gist options
  • Save 8bit-pixies/1529a5bcc730de83f476da2a1096c6fe to your computer and use it in GitHub Desktop.
Save 8bit-pixies/1529a5bcc730de83f476da2a1096c6fe to your computer and use it in GitHub Desktop.
This is to demonstrate how we could naively convert a tree in River to work with the Shap library. This is in order to start a discussion. https://github.com/online-ml/river/issues/437
# The goal of this is to try to make use of Shap to explain a tree built in river.
# https://github.com/online-ml/river/issues/437
from functools import reduce
import operator
import numpy as np
import pandas as pd
import pprint
from sklearn import datasets
import lightgbm as lgb
import matplotlib.pyplot as plt
import shap
from river import tree
from river import stream
from river.utils.skmultiflow_utils import normalize_values_in_dict
from river.utils.skmultiflow_utils import round_sig_fig
def get_all_path(tree):
# give the tree and a node, return the path to the node
all_paths = []
for (
parent_no,
child_no,
parent,
child,
branch_id,
) in tree:
if parent_no is None:
path = [(branch_id, child_no)]
all_paths.append(path)
else:
for p in all_paths:
if p[-1][1] == parent_no:
new_path = p[:]
new_path.append((branch_id, child_no))
all_paths.append(new_path)
path_dict = {x[-1][1]: x for x in all_paths}
return path_dict
def count_leaf(tree):
num_leaves = 0
for (
parent_no,
child_no,
parent,
child,
branch_id,
) in tree:
if child.is_leaf():
num_leaves += 1
return num_leaves
# https://stackoverflow.com/questions/14692690/access-nested-dictionary-items-via-a-list-of-keys
def get_by_path(root, items):
"""Access a nested object in root by item sequence."""
return reduce(operator.getitem, items, root)
def set_by_path(root, items, value):
"""Set a value in a nested object in root by item sequence."""
get_by_path(root, items[:-1])[items[-1]] = value
class FakeLightGBMBooster:
"""
This is to dummy a LightGBM object so we can load a dict object into shap
"""
def __init__(self, model, model_struct):
self.model_struct = model_struct
self.model = model
self.params = {"objective": "binary"}
def dump_model(self):
return {"tree_info": [self.model_struct]}
def predict(self, X, *args, **kwargs):
pred = []
for x in stream.iter_array(X):
pred.append(self.model.predict_one(x))
return np.array(pred)
@classmethod
def __instancecheck__(cls, instance):
return isinstance(instance, lgb.basic.Booster)
@property
def __class__(self):
return lgb.basic.Booster
def dump_tree_model(htc: tree.HoeffdingTreeClassifier):
num_leaves = count_leaf(htc._tree_root.iter_edges())
feature_names_mapping = {k: v for v, k in enumerate(dataset.feature_names)}
path_dict = get_all_path(htc._tree_root.iter_edges())
left_right = {0: "left_child", 1: "right_child"}
info_json = None
# lightgbm's json dump resets the indices for the leaf and node, we'll replicate that here
# rather than using `child_no`
curr_leaf_index = 0
curr_node_indx = 0
for (
parent_no,
child_no,
parent,
child,
branch_id,
) in htc._tree_root.iter_edges():
# print(parent_no, child_no)
# lightgbm structure reports the stats for all nodes
pred = child.stats
max_class = max(pred, key=pred.get)
# copy from river...assume its a classifier for now
sum_votes = sum(pred.values())
probas = max_class
if sum_votes > 0:
pred = normalize_values_in_dict(pred, factor=sum_votes, inplace=False)
probas = {c: round_sig_fig(proba) for c, proba in pred.items()}[
max(pred, key=pred.get)
]
internal_value = probas
stat_dict = {
"internal_value": internal_value,
"internal_count": max(int(sum_votes), 1),
"leaf_value": internal_value,
"leaf_count": max(int(sum_votes), 1),
"default_left": True,
"_parent": parent_no,
"_node": child_no,
"_branch_id": branch_id,
}
if child.is_leaf():
# stat_dict['leaf_index'] = child_no # - this is lightgbm style
# i don't think it matters in shap, and may be preferable to keep an internal index like this
stat_dict["leaf_index"] = curr_leaf_index # hack for shap
curr_leaf_index += 1
else:
# to extract the condition of the bracnh, you need to do something like
# child.split_test.describe_condtion_for_branch(branch_id)
# but we don't need to do that...
stat_dict["split_index"] = curr_node_indx
curr_node_indx += 1
condition_extract = child.split_test.describe_condition_for_branch(branch_id)
stat_dict["split_feature"] = feature_names_mapping[child.split_test._att_idx]
stat_dict["threshold"] = child.split_test._att_value
if condition_extract is not None:
# convert condtion to: split_feature, decision_type, threshold
output = condition_extract.strip().rsplit(" ", 1)
condition_extract, threshold = output[0], output[1]
output = condition_extract.strip().rsplit(" ", 1)
feature_name, decision_type = output[0], output[1]
split_feature = feature_names_mapping[feature_name.strip()]
# this dump presumes binary classes, and convention is that if decision type contains "=" or "<" it is gets routes as the left node.
stat_dict["split_feature"] = split_feature
stat_dict["threshold"] = threshold
stat_dict["decision_type"] = decision_type
stat_dict["_child_type"] = (
"left_child"
if "=" in decision_type or "<" in decision_type
else "right_child"
)
if info_json is None:
info_json = stat_dict
else:
# figure out the path to the node, and attach it appropriately
# we also need to fill in the split logic upstream (maybe)
path = path_dict[child_no]
# unravel path to get key
map_keys = [left_right[x[0]] for x in path if x[0] is not None]
set_by_path(info_json, map_keys, stat_dict)
tree_struct = {
"tree_structure": info_json,
"num_leaves": num_leaves,
}
return tree_struct
#################
# if __name__ == "__main__": blah blah blah
# Load the data
dataset = datasets.load_breast_cancer()
X, y = dataset.data, dataset.target
lgdt = lgb.LGBMClassifier(n_estimators=1)
htc = tree.HoeffdingTreeClassifier()
lgdt.fit(dataset.data, dataset.target)
json_model = (
lgdt.booster_.dump_model()
) # we can use tree_info dictionary type + SingleTree to use shap.
tree_info_json = json_model["tree_info"][
0
] # we index 1 because we only care about the singletree format
# we use the keys, tree_index, num_leaves, num_cat, shrinkage, tree_structure
# shap only cares about tree_structure, and num_leaves, to get the num_parents.
tree_struct = tree_info_json["tree_structure"]
print("This is an example of the LightGBM tree dump:")
pprint.pprint(tree_struct)
print("---------")
# it then uses: left_child, split index, right_child, leaf_index, threshold, internal_value, internal_count,
# internal_count: number of records from the training data that fall into this non-leaf node
# internal_value: raw predicted value that would be produced by this node if it was a leaf node
for _ in range(10):
for xi, yi in stream.iter_sklearn_dataset(dataset):
# print(xi, yi)
htc.learn_one(xi, yi)
# this is the root of the tree.
# we need to examine how to extract a tree from the learning node objects
# for example, the method `iter_edges` may give us a hint.
# as shap only supports binary trees, and other constraints lets see how this
# would work in this setting...
# we can create a json_tree info object similar to lightGBM for exporting
# purposes.
# htc._tree_root
# https://github.com/slundberg/shap/blob/474fc74bc0a93911879248ee9f651dcea67270fd/shap/explainers/_tree.py#L1119
# https://github.com/online-ml/river/blob/bf012736ee4bb5152d2e20ab11beedc6957e8294/river/tree/_base_tree.py
print("Replicating this tree format here...")
tree_struct = dump_tree_model(htc)
pprint.pprint(tree_struct)
print("---------")
fake_model = FakeLightGBMBooster(htc, tree_struct)
explainer = shap.TreeExplainer(fake_model)
# explainer.model --> TreeEnsemble object
# explainer.model.trees --> the internal tree object.
explainer.model.tree_output = "raw_value"
explainer.model.objective = None
explainer.model.model_type = "internal"
# TODO: as check additivity is set to False - I can't guarentee that this makes sense in this context
shap_values = explainer.shap_values(dataset.data, check_additivity=False)
# shap.force_plot(explainer.expected_value, shap_values, dataset.data).html()
shap.summary_plot(shap_values, dataset.feature_names, show=False)
f = plt.gcf()
f.savefig("output.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment