Skip to content

Instantly share code, notes, and snippets.

@wut0n9
Last active May 8, 2019 11:12
Show Gist options
  • Save wut0n9/e423fa9dd09eb166e913564c4be613a3 to your computer and use it in GitHub Desktop.
Save wut0n9/e423fa9dd09eb166e913564c4be613a3 to your computer and use it in GitHub Desktop.
绘制混淆矩阵 #confusion_matrix
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
def print_confusion_matrix(true_cls, pred_cls, cls_name):
"""
true_cls、true_cls、pred_cls、cls_name都不是标签的id,而是原始标签文本
true_cls = [u'真实标签1', u'真实标签2', u'真实标签3',..]
pred_cls = [u'预测标签1', u'预测标签2', u'预测标签3',..]
cls_name: 类别标签列表,传给confusion_matrix便于重建混淆矩阵坐标轴上的标签
"""
# Get the true classifications for the test-set.
# Get the confusion matrix using sklearn.
n_classes = len(cls_name)
cm = confusion_matrix(y_true=true_cls,
y_pred=pred_cls,
labels=cls_name
)
# Print the confusion matrix as text.
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
# Make various adjustments to the plot.
plt.tight_layout()
plt.colorbar()
tick_marks = np.arange(n_classes)
plt.xticks(tick_marks, cls_name)
plt.yticks(tick_marks, cls_name)
plt.xlabel('Predicted')
plt.ylabel('True')
# Ensure the plot is shown correctly with multiple plots
# in a single Notebook cell.
plt.show()
cls_name = np.array(tid2topic.values())
print_confusion_matrix(df['topic'].tolist(), topic_pred, cls_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment