Created
May 9, 2020 12:51
-
-
Save VincentTatan/b833c2c2997e2b0fe6e407511576bef6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_confusion_matrix(cm_list,target_names,title_list,cmap=None,normalize=True,float_format_str='{:,.2f}'): | |
plt.figure(figsize=(10,5)) | |
print('{}_count={:d}\n{}_count={:d}'.format(target_names[0],cm_list[0][0].sum(),target_names[1],cm_list[0][1].sum())) | |
stats_list = [] | |
for i in range(len(cm_list)): | |
model_name = title_list[i] | |
cm = cm_list[i] | |
actual_phishy= cm[0] | |
actual_benign= cm[1] | |
TP = actual_phishy[0] | |
FN = actual_phishy[1] | |
FP = actual_benign[0] | |
TN = actual_benign[1] | |
accuracy = np.trace(cm) / float(np.sum(cm)) | |
misclass = 1 - accuracy | |
precision = TP/float(TP+FP) | |
recall = TP/float(TP+FN) | |
fn_rate = FN/float(TN+FN) | |
fp_rate = FP/float(TP+FP) | |
if cmap is None: | |
cmap = plt.get_cmap('Blues') | |
plt.subplot(1, len(cm_list), i+1) | |
plt.imshow(cm, interpolation='nearest', cmap=cmap) | |
plt.title("Confusion Matrix " + model_name, fontsize=10) | |
if target_names is not None: | |
tick_marks = np.arange(len(target_names)) | |
plt.xticks(tick_marks, target_names, rotation=45) | |
plt.yticks(tick_marks, target_names) | |
if normalize: | |
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
cm = cm.round(3) | |
thresh = cm.max() / 1.5 if normalize else cm.max() / 2 | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
plt.text(j, i, "{:,}".format(cm[i, j]), | |
horizontalalignment="center", | |
color="red", | |
fontsize=15 ) | |
plt.ylabel('True label', fontsize=10) | |
plt.xlabel('Predicted label\naccuracy={:0.3f}; misclass={:0.3f}; \nprecision={:0.3f} ; recall={:0.3f} ; \nfn_rate={:0.3f} ; fp_rate={:0.3f} '.format(accuracy, misclass,precision,recall,fn_rate,fp_rate), fontsize=10) | |
model_stats = {'model_name':model_name,'accuracy':accuracy,'misclass':misclass,'precision':precision,'recall':recall,'fn_rate':fn_rate,'fp_rate':fp_rate } | |
stats_list.append(model_stats) | |
plt.tight_layout() | |
plt.show() | |
return generate_stats_df(stats_list,float_format_str) | |
def generate_stats_df(stats_list,float_format_str): | |
pd.options.display.float_format = float_format_str.format | |
df_stats = pd.DataFrame(stats_list) | |
df_stats.set_index('model_name',inplace=True) | |
return df_stats | |
def generate_confusion_matrix(X_test, y_test, models,Xlabels=[True,False],target_names=['True','False'],normalize=False): | |
confusion_matrices = [] | |
title_list = [] | |
for model in models: | |
confusion_matrices.append(confusion_matrix(y_test,model.predict(X_test),labels=Xlabels)) | |
title_list.append(type(model).__name__) | |
plot_confusion_matrix(cm_list = confusion_matrices, | |
target_names = target_names, | |
normalize = normalize, | |
title_list = title_list, | |
float_format_str='{:,.3f}') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment