Skip to content

Instantly share code, notes, and snippets.

@VincentTatan
Created May 9, 2020 12:51
Show Gist options
  • Save VincentTatan/b833c2c2997e2b0fe6e407511576bef6 to your computer and use it in GitHub Desktop.
Save VincentTatan/b833c2c2997e2b0fe6e407511576bef6 to your computer and use it in GitHub Desktop.
def plot_confusion_matrix(cm_list,target_names,title_list,cmap=None,normalize=True,float_format_str='{:,.2f}'):
plt.figure(figsize=(10,5))
print('{}_count={:d}\n{}_count={:d}'.format(target_names[0],cm_list[0][0].sum(),target_names[1],cm_list[0][1].sum()))
stats_list = []
for i in range(len(cm_list)):
model_name = title_list[i]
cm = cm_list[i]
actual_phishy= cm[0]
actual_benign= cm[1]
TP = actual_phishy[0]
FN = actual_phishy[1]
FP = actual_benign[0]
TN = actual_benign[1]
accuracy = np.trace(cm) / float(np.sum(cm))
misclass = 1 - accuracy
precision = TP/float(TP+FP)
recall = TP/float(TP+FN)
fn_rate = FN/float(TN+FN)
fp_rate = FP/float(TP+FP)
if cmap is None:
cmap = plt.get_cmap('Blues')
plt.subplot(1, len(cm_list), i+1)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title("Confusion Matrix " + model_name, fontsize=10)
if target_names is not None:
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm = cm.round(3)
thresh = cm.max() / 1.5 if normalize else cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, "{:,}".format(cm[i, j]),
horizontalalignment="center",
color="red",
fontsize=15 )
plt.ylabel('True label', fontsize=10)
plt.xlabel('Predicted label\naccuracy={:0.3f}; misclass={:0.3f}; \nprecision={:0.3f} ; recall={:0.3f} ; \nfn_rate={:0.3f} ; fp_rate={:0.3f} '.format(accuracy, misclass,precision,recall,fn_rate,fp_rate), fontsize=10)
model_stats = {'model_name':model_name,'accuracy':accuracy,'misclass':misclass,'precision':precision,'recall':recall,'fn_rate':fn_rate,'fp_rate':fp_rate }
stats_list.append(model_stats)
plt.tight_layout()
plt.show()
return generate_stats_df(stats_list,float_format_str)
def generate_stats_df(stats_list,float_format_str):
pd.options.display.float_format = float_format_str.format
df_stats = pd.DataFrame(stats_list)
df_stats.set_index('model_name',inplace=True)
return df_stats
def generate_confusion_matrix(X_test, y_test, models,Xlabels=[True,False],target_names=['True','False'],normalize=False):
confusion_matrices = []
title_list = []
for model in models:
confusion_matrices.append(confusion_matrix(y_test,model.predict(X_test),labels=Xlabels))
title_list.append(type(model).__name__)
plot_confusion_matrix(cm_list = confusion_matrices,
target_names = target_names,
normalize = normalize,
title_list = title_list,
float_format_str='{:,.3f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment