Skip to content

Instantly share code, notes, and snippets.

@shlomihod
Last active March 23, 2024 21:10
Show Gist options
  • Save shlomihod/a8ab81896e83eb65d2c971ed7e44b660 to your computer and use it in GitHub Desktop.
Save shlomihod/a8ab81896e83eb65d2c971ed7e44b660 to your computer and use it in GitHub Desktop.
notebook.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shlomihod/a8ab81896e83eb65d2c971ed7e44b660/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mebjsaY78lQc"
},
"source": [
"![banner](https://learn.responsibly.ai/assets/banner.jpg)\n",
"\n",
"# Class 7 - Transparency and Explainability: Methods\n",
"\n",
"https://learn.responsibly.ai"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VQ-J_AQZHZrI"
},
"source": [
"If you have any questions, please post them in the `#ds` channel in Discord or join the office hours."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v3M06_wW8lQe"
},
"source": [
"## 1. Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gTp8BbyTHZrI"
},
"outputs": [],
"source": [
"!wget -q 'https://drive.google.com/uc?export=download&id=1aH9hrWCFMyaVy9DJFeQxgHz7zrmaQ1bf' -O GiveMeSomeCredit.zip\n",
"!unzip -q -o GiveMeSomeCredit.zip -d data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u5SoJCRe2kEY",
"inputHidden": false,
"outputHidden": false
},
"outputs": [],
"source": [
"%pip install -qq numpy pandas xlrd matplotlib seaborn tabulate missingno scikit-learn imbalanced-learn statsmodels\n",
"%pip install -qq lime shap alibi corels \"aix360<0.3.0\"\n",
"%pip install -qq git+https://github.com/MarcelRobeer/ContrastiveExplanation.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I9ja-Mt-8lQo"
},
"outputs": [],
"source": [
"import os\n",
"from contextlib import (redirect_stdout, redirect_stderr,\n",
" contextmanager, ExitStack)\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"from IPython import display\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pylab as plt\n",
"import seaborn as sns\n",
"import missingno as msno\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import (classification_report, precision_recall_curve,\n",
" RocCurveDisplay, PrecisionRecallDisplay, ConfusionMatrixDisplay)\n",
"from sklearn.preprocessing import minmax_scale\n",
"\n",
"\n",
"# https://stackoverflow.com/questions/50691545/how-to-use-a-with-statement-to-suppress-sys-stdout-or-sys-stderr\n",
"@contextmanager\n",
"def suppress(out=True, err=True):\n",
" with ExitStack() as stack:\n",
" with open(os.devnull, 'w') as null:\n",
" if out:\n",
" stack.enter_context(redirect_stdout(null))\n",
" if err:\n",
" stack.enter_context(redirect_stderr(null))\n",
" yield"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R64bgg_D8lRS"
},
"source": [
"## 2. Dataset\n",
"\n",
"Banks play a crucial role in market economies. They decide who can get finance and on what terms and can make or break investment decisions. For markets and society to function, individuals and companies need access to credit.\n",
"\n",
"Credit scoring algorithms, which make a guess at the probability of default, are the method banks use to determine whether or not a loan should be granted. This competition requires participants to improve on the state of the art in credit scoring, by predicting the probability that somebody will experience financial distress in the next two years.\n",
"\n",
"Source: [Kaggle - Give Me Some Credit](https://www.kaggle.com/c/GiveMeSomeCredit)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WOZ6D65N8lRZ"
},
"source": [
"### Columns\n",
"\n",
"`SeriousDlqin2yrs` is the target. Read [here](https://www.investopedia.com/ask/answers/062315/what-are-differences-between-delinquency-and-default.asp) about delinquency."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8NUQ9JGC8lRf"
},
"outputs": [],
"source": [
"display.HTML(pd.read_excel('data/Data Dictionary.xls', header=1).to_html())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2T3EHPiu8lRk"
},
"source": [
"### Loading data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UxBV9_DC8lRl"
},
"outputs": [],
"source": [
"df = pd.read_csv('data/cs-training.csv', index_col=0)\n",
"\n",
"target_name = 'SeriousDlqin2yrs'\n",
"class_names = ['Good', 'Bad']\n",
"\n",
"df[target_name] = df[target_name].astype(bool)\n",
"\n",
"feature_names = np.array(df.columns.drop(target_name))\n",
"\n",
"print('# Data:', len(df))\n",
"print('% Positive target:', df[target_name].mean() * 100)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AI6NI22z8lRq"
},
"source": [
"### Missing Data\n",
"\n",
"⚠️ For the sake of simplicity, we will remove missing data **before** downsampling. In real-world scenario, the missing data might be important for deployment setting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZAQ0Rd0Y8lRr"
},
"outputs": [],
"source": [
"assert not df[target_name].isnull().any()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u25PhovY8lRx"
},
"outputs": [],
"source": [
"msno.bar(df);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sai6Dm648lR4"
},
"outputs": [],
"source": [
"dropna_df = df.dropna()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IvClYuL18lR9"
},
"outputs": [],
"source": [
"print('# Data:', len(dropna_df))\n",
"print('% Dropped NA / Original:', 100 * len(dropna_df) / len(df))\n",
"print('% Positive target:', dropna_df[target_name].mean() * 100)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4xA-3atz8lSC"
},
"source": [
"### Downsample majority class from ~95% to 70% to reduce imbalance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "a8uU5YhI8lSD"
},
"outputs": [],
"source": [
"X, y = (RandomUnderSampler(sampling_strategy=1/3, random_state=42)\n",
" .fit_resample(dropna_df.drop(target_name, axis=1), dropna_df[target_name]))\n",
"\n",
"rus_df = pd.DataFrame(np.concatenate((X, y[:, None]), axis=1),\n",
" columns=list(feature_names) + [target_name])\n",
"\n",
"print('# Data:', len(rus_df))\n",
"print('% Undersampled + Dropped NA / Original:', 100 * len(rus_df) / len(df))\n",
"print('% Positive target:', rus_df[target_name].mean() * 100)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6EWH7Xfqi18"
},
"source": [
"### Feature Engineering\n",
"Let's combine the three features that counts days past-due into one feature.\n",
"\n",
"⚠️ This choise was done for the sake of simpliciy, and after some exploratory data analysis ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "taGqDYLtqi18"
},
"outputs": [],
"source": [
"fe_df = rus_df.copy()\n",
"\n",
"fe_df['NumberOfTimeDaysPastDue'] = (fe_df['NumberOfTime30-59DaysPastDueNotWorse']\n",
" + fe_df['NumberOfTime60-89DaysPastDueNotWorse']\n",
" + fe_df ['NumberOfTimes90DaysLate'])\n",
"\n",
"fe_df = fe_df.drop(['NumberOfTime30-59DaysPastDueNotWorse',\n",
" 'NumberOfTime60-89DaysPastDueNotWorse',\n",
" 'NumberOfTimes90DaysLate'],\n",
" axis=1)\n",
"\n",
"feature_names = np.array(fe_df.columns.drop(target_name))\n",
"\n",
"# re-order columns that the target name is last\n",
"fe_df = fe_df[list(feature_names) + [target_name]]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Pe2L5f68lTE"
},
"source": [
"### Splitting into training and test datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YZ-QTZgO8lTF"
},
"outputs": [],
"source": [
"train_df, test_df = train_test_split(fe_df, test_size=0.3, random_state=42)\n",
"\n",
"print('# Training dataset:', len(train_df))\n",
"print('# Test dataset:', len(test_df))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2OncSnEx8lTh"
},
"source": [
"### Exploratory data analysis"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VXL3yobE8lTh"
},
"source": [
"#### Features"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "j8sjglRz8lTi"
},
"outputs": [],
"source": [
"train_df.info()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2YlFZXC48lTn"
},
"outputs": [],
"source": [
"train_df.describe().T"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rxLUNqO_8lT5"
},
"outputs": [],
"source": [
"sns.pairplot(train_df.sample(1000), hue=target_name);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "REZ0WAo58lUD"
},
"outputs": [],
"source": [
"plt.figure(figsize=(14,12))\n",
"sns.heatmap(train_df.corr(), annot=True, fmt=\".2g\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qy_Hmd6D8lUK"
},
"source": [
"#### Separate datasets to features and target"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7MDRkHxT8lUN"
},
"outputs": [],
"source": [
"X_train, y_train = train_df.drop(target_name, axis=1), train_df[target_name]\n",
"X_test, y_test = test_df.drop(target_name, axis=1), test_df[target_name]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "36QJ3T7m8lUm"
},
"source": [
"## 3. Model - Random Forest Classifier"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rfCgzWMk8lUo"
},
"source": [
"### Training Random Forest\n",
"\n",
"The hyperparameters were chosen with cross-validation grid search."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_TM2FB6T8lUq"
},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"model = (RandomForestClassifier(class_weight='balanced',\n",
" max_depth=30,\n",
" min_samples_leaf=2,\n",
" min_samples_split=100,\n",
" n_estimators=1200)\n",
" .fit(X_train, y_train))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mFMQhDz7qi2X"
},
"source": [
"### Prediction Widget\n",
"Shows the prediction probability of the positive class (untrustworthy)\n",
"\n",
"#### ⚠️ Ignore the setup code and go directly to the interactive widget"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5qMdT-_qqi2X"
},
"outputs": [],
"source": [
"from ipywidgets import interact, IntSlider, FloatSlider, Layout\n",
"\n",
"slider_setup = {'continuous_update': False,\n",
" 'style': {'description_width': 'initial'},\n",
" 'layout': Layout(width='50%')}\n",
"\n",
"@interact(rev_util=FloatSlider(description='Revolving Utilization Of Unsecured Lines',\n",
" min=0, max=1, step=0.05, value=0, **slider_setup),\n",
" age=IntSlider(description='Age',\n",
" min=18, max=91, value=27, **slider_setup),\n",
" debt_ratio=FloatSlider(description='Debt Ratio',\n",
" min=0, max=1, step=0.05, value=0, **slider_setup),\n",
" monthly_income=IntSlider(description='Monthly Income',\n",
" min=0, max=10000, value=5000, **slider_setup),\n",
" number_of_open_credit_lines_and_loans=IntSlider(description='Number Of Open Credit Lines And Loans',\n",
" min=0, max=11, value=0, **slider_setup),\n",
" number_of_real_estate_loans_or_lines=IntSlider(description='Number Real Estate Loans Or Lines',\n",
" min=0, max=11, value=0, **slider_setup),\n",
" number_of_dependents=IntSlider(description='Number Of Dependents',\n",
" min=0, max=8, value=0, **slider_setup),\n",
" number_of_times_days_past_due=IntSlider(description='Number Of Time Days Past Due',\n",
" min=0, max=31, value=0, **slider_setup))\n",
"def interactive_prediction(rev_util, age, debt_ratio, monthly_income,\n",
" number_of_open_credit_lines_and_loans,\n",
" number_of_real_estate_loans_or_lines,\n",
" number_of_dependents,\n",
" number_of_times_days_past_due):\n",
"\n",
" x = (rev_util, age, debt_ratio, monthly_income,\n",
" number_of_open_credit_lines_and_loans,\n",
" number_of_real_estate_loans_or_lines,\n",
" number_of_dependents,\n",
" number_of_times_days_past_due)\n",
"\n",
" y_pred_bad_prob = model.predict_proba([x])[0][1]\n",
"\n",
" _, ax = plt.subplots(1, figsize=(10, 1))\n",
" plt.barh(' ', y_pred_bad_prob)\n",
" plt.xlim(0, 1)\n",
" plt.xlabel('Untrustworthiness indicator', fontsize=20)\n",
" plt.xticks(np.arange(0, 1.1, 0.1), fontsize=15)\n",
" plt.ylabel('', fontsize=20)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qUjfqSrv8lUt"
},
"source": [
"### Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qIelMK0J8lUu"
},
"outputs": [],
"source": [
"print('Training dataset:')\n",
"print(classification_report(y_train, model.predict(X_train), target_names=class_names))\n",
"\n",
"print('Test dataset:')\n",
"print(classification_report(y_test, model.predict(X_test), target_names=class_names))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "995kjIMT8lUz"
},
"outputs": [],
"source": [
"ConfusionMatrixDisplay.from_estimator(model, X_test, y_test, normalize='all');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GKQEyW7e8lU6"
},
"outputs": [],
"source": [
"RocCurveDisplay.from_estimator(model, X_test, y_test);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FefkC0C18lWu"
},
"outputs": [],
"source": [
"PrecisionRecallDisplay.from_estimator(model, X_test, y_test);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-H6-HDID8lW1"
},
"source": [
"### [OPTIONAL] Choosing threshold using cost matrix\n",
"- False Negative: 5 (giving risky loan)\n",
"- True Negative: -1 (giving loan to a trustworthy person)\n",
"\n",
"⚠️ We shouldn't use the test dataset to choose the threshold, but we are doing so for this toy example; We should have had a validation dataset for that."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PB6SWWoz8lW2"
},
"outputs": [],
"source": [
"def get_best_threshold(model, X, y, cost_matrix):\n",
" \"\"\"TP, FN, FP, TN\"\"\"\n",
"\n",
" y_score = model.predict_proba(X)[:, 1]\n",
"\n",
" precision, recall, thresholds = precision_recall_curve(y, y_score)\n",
"\n",
" num_pos_class = y.sum()\n",
" num_neg_class = (1-y).sum()\n",
"\n",
" tp = recall * num_pos_class\n",
" fp = (tp / precision) - tp\n",
" tn = num_neg_class - fp\n",
" fn = num_pos_class - tp\n",
"\n",
" assert ((tp + fp + tn + fn) == num_pos_class + num_neg_class).all()\n",
"\n",
" # acc = (tp + tn) / (num_pos_class + num_neg_class)\n",
"\n",
" confusion_matrix = np.stack([tp, fn, fp, tn])\n",
" cost = cost_matrix @ confusion_matrix\n",
"\n",
" best_threshold_index = np.argmin(cost)\n",
" best_threshold = thresholds[best_threshold_index]\n",
"\n",
" return best_threshold, {'precision': precision[best_threshold_index],\n",
" 'recall/tpr': recall[best_threshold_index],\n",
" 'fpr': fp[best_threshold_index]/num_neg_class}\n",
"\n",
"\n",
"# [TP, FN, FP, TN]\n",
"COST_MATRIX = np.array([0, 5, 0, -1])\n",
"\n",
"best_threshod, metrics_ = get_best_threshold(model, X_test, y_test, COST_MATRIX)\n",
"\n",
"print('Best Threshod:', best_threshod)\n",
"pd.Series(metrics_).plot(kind='barh');"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LpsgB5sp8lW5"
},
"source": [
"### Setting the Prediction Function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7syXrrET8lW6"
},
"outputs": [],
"source": [
"# predict_fn = lambda X: model.predict_proba(X)[:, 1] > best_threshod\n",
"predict_fn = lambda X: model.predict(X)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v98f8r_I8lXB"
},
"source": [
"## 4. Interpretable Machine Learning\n",
"\n",
"### Reference:\n",
"- Molnar, Christoph. [Interpretable machine learning](https://christophm.github.io/interpretable-ml-book/). 2019.\n",
"- Lakkaraju, Hima. [Interpretability and Explainability in Machine Learning Course](https://interpretable-ml-class.github.io/). Harvard. 2019."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QCJqvo0C8lXC"
},
"source": [
"Choose one random row:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LYPdizQX8lXE"
},
"outputs": [],
"source": [
"class_to_sample = 1 # change me!\n",
"assert class_to_sample in (0, 1)\n",
"\n",
"np.random.seed(123456789) # fix state for choosing random individual\n",
"idx = np.random.choice(y_test[y_test == class_to_sample].index)\n",
"\n",
"print('Chosen Sample:', idx)\n",
"print('True class:', y_test[idx])\n",
"print('Prediction class:', predict_fn(X_test.loc[idx][None, :])[0] == 1)\n",
"print('Prediction probabilities:', model.predict_proba(X_test.loc[idx][None, :]))\n",
"\n",
"X_test.loc[idx]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Y77J5vl8lXJ"
},
"source": [
"### 4.1. Feature Importance\n",
"\n",
"https://christophm.github.io/interpretable-ml-book/feature-importance.html\n",
"\n",
"https://scikit-learn.org/stable/auto_examples/inspection/plot_permutation_importance.html#sphx-glr-auto-examples-inspection-plot-permutation-importance-py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ike-T4zi8lXJ"
},
"source": [
"Random forest feature importance:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4n5td0xM8lXK"
},
"outputs": [],
"source": [
"tree_feature_importances = model.feature_importances_\n",
"sorted_idx = tree_feature_importances.argsort()\n",
"\n",
"y_ticks = np.arange(0, len(feature_names))\n",
"fig, ax = plt.subplots()\n",
"ax.barh(y_ticks, tree_feature_importances[sorted_idx])\n",
"ax.set_yticklabels(feature_names[sorted_idx])\n",
"ax.set_yticks(y_ticks)\n",
"ax.set_title('Random Forest Feature Importances (MDI)')\n",
"fig.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pyUEiekB8lXN"
},
"source": [
"Permuation feature importance:\n",
"\n",
"(The current implementation uses accuracy as metric)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JEttsotn8lXO",
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn.inspection import permutation_importance\n",
"\n",
"result = permutation_importance(model, X_test, y_test, scoring='average_precision', n_repeats=10,\n",
" random_state=42, n_jobs=2)\n",
"sorted_idx = result.importances_mean.argsort()\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.boxplot(result.importances[sorted_idx].T,\n",
" vert=False, labels=X_test.columns[sorted_idx])\n",
"ax.set_title('Permutation Importances (test set)')\n",
"fig.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BoWLaojnqi28"
},
"source": [
"### 4.2. Partial Dependence Plot (PDP)\n",
"\n",
"https://christophm.github.io/interpretable-ml-book/pdp.html\n",
"\n",
"https://scikit-learn.org/stable/modules/partial_dependence.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i96_ACQHqi28"
},
"outputs": [],
"source": [
"from sklearn.inspection import PartialDependenceDisplay\n",
"\n",
"_, ax = plt.subplots(figsize=(20, 20))\n",
"\n",
"PartialDependenceDisplay.from_estimator(model, X_test, X_test.columns, ax=ax);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VETdaQ9s8lXR"
},
"source": [
"### 4.3. LIME\n",
"\n",
"https://christophm.github.io/interpretable-ml-book/lime.html\n",
"\n",
"https://github.com/marcotcr/lime\n",
"\n",
"https://marcotcr.github.io/lime/tutorials/Tutorial%20-%20continuous%20and%20categorical%20features.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YztdIykL8lXR"
},
"outputs": [],
"source": [
"import lime\n",
"import lime.lime_tabular\n",
"\n",
"lime_explainer = lime.lime_tabular.LimeTabularExplainer(X_train.to_numpy(),\n",
" feature_names=feature_names,\n",
" class_names=class_names,\n",
" discretize_continuous=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S_S7FETR8lXW"
},
"outputs": [],
"source": [
"exp = lime_explainer.explain_instance(X_test.loc[idx], model.predict_proba)\n",
"\n",
"exp.show_in_notebook(show_table=True, show_all=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G7J_Kam38lXi"
},
"source": [
"### 4.4. SHAP (SHapley Additive exPlanations)\n",
"\n",
"https://christophm.github.io/interpretable-ml-book/shap.html\n",
"\n",
"https://github.com/slundberg/shap\n",
"\n",
"https://slundberg.github.io/shap/notebooks/tree_explainer/Census%20income%20classification%20with%20LightGBM.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-EYA9jH38lXj"
},
"outputs": [],
"source": [
"X_test_sample = X_test.sample(1000)"
]
},
{
"cell_type": "code",
"source": [
"model.predict_proba(X_test_sample)"
],
"metadata": {
"id": "kt2bUy6HqDtZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2f0Iv6AP8lXo"
},
"outputs": [],
"source": [
"import shap\n",
"\n",
"# sample from X_test, because otherwise it takes a lot of time to run\n",
"# and making sure that the choosed row is the last one\n",
"X_test_sample = X_test.sample(1000)\n",
"if idx in X_test_sample:\n",
" X_test_sample = X_test_sample.drop(idx)\n",
"X_test_sample = X_test_sample.append(X_test.loc[idx])\n",
"\n",
"shap_explainer = shap.TreeExplainer(model)\n",
"shap_values = shap_explainer.shap_values(X_test_sample)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1R6INttU8lXs"
},
"source": [
"#### Visualize a single prediction"
]
},
{
"cell_type": "code",
"source": [
"shap.initjs()\n",
"\n",
"shap.force_plot(\n",
" shap_explainer.expected_value[1],\n",
" shap_values[-1, :, 1],\n",
" X_test_sample.iloc[-1, :],\n",
" link=\"logit\",\n",
")"
],
"metadata": {
"id": "4q5ItydxuOL8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "TFcEjQFc8lXx"
},
"source": [
"#### Visualize many predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T0_khRty8lXy"
},
"outputs": [],
"source": [
"shap.initjs()\n",
"\n",
"shap.force_plot(\n",
" shap_explainer.expected_value[1],\n",
" shap_values[:, :, 1],\n",
" X_test_sample.iloc[:, :],\n",
" link=\"logit\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tlfe06Lx8lX1"
},
"source": [
"#### SHAP Summary Plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lfgXa2go8lX2"
},
"outputs": [],
"source": [
"shap.initjs()\n",
"\n",
"shap.summary_plot(shap_values[:, :, 1], X_test_sample, class_names=class_names)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0smKmqGG8lX6"
},
"source": [
"#### SHAP Dependence Plots"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xEhnzJMZ8lX7",
"scrolled": false
},
"outputs": [],
"source": [
"shap.initjs()\n",
"\n",
"for name in X_test_sample.columns:\n",
" shap.dependence_plot(name, shap_values[:, :, 1], X_test_sample)"
]
},
{
"cell_type": "code",
"source": [
"shap_values[:, :, 1].shape"
],
"metadata": {
"id": "N6YtfOwvvUtg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Fl_QoOu8lX_"
},
"source": [
"### 4.5. Anchor\n",
"\n",
"https://docs.seldon.io/projects/alibi/en/stable/index.html\n",
"\n",
"https://docs.seldon.io/projects/alibi/en/stable/examples/anchor_tabular_adult.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5STxeUER8lYA"
},
"outputs": [],
"source": [
"from alibi.explainers import AnchorTabular"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mzPCob768lYL"
},
"outputs": [],
"source": [
"anchor_explainer = (AnchorTabular(predict_fn, feature_names)\n",
" .fit(X_train.to_numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_JhjbKIW8lYW"
},
"outputs": [],
"source": [
"explanation = anchor_explainer.explain(X_test.loc[idx][None, :], threshold=0.95)\n",
"\n",
"print('Anchor: %s' % (' AND '.join(explanation.anchor)))\n",
"print('Precision: %.2f' % explanation.precision)\n",
"print('Coverage: %.2f' % explanation.coverage)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZRw8yA8z8lYY"
},
"source": [
"### 4.6. Prototypes\n",
"\n",
"https://christophm.github.io/interpretable-ml-book/proto.html\n",
"\n",
"https://aix360.readthedocs.io/en/latest/die.html?protodash-explainer#protodash-explainer\n",
"\n",
"https://nbviewer.jupyter.org/github/IBM/AIX360/blob/master/examples/tutorials/HELOC.ipynb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W3qElUaC8lYZ"
},
"outputs": [],
"source": [
"from aix360.algorithms.protodash import ProtodashExplainer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4K_01a8p8lYr"
},
"source": [
"#### Normalize the data and chose a particular applicant, whose profile is displayed below"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k19ys7218lYs"
},
"outputs": [],
"source": [
"X_idx = X_test.loc[idx][None, :]\n",
"y_idx = predict_fn(X_idx)[None, :]\n",
"Xy_idx = np.hstack((X_idx, y_idx))\n",
"\n",
"y_pred_train = predict_fn(X_train)\n",
"Xy_train = np.hstack((X_train,\n",
" y_pred_train[:, None]))\n",
"\n",
"# Zy_train = np.hstack((minmax_scale(X_train, (-0.5, 0.5), axis=0),\n",
"# y_pred_train[:, None]))\n",
"\n",
"Zy_train = Xy_train\n",
"\n",
"Xy_train_by_prediction = Xy_train[y_pred_train == y_idx[0]]\n",
"Zy_train_by_prediction = Zy_train[y_pred_train == y_idx[0]]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YCsqdcb48lYx"
},
"source": [
"#### Find similar applicants predicted as \"bad\" using the protodash explainer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AQ5cyjeJ8lYy"
},
"outputs": [],
"source": [
"prototype_explainer = ProtodashExplainer()\n",
"\n",
"n_prototypes = 5\n",
"\n",
"with suppress():\n",
" weights, prototype_indices, set_values_obj_func = prototype_explainer.explain(Xy_idx,\n",
" Xy_train_by_prediction,\n",
" n_prototypes)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m4XQoneW8lY2"
},
"source": [
"#### Display similar applicant user profiles and the extent to which they are similar to the chosen applicant as indicated by the last row in the table below labelled as \"Weight\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7Qus6l6y8lY3"
},
"outputs": [],
"source": [
"prototype_df = pd.DataFrame.from_records(Xy_train_by_prediction[prototype_indices, 0:-1].astype('double'),\n",
" columns=list(feature_names))\n",
"prototype_classes = []\n",
"for prototype_index in prototype_indices:\n",
" prototype_classes.append(class_names[int(Zy_train_by_prediction[prototype_index, -1])]) # Append class names\n",
"\n",
"prototype_df[target_name] = prototype_classes\n",
"prototype_df['Weight'] = (np.around(weights, n_prototypes)\n",
" / np.sum(np.around(weights, n_prototypes))) # Calculate normalized importance weights\n",
"prototype_df = prototype_df.sort_values('Weight', ascending=False)\n",
"\n",
"prototype_df.transpose()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DyKeubeV8lY-"
},
"outputs": [],
"source": [
"X_test.loc[idx]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "enJlpWUm8lZP"
},
"source": [
"#### Compute how similar a feature of a prototypical user is to the chosen applicant.\n",
"\n",
"Closer to 1 is similar."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FCM5ZBS88lZQ"
},
"outputs": [],
"source": [
"EPS = 1e-10 # Small constant defined to eliminate divide-by-zero errors\n",
"\n",
"Z_train_by_prediction_prototypes = Zy_train_by_prediction[prototype_indices, 0:-1] # Store chosen prototypes\n",
"\n",
"log_feature_weights = (-np.abs(X_idx[0] - Z_train_by_prediction_prototypes)\n",
" / (np.std(Z_train_by_prediction_prototypes, axis=0) + EPS))\n",
"\n",
"feature_weights = np.exp(log_feature_weights)\n",
"\n",
"feature_weights_df = pd.DataFrame.from_records(np.around(feature_weights.astype('double'), 2),\n",
" columns=feature_names)\n",
"feature_weights_df.transpose()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bC4T1rLy8lZT"
},
"source": [
"### 4.7. Contrastive explanations\n",
"\n",
"https://christophm.github.io/interpretable-ml-book/counterfactual.html\n",
"\n",
"https://github.com/MarcelRobeer/ContrastiveExplanation\n",
"\n",
"https://nbviewer.jupyter.org/github/MarcelRobeer/ContrastiveExplanation/blob/master/Contrastive%20explanation%20-%20example%20usage.ipynb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OEmT1_MS8lZT"
},
"outputs": [],
"source": [
"import contrastive_explanation as ce"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JX-Bb8I88lZW"
},
"outputs": [],
"source": [
"# Create a domain mapper for the Pandas DataFrame\n",
"dm = ce.domain_mappers.DomainMapperTabular(X_train.to_numpy(),\n",
" feature_names=feature_names,\n",
" contrast_names=class_names)\n",
"\n",
"\n",
"# Create the contrastive explanation object (default is a Foil Tree explanator)\n",
"contrastive_explainer = ce.ContrastiveExplanation(dm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WS47sqQ58lZZ"
},
"outputs": [],
"source": [
"# Explain the instance (sample) for the given model\n",
"explanation = contrastive_explainer.explain_instance_domain(model.predict_proba, X_test.loc[idx])\n",
"\n",
"print(explanation)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mYbzLCAS8lZd"
},
"source": [
"### 4.8. Global Surrogate - Decision Tree\n",
"\n",
"https://christophm.github.io/interpretable-ml-book/global.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QG8tbBCID2Fi"
},
"outputs": [],
"source": [
"from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BGrT-t378lZg"
},
"outputs": [],
"source": [
"y_train_pred = predict_fn(X_train)\n",
"y_test_pred = predict_fn(X_test)\n",
"\n",
"global_surrogate_model = (DecisionTreeClassifier(class_weight='balanced',\n",
" max_depth=3)\n",
" .fit(X_train, y_train_pred))\n",
"\n",
"y_test_global_surrogate = global_surrogate_model.predict(X_test)\n",
"\n",
"print(classification_report(y_test_pred, y_test_global_surrogate, target_names=class_names))\n",
"\n",
"\n",
"_, ax = plt.subplots(figsize=(20, 10))\n",
"\n",
"plot_tree(global_surrogate_model,\n",
" feature_names=feature_names, class_names=class_names,\n",
" filled=True, proportion=True, fontsize=10, ax=ax);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ntQhV9J88lZr"
},
"source": [
"### 4.9. Interpretable Model - CORELS\n",
"\n",
"https://corels.eecs.harvard.edu/\n",
"\n",
"https://christophm.github.io/interpretable-ml-book/rules.html"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t4k26yt9qi4h"
},
"source": [
"Discretize features into binary variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cx-l-wpfqi4h"
},
"outputs": [],
"source": [
"from sklearn.preprocessing import KBinsDiscretizer\n",
"\n",
"# we let the discretization run automatically\n",
"# we probably could come up with better way to do so with some feature engineering\n",
"kbins_disc = KBinsDiscretizer(n_bins=10, encode='onehot-dense')\n",
"X_train_binarized = kbins_disc.fit_transform(X_train)\n",
"\n",
"binarized_features= [f'{feature}_[{rh:.2f},{lh:.2f})'\n",
" for feature, bin_edges in zip(X_train.columns, kbins_disc.bin_edges_)\n",
" for rh, lh in zip(bin_edges[:-1], bin_edges[1:])]\n",
"\n",
"X_train_binarized = pd.DataFrame(X_train_binarized,\n",
" columns=binarized_features, index=X_train.index)\n",
"\n",
"X_test_binarized = kbins_disc.transform(X_test)\n",
"X_test_binarized = pd.DataFrame(X_test_binarized,\n",
" columns=binarized_features, index=X_test.index)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zziu9ykaqi4k"
},
"source": [
"CORELS Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JUNXZlAUqi4k"
},
"outputs": [],
"source": [
"from corels import CorelsClassifier\n",
"\n",
"np.bool = bool\n",
"\n",
"corels_model = (CorelsClassifier()\n",
" .fit(X_train_binarized, y_train, binarized_features))\n",
"\n",
"print()\n",
"y_pred_corels = corels_model.predict(X_test_binarized)\n",
"print(classification_report(y_test, y_pred_corels))"
]
}
],
"metadata": {
"colab": {
"name": "notebook.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernel_info": {
"name": "python3"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"nteract": {
"version": "0.12.3"
},
"vscode": {
"interpreter": {
"hash": "55bbdba5d2159c30191d9b81156a2ec7ece345201aa1fcd9b85bbc484276dddb"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment