Skip to content

Instantly share code, notes, and snippets.

@amansk2050
Last active June 21, 2020 06:54
Show Gist options
  • Save amansk2050/26567e48cd98e2f46e2ce15d0c5e6950 to your computer and use it in GitHub Desktop.
Save amansk2050/26567e48cd98e2f46e2ce15d0c5e6950 to your computer and use it in GitHub Desktop.
Code for decision tree classifier in Scikit-learn python.
#importing libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
#loading dataset
dataset = pd.read_csv("/kaggle/input/decision-tree-data-set-from-stack-abuse/bill_authentication.csv")
#data analysis
dataset.shape
dataset.head()
#prepraing the data
X = dataset.drop('Class', axis=1)
y = dataset['Class']
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20)
#training and making predictions
from sklearn.tree import DecisionTreeClassifier
classifier = DecisionTreeClassifier()
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
#evaluating the algorithm
from sklearn.metrics import classification_report, confusion_matrix
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))
from sklearn import metrics
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
#some installation for printing the tree
pip install graphviz
pip install pydotplus
pip install --upgrade scikit-learn==0.20.3
#code for printing the tree
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image
import pydotplus
dot_data = StringIO()
export_graphviz(classifier, out_file=dot_data,
filled=True, rounded=True,
special_characters=True,feature_names = ['Variance','Skewness','Curtosis','Entropy'],class_names=['0','1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('diabetes.png')
Image(graph.create_png())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment