Skip to content

Instantly share code, notes, and snippets.

@romainmartinez
Last active January 17, 2023 19:26
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save romainmartinez/d1aa798896d2f8cde62e40a3e59ec4a5 to your computer and use it in GitHub Desktop.
Save romainmartinez/d1aa798896d2f8cde62e40a3e59ec4a5 to your computer and use it in GitHub Desktop.
Sensitivity analysis of a (scikit-learn) machine learning model
from sklearn.datasets import make_regression
import pandas as pd
from xgboost import XGBRegressor
import matplotlib.pyplot as plt
import seaborn as sns
X, y = make_regression(n_samples=500, n_features=4, n_informative=2, noise=0.3)
X = pd.DataFrame(X, columns=['A', 'B', 'C', 'D'])
model = XGBRegressor()
model.fit(X, y)
class Simulate:
def __init__(self, obs, var):
self.obs = obs
self.var = var
def simulate_increase(self, model, percentage):
baseline = model.predict(self.obs)
plus = {}
for ivar in self.var:
X_plus = self.obs.copy()
X_plus[ivar] = X_plus[ivar] + X_plus[ivar] * (percentage / 100)
plus[ivar] = model.predict(X_plus)
b = pd.DataFrame(
plus, index=['simulated'
]).T.reset_index().rename(columns={'index': 'test'})
b['baseline'] = baseline[0]
return b
@staticmethod
def plot_simulation(d, **kwargs):
fig, ax = plt.subplots()
sns.barplot(x='test', y='simulated', data=d, palette='deep', ax=ax)
ax.axhline(d['baseline'].values[0], color='grey', linestyle='--', linewidth=2)
ax.plot([0, 0], [-100, -100], color='grey', linestyle='--', linewidth=2, label='baseline')
maxi = int(d['simulated'].max() + d['simulated'].max() * 0.1)
mini = int(d['simulated'].min() - d['simulated'].min() * 0.1)
ax.set_ylim([mini, maxi])
ax.set_xlabel('Simulated variables')
ax.set_ylabel('Target value')
ax.set_title(kwargs.get('title'))
ax.legend()
ax.grid(axis='y', linewidth=.3)
sns.despine(offset=10, trim=True)
plt.tight_layout()
plt.show()
VAR_OPTIMIZE = ['A', 'B', 'C']
PERC = 5
ROW = X.iloc[[29]]
S = Simulate(obs=ROW, var=VAR_OPTIMIZE)
d = S.simulate_increase(model=model, percentage=PERC)
S.plot_simulation(d, title=f'Impact of a {PERC}% increase of {VAR_OPTIMIZE} in target value')
@jbanerje
Copy link

Please also include these 2 in import.

import matplotlib.pyplot as plt
import seaborn as sns

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment