Last active
June 21, 2020 06:54
-
-
Save amansk2050/26567e48cd98e2f46e2ce15d0c5e6950 to your computer and use it in GitHub Desktop.
Code for decision tree classifier in Scikit-learn python.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#importing libraries | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
%matplotlib inline | |
#loading dataset | |
dataset = pd.read_csv("/kaggle/input/decision-tree-data-set-from-stack-abuse/bill_authentication.csv") | |
#data analysis | |
dataset.shape | |
dataset.head() | |
#prepraing the data | |
X = dataset.drop('Class', axis=1) | |
y = dataset['Class'] | |
from sklearn.model_selection import train_test_split | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20) | |
#training and making predictions | |
from sklearn.tree import DecisionTreeClassifier | |
classifier = DecisionTreeClassifier() | |
classifier.fit(X_train, y_train) | |
y_pred = classifier.predict(X_test) | |
#evaluating the algorithm | |
from sklearn.metrics import classification_report, confusion_matrix | |
print(confusion_matrix(y_test, y_pred)) | |
print(classification_report(y_test, y_pred)) | |
from sklearn import metrics | |
print("Accuracy:",metrics.accuracy_score(y_test, y_pred)) | |
#some installation for printing the tree | |
pip install graphviz | |
pip install pydotplus | |
pip install --upgrade scikit-learn==0.20.3 | |
#code for printing the tree | |
from sklearn.tree import export_graphviz | |
from sklearn.externals.six import StringIO | |
from IPython.display import Image | |
import pydotplus | |
dot_data = StringIO() | |
export_graphviz(classifier, out_file=dot_data, | |
filled=True, rounded=True, | |
special_characters=True,feature_names = ['Variance','Skewness','Curtosis','Entropy'],class_names=['0','1']) | |
graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) | |
graph.write_png('diabetes.png') | |
Image(graph.create_png()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment