Created
February 25, 2022 02:31
-
-
Save data-enhanced/6a096c40a0446c841efab63d9b0a7821 to your computer and use it in GitHub Desktop.
Custom function to generate a report for a binary classification model that includes the model, scores, and confusion matrix.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Function for generating model scores and confusion matrices with custom colors and descriptive labels | |
# https://stackoverflow.com/questions/70097754/confusion-matrix-with-different-colors | |
# https://medium.com/@dtuk81/confusion-matrix-visualization-fc31e3f30fea | |
def report_scores(model, features, labels): | |
''' | |
Generating model scores and confusion matrices with custom colors and descriptive labels | |
model = model variable | |
features = features of desired split | |
labels = labels of desired split | |
''' | |
y_pred = model.predict(features) | |
accuracy = accuracy_score(y_test, y_pred) * 100 | |
precision = precision_score(y_test, y_pred) * 100 | |
recall = recall_score(y_test, y_pred) * 100 | |
cm = confusion_matrix(labels, y_pred) | |
cm_norm = confusion_matrix(labels, y_pred, normalize='true') | |
cm_colors = sns.color_palette(['gainsboro', 'cornflowerblue']) | |
# axis labels for the confusion matrix plot | |
cm_y_labels = ['0','1'] # column labels | |
cm_x_labels = ['0','1'] # row labels | |
# Confusion matrix labels | |
# Review and update to match the appropriate labels for your data set | |
group_names = ['True Negative', 'False Positive', 'False Negative', 'True Positive'] | |
group_counts = ['{0:0.0f}'.format(value) for value in cm.flatten()] | |
group_percentages = ['{0:.2%}'.format(value) for value in cm_norm.flatten()] | |
group_labels = [f'{v1}\n{v2}\n{v3}' for v1, v2, v3 in | |
zip(group_names, group_percentages, group_counts)] | |
group_labels = np.asarray(group_labels).reshape(2,2) | |
# Begin plot setup | |
fig, ax = plt.subplots(figsize=(4.2, 4.2)) | |
# Heatmap | |
sns.heatmap(np.eye(2), annot=group_labels, annot_kws={'size': 11}, fmt='', | |
cmap=cm_colors, cbar=False, | |
yticklabels=cm_y_labels, xticklabels=cm_x_labels, ax=ax) | |
# Axis elements | |
ax.xaxis.tick_top() | |
ax.xaxis.set_label_position('top') | |
ax.tick_params(labelsize=10, length=0) | |
ax.set_xlabel('Predicted Values', size=10) | |
ax.set_ylabel('Actual Values', size=10) | |
# Position group labels and set colors | |
for text_elt, group_label in zip(ax.texts, group_labels): | |
ax.text(*text_elt.get_position(), '\n', color=text_elt.get_color(), | |
ha='center', va='top') | |
# Title for each plot | |
# Adjust pad to provide room for the score report below title and above confusion matrix plot | |
plt.title(f'{model}', pad=80, loc='left', fontsize=16, fontweight='bold') | |
# Score reports beneath each title | |
# Adjust x and y to fit report | |
plt.figtext(0.21, 0.81, f'Accuracy: {round(accuracy, 3)}%\nPrecision: {round(precision, 2)}%\nRecall: {round(recall,2)}%', wrap=True, ha='left', fontsize=10) | |
# Disply the plot! | |
plt.tight_layout() | |
plt.subplots_adjust(left=0.2) | |
print('\n') # Add a blank line for improved spacing | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment