Skip to content

Instantly share code, notes, and snippets.

@glemaitre
Last active March 29, 2022 09:55
Show Gist options
  • Save glemaitre/9a30dd3a704675164b84d9bf7128882e to your computer and use it in GitHub Desktop.
Save glemaitre/9a30dd3a704675164b84d9bf7128882e to your computer and use it in GitHub Desktop.
TreeSHAP bug reproducer
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "9e424e70",
"metadata": {},
"source": [
"## Comparing the TreeSHAP and `Exact` explainers"
]
},
{
"cell_type": "markdown",
"id": "f3cb9aac",
"metadata": {},
"source": [
"We define the same dataset as exposed in [gh-2345](https://github.com/slundberg/shap/issues/2345)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "708e24a8",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"X = np.vstack([\n",
" [[0, 0]] * 400,\n",
" [[0, 1]] * 100,\n",
" [[1, 0]] * 100,\n",
" [[1, 1]] * 400,\n",
"])\n",
"y = np.array(\n",
" [0] * 400 +\n",
" [50] * 100 +\n",
" [50] * 100 +\n",
" [100] * 400\n",
")"
]
},
{
"cell_type": "markdown",
"id": "793075d6",
"metadata": {},
"source": [
"### Define two trees with different split ordering\n",
"\n",
"We vary the `random_state` to make sure that\n",
"\n",
"* `tree_1` considers `X[0]` as a root split\n",
"* `tree_2` considers `X[1]` as a root split\n",
"\n",
"Note that the 2 trees compute the same prediction on any data points. They are two distinct implementations of the same mathematical decision function."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "16949644",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"from sklearn.tree import DecisionTreeRegressor\n",
"from sklearn.tree import plot_tree\n",
" \n",
"_, ax = plt.subplots(figsize=(10, 6))\n",
"tree_1 = DecisionTreeRegressor(random_state=2)\n",
"tree_1.fit(X, y)\n",
"_ = plot_tree(tree_1, ax=ax)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "39b1a909",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(figsize=(10, 6))\n",
"tree_2 = DecisionTreeRegressor(random_state=0)\n",
"tree_2.fit(X, y)\n",
"_ = plot_tree(tree_2, ax=ax)"
]
},
{
"cell_type": "markdown",
"id": "2844808b",
"metadata": {},
"source": [
"### Explaining the prediction for a given data point with TreeSHAP"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1b6bb090",
"metadata": {},
"outputs": [],
"source": [
"X_test = np.array([[1, 1]])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6609a1f9",
"metadata": {},
"outputs": [],
"source": [
"from shap import Explainer"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9ec30981",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
".values =\n",
"array([[32.5, 17.5]])\n",
"\n",
".base_values =\n",
"array([[50.]])\n",
"\n",
".data =\n",
"array([[1, 1]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree_1_explainer = Explainer(model=tree_1, algorithm=\"tree\")\n",
"tree_1_explainer(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0ffbae96",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
".values =\n",
"array([[17.5, 32.5]])\n",
"\n",
".base_values =\n",
"array([[50.]])\n",
"\n",
".data =\n",
"array([[1, 1]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree_2_explainer = Explainer(model=tree_2, algorithm=\"tree\")\n",
"tree_2_explainer(X_test)"
]
},
{
"cell_type": "markdown",
"id": "832baa37",
"metadata": {},
"source": [
"We observe that we can reproduce the bug regarding the asymmetry of SHAP values reported in the original issue linked above."
]
},
{
"cell_type": "markdown",
"id": "67eac2e3",
"metadata": {},
"source": [
"### Explaining the prediction for a given data point with the `Exact` explainer"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a347415e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(array([1, 1]),)\n"
]
},
{
"data": {
"text/plain": [
".values =\n",
"array([[25., 25.]])\n",
"\n",
".base_values =\n",
"array([50.])\n",
"\n",
".data =\n",
"array([[1, 1]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from shap.explainers import Exact\n",
"from shap.maskers import Independent\n",
"\n",
"explainer = Exact(tree_1.predict, Independent(X, max_samples=X.shape[0]))\n",
"explainer(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "55f11777",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(array([1, 1]),)\n"
]
},
{
"data": {
"text/plain": [
".values =\n",
"array([[25., 25.]])\n",
"\n",
".base_values =\n",
"array([50.])\n",
"\n",
".data =\n",
"array([[1, 1]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"explainer = Exact(tree_2.predict, Independent(X, max_samples=X.shape[0]))\n",
"explainer(X_test)"
]
},
{
"cell_type": "markdown",
"id": "07e936db",
"metadata": {},
"source": [
"The `Exact` explainer is not subject to the asymmetry problem.\n",
"\n",
"We can as well check that a Python implementation of the same decision function leads to the same explaination."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f7e6d7af",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(array([1, 1]),)\n"
]
},
{
"data": {
"text/plain": [
".values =\n",
"array([[25., 25.]])\n",
"\n",
".base_values =\n",
"array([50.])\n",
"\n",
".data =\n",
"array([[1, 1]])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def my_predict_one(x):\n",
" if x[0] < 0.5 and x[1] < 0.5:\n",
" return 0\n",
" elif x[0] > 0.5 and x[1] > 0.5:\n",
" return 100\n",
" else:\n",
" return 50\n",
"\n",
" \n",
"def my_predict(X):\n",
" return np.array(\n",
" [my_predict_one(x) for x in X]\n",
" )\n",
"\n",
"explainer = Exact(my_predict, masker=Independent(X, max_samples=X.shape[0]))\n",
"explainer(X_test)"
]
},
{
"cell_type": "markdown",
"id": "b6d3450d",
"metadata": {},
"source": [
"### Checking the ACV implementation of SHAP values"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "3939f707",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1216.45it/s]\n"
]
},
{
"data": {
"text/plain": [
"array([[[25.],\n",
" [25.]]])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from acv_explainers import ACVTree\n",
"\n",
"ACVTree(tree_1, X).shap_values(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e0b6b424",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1452.82it/s]\n"
]
},
{
"data": {
"text/plain": [
"array([[[25.],\n",
" [25.]]])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ACVTree(tree_2, X).shap_values(X_test)"
]
},
{
"cell_type": "markdown",
"id": "464b4598",
"metadata": {},
"source": [
"### Checking the FastTreeSHAP implementation\n",
"\n",
"We now observe that this bug is also present in the FastTreeSHAP implementation."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b5437856",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"poetry 1.1.13 requires packaging<21.0,>=20.4, but you have packaging 21.3 which is incompatible.\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install -q fasttreeshap"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b2c9535d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
".values =\n",
"array([[32.5, 17.5]])\n",
"\n",
".base_values =\n",
"array([[50.]])\n",
"\n",
".data =\n",
"array([[1, 1]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import fasttreeshap\n",
"\n",
"fasttreeshap.TreeExplainer(tree_1, algorithm=\"v2\")(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fd412556",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
".values =\n",
"array([[17.5, 32.5]])\n",
"\n",
".base_values =\n",
"array([[50.]])\n",
"\n",
".data =\n",
"array([[1, 1]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fasttreeshap.TreeExplainer(tree_2, algorithm=\"v2\")(X_test)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment