Created
November 14, 2017 08:25
-
-
Save duarteocarmo/a4cf1ce54566e8991eceb74fa5243942 to your computer and use it in GitHub Desktop.
Classification - Decision Tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Sat Nov 11 09:16:48 2017 | |
@author: sviglios | |
""" | |
from sklearn import cross_validation, tree | |
from sklearn.metrics import confusion_matrix | |
from Project_Clean_data import raw, header, is_binary | |
from matplotlib.pyplot import figure, plot, subplot, title, xlabel, ylabel, show, clim, ion, legend, boxplot, savefig, imshow, colorbar, xticks, yticks | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import math | |
# import chest | |
X_chest = np.loadtxt('chest.txt', dtype=int) | |
final_cand = np.loadtxt('final_cand.txt', dtype=int) | |
# select attribute to predict | |
target_attribute_name = 'Dx' | |
target_index = list(header).index(target_attribute_name) | |
# prepare data | |
X = raw | |
y = X[:, target_index] | |
y = np.delete(y, final_cand) | |
X = np.delete(raw, target_index, 1) | |
X = np.delete(X, final_cand, 0) | |
attributeNames = np.delete(header, target_index) | |
N, M = X.shape | |
C = 2 | |
# Tree complexity parameter - constraint on maximum depth | |
tc = np.arange(2, 20, 1) | |
# K-fold crossvalidation | |
K = 10 | |
CV = cross_validation.KFold(N,K,shuffle=True) | |
# Initialize variable | |
Error_train = np.empty((len(tc),K)) | |
Error_test = np.empty((len(tc),K)) | |
tree_models = [] | |
err_test = [] | |
k=0 | |
for train_index, test_index in CV: | |
print('Computing CV fold: {0}/{1}..'.format(k+1,K)) | |
#initialize list for model | |
tree_models.append([]) | |
err_test.append([]) | |
# extract training and test set for current CV fold | |
X_train, y_train = X[train_index,:], y[train_index] | |
X_test, y_test = X[test_index,:], y[test_index] | |
for i, t in enumerate(tc): | |
# Fit decision tree classifier, Gini split criterion, different pruning levels | |
dtc = tree.DecisionTreeClassifier(criterion='gini', max_depth=t) | |
dtc = dtc.fit(X_train,y_train.ravel()) | |
y_est_test = dtc.predict(X_test) | |
y_est_train = dtc.predict(X_train) | |
# Evaluate misclassification rate over train/test data (in this CV fold) | |
misclass_rate_test = sum(np.abs(y_est_test - y_test)) / float(len(y_est_test)) | |
misclass_rate_train = sum(np.abs(y_est_train - y_train)) / float(len(y_est_train)) | |
Error_test[i,k], Error_train[i,k] = misclass_rate_test, misclass_rate_train | |
tree_models[k].append(dtc) | |
err_test[k].append(misclass_rate_test) | |
k+=1 | |
f = figure(); f.hold(True) | |
boxplot(Error_test.T) | |
xlabel('Model complexity (max tree depth)') | |
ylabel('Test error across CV folds, K={0})'.format(K)) | |
savefig("boxplot_classification.png") | |
f = figure(); f.hold(True) | |
plot(tc, Error_train.mean(1)) | |
plot(tc, Error_test.mean(1)) | |
xlabel('Model complexity (max tree depth)') | |
ylabel('Error (misclassification rate, CV K={0})'.format(K)) | |
legend(['Error_train','Error_test'], loc=0) | |
savefig("misclassification_rate.png") | |
all_av_tree_len = [] | |
for i in Error_test.mean(1): | |
all_av_tree_len.append(i) | |
all_av_k_folds = [] | |
for i in Error_test.mean(0): | |
all_av_k_folds.append(i) | |
col_index = all_av_tree_len.index(min(all_av_tree_len)) | |
row_index = all_av_k_folds.index(min(all_av_k_folds)) | |
best_tree = tree_models[row_index][col_index] | |
# final test | |
y_chest = X_chest[:, target_index] | |
X_chest = np.delete(X_chest, target_index, 1) | |
y_est_chest = best_tree.predict(X_chest) | |
misclass_rate_chest = sum(np.abs(y_est_chest - y_chest)) / float(len(y_est_chest)) | |
final_error = y_chest - y_est_chest | |
print(misclass_rate_chest) | |
error_rate = misclass_rate_chest | |
accuracy = 1 - error_rate | |
# Compute and plot confusion matrix | |
f = figure(); f.hold(True) | |
cm = confusion_matrix(y_chest,y_est_chest) | |
imshow(cm, cmap='binary', interpolation='None'); | |
colorbar() | |
xticks(range(C)); yticks(range(C)); | |
xlabel('Predicted class'); ylabel('Actual class'); | |
title('Accuracy: {0}, Error Rate: {1}'.format(accuracy, error_rate)); | |
savefig('con_matrix_class.png') | |
show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment