Last active
January 20, 2022 15:03
-
-
Save juliobguedes/00897a40d16863c7d9b02001ca929653 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics import confusion_matrix | |
from sklearn.utils.multiclass import unique_labels | |
import itertools | |
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', | |
cmap=plt.cm.get_cmap('Blues'), figsize=None, savefig=None): | |
# Função importada dos exemplos do SKLearn | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
if (figsize is not None): | |
fig = plt.figure(figsize=figsize) | |
else: | |
fig = plt.figure() | |
ax = fig.add_subplot(1, 1, 1) | |
img = ax.imshow(cm, interpolation='nearest', cmap=cmap) | |
ax.set_title(title) | |
fig.colorbar(img) | |
tick_marks = np.arange(len(classes)) | |
ax.set_xticks(tick_marks) | |
ax.set_xticklabels(classes, rotation=45) | |
ax.set_yticks(tick_marks) | |
ax.set_yticklabels(classes, rotation=45) | |
fmt = '.2f' if normalize else 'd' | |
thresh = cm.max() / 2. | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
ax.text(j, i, format(cm[i, j], fmt), | |
horizontalalignment="center", | |
color="white" if cm[i, j] > thresh else "black") | |
ax.set_ylabel('True label') | |
ax.set_xlabel('Predicted label') | |
fig.tight_layout() | |
if (savefig is not None): | |
fig.savefig(savefig, dpi=300) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment