Skip to content

Instantly share code, notes, and snippets.

@juliobguedes
Last active January 20, 2022 15:03
Show Gist options
  • Save juliobguedes/00897a40d16863c7d9b02001ca929653 to your computer and use it in GitHub Desktop.
Save juliobguedes/00897a40d16863c7d9b02001ca929653 to your computer and use it in GitHub Desktop.
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
import itertools
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix',
cmap=plt.cm.get_cmap('Blues'), figsize=None, savefig=None):
# Função importada dos exemplos do SKLearn
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
if (figsize is not None):
fig = plt.figure(figsize=figsize)
else:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
img = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.set_title(title)
fig.colorbar(img)
tick_marks = np.arange(len(classes))
ax.set_xticks(tick_marks)
ax.set_xticklabels(classes, rotation=45)
ax.set_yticks(tick_marks)
ax.set_yticklabels(classes, rotation=45)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
ax.set_ylabel('True label')
ax.set_xlabel('Predicted label')
fig.tight_layout()
if (savefig is not None):
fig.savefig(savefig, dpi=300)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment