Skip to content

Instantly share code, notes, and snippets.

@benman1
Last active March 1, 2020 16:11
Show Gist options
  • Save benman1/1514a6d8bd6a2471705d56ba45130539 to your computer and use it in GitHub Desktop.
Save benman1/1514a6d8bd6a2471705d56ba45130539 to your computer and use it in GitHub Desktop.
lift chart for predictions - should work for matching vectors, similar to a residuals plots. This can also be used for plotting variables against each other, similar to a scatter graph.
import numpy as np
import pandas as pd
import seaborn as sns
import scipy
def lift_chart(y_true, y_pred, bins=10, ax=None, normalize=False, labels=None):
'''Given matched vectors of true versus predicted
targets, plot them against each other.
This can be useful for comparing predictions vs targets
across the whole spectrum. Often a model fits well to
the mid-range, however, does poorly in the extremes. This
plot can help spot these problems.
The plot is inspired by Datarobot's lift chart visualization:
https://twitter.com/hackathorn/status/868136907301146624
Parameters:
-----------
y_true : target values for model
y_pred : predicted values
bins : number of bins
ax : axis, if the lift chart is to be included
in an existing figure
normalize : if y_true and y_pred should be (z-)normalized. This
can be useful to look at variable correlations, where the
variables can have a different scale.
labels : the labels to show in the legend/colorbar.
Returns:
--------
corr : the Spearman correlation between predicted and actual targets.
Example:
--------
>> a = np.random.randn(10000)
>> lift_chart(a, np.random.randn(10000) * 0.9 + a, bins=100)
'''
if isinstance(y_true, (pd.Series, pd.DataFrame)):
y_true = y_true.values
if isinstance(y_pred, (pd.Series, pd.DataFrame)):
y_pred = y_pred.values
if normalize:
y_pred = scipy.stats.zscore(y_pred)
y_true = scipy.stats.zscore(y_true)
means, _, _ = scipy.stats.binned_statistic(
y_true, [y_true, y_pred], bins=bins
)
for i, l_mean in enumerate(means):
if labels is None:
label='target' if i == 0 else 'pred %d' % i
else:
label = labels[i]
ax = sns.lineplot(
range(bins), l_mean, label=label,
ax=ax
)
ax.set(xlabel='Target Ranking', ylabel='Target')
box = ax.get_position()
ax.legend(
bbox_to_anchor=(0.3, 0.99),
ncol=1,
)
print(scipy.stats.spearmanr(y_true, y_pred))
return ax
@benman1
Copy link
Author

benman1 commented Jan 15, 2020

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment