Skip to content

Instantly share code, notes, and snippets.

@data-enhanced
Created February 25, 2022 02:31
Show Gist options
  • Save data-enhanced/6a096c40a0446c841efab63d9b0a7821 to your computer and use it in GitHub Desktop.
Save data-enhanced/6a096c40a0446c841efab63d9b0a7821 to your computer and use it in GitHub Desktop.
Custom function to generate a report for a binary classification model that includes the model, scores, and confusion matrix.
# Function for generating model scores and confusion matrices with custom colors and descriptive labels
# https://stackoverflow.com/questions/70097754/confusion-matrix-with-different-colors
# https://medium.com/@dtuk81/confusion-matrix-visualization-fc31e3f30fea
def report_scores(model, features, labels):
'''
Generating model scores and confusion matrices with custom colors and descriptive labels
model = model variable
features = features of desired split
labels = labels of desired split
'''
y_pred = model.predict(features)
accuracy = accuracy_score(y_test, y_pred) * 100
precision = precision_score(y_test, y_pred) * 100
recall = recall_score(y_test, y_pred) * 100
cm = confusion_matrix(labels, y_pred)
cm_norm = confusion_matrix(labels, y_pred, normalize='true')
cm_colors = sns.color_palette(['gainsboro', 'cornflowerblue'])
# axis labels for the confusion matrix plot
cm_y_labels = ['0','1'] # column labels
cm_x_labels = ['0','1'] # row labels
# Confusion matrix labels
# Review and update to match the appropriate labels for your data set
group_names = ['True Negative', 'False Positive', 'False Negative', 'True Positive']
group_counts = ['{0:0.0f}'.format(value) for value in cm.flatten()]
group_percentages = ['{0:.2%}'.format(value) for value in cm_norm.flatten()]
group_labels = [f'{v1}\n{v2}\n{v3}' for v1, v2, v3 in
zip(group_names, group_percentages, group_counts)]
group_labels = np.asarray(group_labels).reshape(2,2)
# Begin plot setup
fig, ax = plt.subplots(figsize=(4.2, 4.2))
# Heatmap
sns.heatmap(np.eye(2), annot=group_labels, annot_kws={'size': 11}, fmt='',
cmap=cm_colors, cbar=False,
yticklabels=cm_y_labels, xticklabels=cm_x_labels, ax=ax)
# Axis elements
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
ax.tick_params(labelsize=10, length=0)
ax.set_xlabel('Predicted Values', size=10)
ax.set_ylabel('Actual Values', size=10)
# Position group labels and set colors
for text_elt, group_label in zip(ax.texts, group_labels):
ax.text(*text_elt.get_position(), '\n', color=text_elt.get_color(),
ha='center', va='top')
# Title for each plot
# Adjust pad to provide room for the score report below title and above confusion matrix plot
plt.title(f'{model}', pad=80, loc='left', fontsize=16, fontweight='bold')
# Score reports beneath each title
# Adjust x and y to fit report
plt.figtext(0.21, 0.81, f'Accuracy: {round(accuracy, 3)}%\nPrecision: {round(precision, 2)}%\nRecall: {round(recall,2)}%', wrap=True, ha='left', fontsize=10)
# Disply the plot!
plt.tight_layout()
plt.subplots_adjust(left=0.2)
print('\n') # Add a blank line for improved spacing
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment