Skip to content

Instantly share code, notes, and snippets.

@psteinb
Last active May 15, 2024 03:09
Show Gist options
  • Save psteinb/d04ed92eccbb2b34e9e50b5ba817a22b to your computer and use it in GitHub Desktop.
Save psteinb/d04ed92eccbb2b34e9e50b5ba817a22b to your computer and use it in GitHub Desktop.
Example on how to use shap for sksurv's random survival forest
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# Using Random Survival Forests\n",
"\n",
"This notebook demonstrates how to use [Random Survival Forests](https://scikit-survival.readthedocs.io/en/latest/api/generated/sksurv.ensemble.RandomSurvivalForest.html#sksurv.ensemble.RandomSurvivalForest) introduced in [scikit-survival](https://github.com/sebp/scikit-survival) 0.11.\n",
"\n",
"As it's popular counterparts for classification and regression, a Random Survival Forest is an ensemble\n",
"of tree-based learners. A Random Survival Forest ensures that individual trees are de-correlated by 1)\n",
"building each tree on a different bootstrap sample of the original training data, and 2)\n",
"at each node, only evaluate the split criterion for a randomly selected subset of\n",
"features and thresholds. Predictions are formed by aggregating predictions of individual\n",
"trees in the ensemble.\n",
"\n",
"To demonstrate Random Survival Forest, we are going to use data from the German Breast Cancer Study Group (GBSG-2) on the treatment of node-positive breast cancer patients. It contains data on 686 women\n",
"and 8 prognostic factors:\n",
"1. age,\n",
"2. estrogen receptor (`estrec`),\n",
"3. whether or not a hormonal therapy was administered (`horTh`),\n",
"4. menopausal status (`menostat`),\n",
"5. number of positive lymph nodes (`pnodes`),\n",
"6. progesterone receptor (`progrec`),\n",
"7. tumor size (`tsize`,\n",
"8. tumor grade (`tgrade`).\n",
"\n",
"The goal is to predict recurrence-free survival time."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"%matplotlib inline\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import OrdinalEncoder\n",
"\n",
"from sksurv.datasets import load_gbsg2\n",
"from sksurv.preprocessing import OneHotEncoder\n",
"from sksurv.ensemble import RandomSurvivalForest"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"First, we need to load the data and transform it into numeric values."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X, y = load_gbsg2()\n",
"\n",
"grade_str = X.loc[:, \"tgrade\"].astype(object).values[:, np.newaxis]\n",
"grade_num = OrdinalEncoder(categories=[[\"I\", \"II\", \"III\"]]).fit_transform(grade_str)\n",
"\n",
"X_no_grade = X.drop(\"tgrade\", axis=1)\n",
"Xt = OneHotEncoder().fit_transform(X_no_grade)\n",
"Xt = np.column_stack((Xt.values, grade_num))\n",
"\n",
"feature_names = X_no_grade.columns.tolist() + [\"tgrade\"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"Next, the data is split into 75% for training and 25% for testing, so we can determine\n",
"how well our model generalizes."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"random_state = 20\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" Xt, y, test_size=0.25, random_state=random_state)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Training\n",
"\n",
"Several split criterion have been proposed in the past, but the most widespread one is based\n",
"on the log-rank test, which you probably know from comparing survival curves among two or more\n",
"groups. Using the training data, we fit a Random Survival Forest comprising 1000 trees."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"RandomSurvivalForest(max_features='sqrt', min_samples_leaf=15,\n",
" min_samples_split=10, n_estimators=1000, n_jobs=-1,\n",
" random_state=20)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rsf = RandomSurvivalForest(n_estimators=1000,\n",
" min_samples_split=10,\n",
" min_samples_leaf=15,\n",
" max_features=\"sqrt\",\n",
" n_jobs=-1,\n",
" random_state=random_state)\n",
"rsf.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"We can check how well the model performs by evaluating it on the test data."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.6759696016771488"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rsf.score(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"This gives a concordance index of 0.68, which is a good a value and matches the results\n",
"reported in the [Random Survival Forests paper](https://projecteuclid.org/euclid.aoas/1223908043)."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Predicting\n",
"\n",
"For prediction, a sample is dropped down each tree in the forest until it reaches a terminal node.\n",
"Data in each terminal is used to non-parametrically estimate the survival and cumulative hazard\n",
"function using the Kaplan-Meier and Nelson-Aalen estimator, respectively. In addition, a risk score\n",
"can be computed that represents the expected number of events for one particular terminal node.\n",
"The ensemble prediction is simply the average across all trees in the forest.\n",
"\n",
"Let's first select a couple of patients from the test data\n",
"according to the number of positive lymph nodes and age."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>estrec</th>\n",
" <th>horTh</th>\n",
" <th>menostat</th>\n",
" <th>pnodes</th>\n",
" <th>progrec</th>\n",
" <th>tsize</th>\n",
" <th>tgrade</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>33.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>26.0</td>\n",
" <td>35.0</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>34.0</td>\n",
" <td>37.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>40.0</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>36.0</td>\n",
" <td>14.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>76.0</td>\n",
" <td>36.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>65.0</td>\n",
" <td>64.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>26.0</td>\n",
" <td>2.0</td>\n",
" <td>70.0</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>80.0</td>\n",
" <td>59.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>39.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>72.0</td>\n",
" <td>1091.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>36.0</td>\n",
" <td>2.0</td>\n",
" <td>34.0</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age estrec horTh menostat pnodes progrec tsize tgrade\n",
"0 33.0 0.0 0.0 0.0 1.0 26.0 35.0 2.0\n",
"1 34.0 37.0 0.0 0.0 1.0 0.0 40.0 2.0\n",
"2 36.0 14.0 0.0 0.0 1.0 76.0 36.0 1.0\n",
"3 65.0 64.0 0.0 1.0 26.0 2.0 70.0 2.0\n",
"4 80.0 59.0 0.0 1.0 30.0 0.0 39.0 1.0\n",
"5 72.0 1091.0 1.0 1.0 36.0 2.0 34.0 2.0"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = np.empty(X_test.shape[0], dtype=[(\"age\", float), (\"pnodes\", float)])\n",
"a[\"age\"] = X_test[:, 0]\n",
"a[\"pnodes\"] = X_test[:, 4]\n",
"\n",
"sort_idx = np.argsort(a, order=[\"pnodes\", \"age\"])\n",
"X_test_sel = pd.DataFrame(\n",
" X_test[np.concatenate((sort_idx[:3], sort_idx[-3:]))],\n",
" columns=feature_names)\n",
"\n",
"X_test_sel"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"The predicted risk scores indicate that risk for the last three patients is quite\n",
"a bit higher than that of the first three patients."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0 91.477609\n",
"1 102.897552\n",
"2 75.883786\n",
"3 170.502092\n",
"4 171.210066\n",
"5 148.691835\n",
"dtype: float64"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.Series(rsf.predict(X_test_sel))"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"We can have a more detailed insight by considering the predicted survival function. It shows that the biggest difference occurs roughly within the first 750 days."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"surv = rsf.predict_survival_function(X_test_sel, return_array=True)\n",
"\n",
"for i, s in enumerate(surv):\n",
" plt.step(rsf.event_times_, s, where=\"post\", label=str(i))\n",
"plt.ylabel(\"Survival probability\")\n",
"plt.xlabel(\"Time in days\")\n",
"plt.legend()\n",
"plt.grid(True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"Alternatively, we can also plot the predicted cumulative hazard function."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"surv = rsf.predict_cumulative_hazard_function(X_test_sel, return_array=True)\n",
"\n",
"for i, s in enumerate(surv):\n",
" plt.step(rsf.event_times_, s, where=\"post\", label=str(i))\n",
"plt.ylabel(\"Cumulative hazard\")\n",
"plt.xlabel(\"Time in days\")\n",
"plt.legend()\n",
"plt.grid(True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Permutation-based Feature Importance\n",
"\n",
"The implementation is based on scikit-learn's Random Forest implementation and inherits many\n",
"features, such as building trees in parallel. What's currently missing is feature importances\n",
"via the `feature_importance_` attribute.\n",
"This is due to the way scikit-learn's implementation computes importances. It relies on\n",
"a measure of *impurity* for each child node, and defines importance as the amount of\n",
"decrease in impurity due to a split. For traditional regression, impurity would be measured by the variance, but for survival analysis\n",
"there is no per-node impurity measure due to censoring. Instead, one could use the\n",
"magnitude of the log-rank test statistic as an importance measure, but scikit-learn's\n",
"implementation doesn't seem to allow this.\n",
"\n",
"Fortunately, this is not a big concern though, as scikit-learn's definition\n",
"of feature importance is non-standard and differs from what Leo Breiman\n",
"[proposed in the original Random Forest paper](https://github.com/scikit-learn/scikit-learn/pull/8027#issuecomment-327595859).\n",
"Instead, we can use permutation to estimate feature importance, which is\n",
"preferred over scikit-learn's definition. This is implemented in the\n",
"[ELI5](https://eli5.readthedocs.io/en/latest/overview.html) library,\n",
"which is fully compatible with scikit-survival."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false,
"tags": [
"nbval-skip"
]
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <style>\n",
" table.eli5-weights tr:hover {\n",
" filter: brightness(85%);\n",
" }\n",
"</style>\n",
"\n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
" <table class=\"eli5-weights eli5-feature-importances\" style=\"border-collapse: collapse; border: none; margin-top: 0em; table-layout: auto;\">\n",
" <thead>\n",
" <tr style=\"border: none;\">\n",
" <th style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">Weight</th>\n",
" <th style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">Feature</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 80.00%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0676\n",
" \n",
" &plusmn; 0.0229\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" pnodes\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 91.29%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0206\n",
" \n",
" &plusmn; 0.0139\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" age\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 92.19%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0177\n",
" \n",
" &plusmn; 0.0468\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" progrec\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 95.29%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0086\n",
" \n",
" &plusmn; 0.0098\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" horTh\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 97.61%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0032\n",
" \n",
" &plusmn; 0.0198\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" tsize\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 97.63%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0032\n",
" \n",
" &plusmn; 0.0060\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" tgrade\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(0, 100.00%, 99.21%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" -0.0007\n",
" \n",
" &plusmn; 0.0018\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" menostat\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(0, 100.00%, 96.19%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" -0.0063\n",
" \n",
" &plusmn; 0.0207\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" estrec\n",
" </td>\n",
" </tr>\n",
" \n",
" \n",
" </tbody>\n",
"</table>\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
"\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import eli5\n",
"from eli5.sklearn import PermutationImportance\n",
"\n",
"perm = PermutationImportance(rsf, n_iter=15, random_state=random_state)\n",
"perm.fit(X_test, y_test)\n",
"eli5.show_weights(perm, feature_names=feature_names)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"The result shows that the number of positive lymph nodes (`pnodes`) is by far the most important\n",
"feature. If its relationship to survival time is removed (by random shuffling),\n",
"the concordance index on the test data drops on average by 0.0676 points.\n",
"Again, this agrees with the results from the original\n",
"[Random Survival Forests paper](https://projecteuclid.org/euclid.aoas/1223908043)."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import shap"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"ename": "Exception",
"evalue": "The passed model is not callable and cannot be analyzed directly with the given masker! Model: RandomSurvivalForest(max_features='sqrt', min_samples_leaf=15,\n min_samples_split=10, n_estimators=1000, n_jobs=-1,\n random_state=20)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_52797/2067042318.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mexpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mExplainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrsf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/foo/venv/bar/lib64/python3.9/site-packages/shap/explainers/_explainer.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, model, masker, link, algorithm, output_names, feature_names, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;31m# if we get here then we don't know how to handle what was given to us\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The passed model is not callable and cannot be analyzed directly with the given masker! Model: \"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;31m# build the right subclass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mException\u001b[0m: The passed model is not callable and cannot be analyzed directly with the given masker! Model: RandomSurvivalForest(max_features='sqrt', min_samples_leaf=15,\n min_samples_split=10, n_estimators=1000, n_jobs=-1,\n random_state=20)"
]
}
],
"source": [
"expl = shap.Explainer(rsf)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"ename": "Exception",
"evalue": "Model type not yet supported by TreeExplainer: <class 'sksurv.ensemble.forest.RandomSurvivalForest'>",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_52797/1379721347.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mexpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTreeExplainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrsf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/foo/venv/bar/lib64/python3.9/site-packages/shap/explainers/_tree.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, model, data, model_output, feature_perturbation, feature_names, **deprecated_options)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_perturbation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfeature_perturbation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpected_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTreeEnsemble\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_missing\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_output\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;31m#self.model_output = self.model.model_output # this allows the TreeEnsemble to translate model outputs types by how it loads the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/foo/venv/bar/lib64/python3.9/site-packages/shap/explainers/_tree.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, model, data, data_missing, model_output)\u001b[0m\n\u001b[1;32m 975\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_offset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mparam_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 977\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Model type not yet supported by TreeExplainer: \"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 978\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 979\u001b[0m \u001b[0;31m# build a dense numpy version of all the tree objects\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mException\u001b[0m: Model type not yet supported by TreeExplainer: <class 'sksurv.ensemble.forest.RandomSurvivalForest'>"
]
}
],
"source": [
"expl = shap.TreeExplainer(rsf)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"argv": [
"python",
"-m",
"ipykernel_launcher",
"-f",
"{connection_file}"
],
"display_name": "Python 3 (ipykernel)",
"env": null,
"interrupt_mode": "signal",
"language": "python",
"metadata": {
"debugger": true
},
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"name": "random-survival-forest-shap.ipynb"
},
"nbformat": 4,
"nbformat_minor": 2
}
@NafisAnwari
Copy link

NafisAnwari commented May 1, 2023

As the results said, Model type not yet supported by TreeExplainer: <class 'sksurv.ensemble.forest.RandomSurvivalForest'>.

There are some custom packages in Github that can use shap values from Random Survival Forest, but it is difficult to use your own parameters unless you know a good bit of coding.

@AJErisa
Copy link

AJErisa commented May 15, 2024

Hi @NafisAnwari
Did you able to solve the problem.
I am also having issues with implementation of survSHAP on custom data but not able to get the global or local explanation.
Do you have any suggestions or workaround for that?

@NafisAnwari
Copy link

Hi @NafisAnwari Did you able to solve the problem. I am also having issues with implementation of survSHAP on custom data but not able to get the global or local explanation. Do you have any suggestions or workaround for that?

Sorry I was not able to solve the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment