Skip to content

Instantly share code, notes, and snippets.

Created May 14, 2019 14:28
Show Gist options
  • Save gopigof/2ff6b140e6d1b623f26d742993789155 to your computer and use it in GitHub Desktop.
Save gopigof/2ff6b140e6d1b623f26d742993789155 to your computer and use it in GitHub Desktop.
Decision Tree with chi-squared pruning for ID3 Decision Tree implementation
from pprint import pprint
from scipy.stats import chi2, chi2_contingency
from sys import argv
import numpy as np
import pandas as pd
from sklearn import metrics, model_selection
file_name = argv[1]
# alpha values for chi square pruning
if len(argv) > 2:
alpha = float(argv[2])
alpha = 0.05 # 5%
def partition(a):
"""takes values of an attribute as parameter and splits it based on values
a (list):
list of values of attribute
dictionary where keys are various values of attribute and value is array of position where the values occurs
return {c: (a==c).nonzero()[0] for c in np.unique(a)}
def entropy(s):
"""takes values of an attribute as parameter and returns entropy
s (list):
list of values of attribute
res (float):
Entropy value of the attribute
res = 0
val, counts = np.unique(s, return_counts=True)
freqs = counts.astype('float')/len(s)
for p in freqs:
if p != 0.0:
res -= p * np.log2(p)
return res
def information_gain(y, x):
"""takes values of an attribute as parameter and returns entropy
y (list):
list of values of class
x (list):
list of values of attribute
res (float):
Information gain of the attribute
res = entropy(y_train)
# We partition x, according to attribute values x_i
val, counts = np.unique(x, return_counts=True)
freqs = counts.astype('float')/len(x)
# We calculate a weighted average of the entropy
for p, v in zip(freqs, val):
res -= p * entropy(y_train[x == v])
return res
def is_pure(s):
"""takes values of class as parameter and returns if all have same value
s (list):
list of values of class
True if pure else false
return len(set(s)) == 1
def recursive_split(x, y, attr_name):
"""takes examples and values of class as parameter and returns best split of the data
x (list of list):
list of examples having various attributes
y (list)
list of class values of the examples
attr_name (list of string)
list of names of the attributes
key is the best split with max information gain and values is either a class or a dictionary (decision tree)
# If there could be no split, just return the original set
if is_pure(y) or len(y) == 0:
return str(y[0])
# We get attribute that gives the highest mutual information
gain = np.array([information_gain(y, x_attr) for x_attr in x.T])
selected_attr = np.argmax(gain)
# If there's no gain at all, nothing has to be done, just return the original set
if np.all(gain < 1e-6):
return y
# We split using the selected attribute
sets = partition(x[:, selected_attr])
res = {}
for k, v in sets.items():
# create subsets of data based on the split
y_subset = y.take(v, axis=0)
x_subset = x.take(v, axis=0)
x_subset = np.delete(x_subset, selected_attr, 1)
attr_name_subset = attr_name[:selected_attr] + attr_name[selected_attr+1:]
#recurse on subset of data left
res[str(attr_name[selected_attr]) + " = " + str(k) ] = recursive_split(x_subset, y_subset, attr_name_subset)
return res
def pruneLeaves(obj):
"""takes decision tree as parameter and returns a pruned tree based on chi square
obj (dict):
obj is a decision tree encoded in the form of decision tree
obj (dict):
obj is decision tree with pruned leaves
isLeaf = True
parent = None
for key in obj:
if isinstance(obj[key], dict):
isLeaf = False
parent = key
if isLeaf and list(obj.keys())[0].split(' ')[0] not in satisfied_attributes:
global pruned
pruned = True
return 'pruned'
if not isLeaf:
if pruneLeaves(obj[parent]):
obj[parent] = None
return obj
data = np.loadtxt(file_name, delimiter=",", dtype='int')
#get first name for the attribute name
attr_name = ['Serial_No','Refr_Index', 'NA2O', 'MGO', 'AL2O3', 'SIO2', 'K2O', 'CAO', 'BAO', 'FE2O3', 'TYPE']
# #get last column for class attribute value
# y = data[...,-1][1:]
# #get rest of the data for the examples
# X = data[...,:-1]
# X = np.delete(X,0,0)
x_train, y_train, x_test, y_test = model_selection.train_test_split(data, test_size=0.2, random_state=31)
#call recursive_split to train the decision tree
tree = recursive_split(x_train, y_train, attr_name)
satisfied_attributes = []
for i in range(10):
contengency = pd.crosstab(x_train.T[i], y_train)
c, p, dof, expected = chi2_contingency(contengency)
if c > chi2.isf(q=alpha, df=dof):
print('\nDecision tree before pruning-\n')
valid_tree = tree
print('\nDecision tree after pruning-\n')
pruned = True
while pruned:
#keep pruning till leaf nodes can be pruned or till whole tree has been pruned
pruned = False
tree = pruneLeaves(tree)
if tree == 'pruned':
def test(x,y, tree):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment