Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Pretty print a confusion matrix with seaborn
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14):
"""Prints a confusion matrix, as returned by sklearn.metrics.confusion_matrix, as a heatmap.
Note that due to returning the created figure object, when this funciton is called in a
notebook the figure willl be printed twice. To prevent this, either append ; to your
function call, or modify the function by commenting out the return expression.
Arguments
---------
confusion_matrix: numpy.ndarray
The numpy.ndarray object returned from a call to sklearn.metrics.confusion_matrix.
Similarly constructed ndarrays can also be used.
class_names: list
An ordered list of class names, in the order they index the given confusion matrix.
figsize: tuple
A 2-long tuple, the first value determining the horizontal size of the ouputted figure,
the second determining the vertical size. Defaults to (10,7).
fontsize: int
Font size for axes labels. Defaults to 14.
Returns
-------
matplotlib.figure.Figure
The resulting confusion matrix figure
"""
df_cm = pd.DataFrame(
confusion_matrix, index=class_names, columns=class_names,
)
fig = plt.figure(figsize=figsize)
try:
heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
except ValueError:
raise ValueError("Confusion matrix values must be integers.")
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
plt.ylabel('True label')
plt.xlabel('Predicted label')
# Note that due to returning the created figure object, when this funciton is called in a notebook
# the figure willl be printed twice. To prevent this, either append ; to your function call, or
# modify the function by commenting out this return expression.
return fig
@Yumin-Sun-00

This comment has been minimized.

Copy link

@Yumin-Sun-00 Yumin-Sun-00 commented Mar 22, 2018

Thanks for sharing.

@indielyt

This comment has been minimized.

Copy link

@indielyt indielyt commented Apr 12, 2018

Very nice, excellent documentation. thank you!

@Znigneering

This comment has been minimized.

Copy link

@Znigneering Znigneering commented Jun 4, 2018

Thank you!

@BrunoGomesCoelho

This comment has been minimized.

Copy link

@BrunoGomesCoelho BrunoGomesCoelho commented Oct 6, 2018

Hey, thanks for this! I modified it to work for normalized cm as well, you can just add this:

    if normalize:
        confusion_matrix = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

and them add the format option:

    fmt = '.2f' if normalize else 'd'

and add fmt=fmt in your try-catch :)

@Akashpatel579

This comment has been minimized.

Copy link

@Akashpatel579 Akashpatel579 commented Jan 20, 2019

Thanks for sharing

@huntermaxfield

This comment has been minimized.

Copy link

@huntermaxfield huntermaxfield commented Mar 5, 2019

I can't seem to figure out why my heat map is printing twice? here is my code:

y_pred_knn_raw = knn_model.predict(X_test) print(metrics.confusion_matrix(y_true = y_test, y_pred = y_pred_knn)) #print basic confusion matrix print('Accuracy = ', metrics.accuracy_score(y_true = y_test, y_pred = y_pred_knn)) #printing accurary print_confusion_matrix(metrics.confusion_matrix(y_true = y_test, y_pred = y_pred_knn), labels) #print cm heatmap

Hoping someone could help me see what i'm doing wrong?

@guifeliper

This comment has been minimized.

Copy link

@guifeliper guifeliper commented Mar 13, 2019

I can't seem to figure out why my heat map is printing twice? here is my code:

y_pred_knn_raw = knn_model.predict(X_test) print(metrics.confusion_matrix(y_true = y_test, y_pred = y_pred_knn)) #print basic confusion matrix print('Accuracy = ', metrics.accuracy_score(y_true = y_test, y_pred = y_pred_knn)) #printing accurary print_confusion_matrix(metrics.confusion_matrix(y_true = y_test, y_pred = y_pred_knn), labels) #print cm heatmap

Hoping someone could help me see what i'm doing wrong?

Printing Twice for me too!

@alanaor

This comment has been minimized.

Copy link

@alanaor alanaor commented Mar 29, 2019

Add a ';' after you call the function. Should only print once.

@jeonghyunwoo

This comment has been minimized.

Copy link

@jeonghyunwoo jeonghyunwoo commented Jun 18, 2019

Thank you, very helpful

@pranjalchaubey

This comment has been minimized.

Copy link

@pranjalchaubey pranjalchaubey commented Sep 19, 2019

Prints twice for no apparent reason.

@laubosslink

This comment has been minimized.

Copy link

@laubosslink laubosslink commented Oct 3, 2019

Prints twice for no apparent reason.

Just remove return fig

@nkc512

This comment has been minimized.

Copy link

@nkc512 nkc512 commented May 5, 2020

Thank you

@hector-margarito

This comment has been minimized.

Copy link

@hector-margarito hector-margarito commented Sep 15, 2020

Awesome

@RunningJingJing

This comment has been minimized.

Copy link

@RunningJingJing RunningJingJing commented May 11, 2021

awesome, thanks.
BTW "class_names" sort alphabetically, when there are a lot of attributes it will be better to do this as @scott Boston said
https://stackoverflow.com/questions/54875846/how-to-print-labels-and-column-names-for-confusion-matrix

and the argument normalize{‘true’, ‘pred’, ‘all’}, default=None can deal with the normalization when we generate the confusion matrix @BrunoGomesCoelho
eg:
confusion_matrix_array = confusion_matrix(ture_label, predict_label, normalize='all' )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment