Skip to content

Instantly share code, notes, and snippets.

@hoenirvili
Last active November 7, 2017 13:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hoenirvili/081383c0b18bc9c652b82f42130aa683 to your computer and use it in GitHub Desktop.
Save hoenirvili/081383c0b18bc9c652b82f42130aa683 to your computer and use it in GitHub Desktop.
id3 algorithm
usoara mirositoare arepete neteda comestibila
1 0 0 0 1
1 0 1 0 1
0 1 0 1 1
0 0 0 1 0
1 1 1 0 0
1 0 1 1 0
1 0 0 1 0
0 1 0 0 0
#!/usr/bin/env python3
import pandas as pd
import sklearn
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
import graphviz
def retrieve_target_names(dataset):
target = dataset['comestibila'].sort_values().values
return (target, np.array(['necomestibila','comestibila']))
def main():
mushrooms = pd.read_csv('data.csv')
feature_names = mushrooms.columns.tolist()[:4]
target, target_names = retrieve_target_names(mushrooms)
data = mushrooms[feature_names].values
classifier = tree.DecisionTreeClassifier(criterion='entropy')
classifier.fit(data, target)
dot_data = tree.export_graphviz(classifier, out_file=None,
feature_names=feature_names,
class_names=target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("mushrooms")
if __name__ == '__main__':
main()
#!/usr/bin/env python3
import csv
import sys
import math
import copy
def entropy(partition):
"""
partition : list [2,5]
"""
total = sum(partition)
entropy = 0
# compute the entropy of all elements
for element in partition:
if element == 0:
continue
p = element/total # compute the probability
entropy -= (p * math.log2(p))
return entropy
def ig(partitions):
"""
partitions: list [ [1,2], [2, 5] ]
"""
n = len(partitions)
# make the root partition of all all the children
root = [0 for l in range(0, n)]
for column, partition in enumerate(partitions):
for row, _ in enumerate(partition):
root[column] += partitions[row][column]
# compute the root entropy
root_entropy = entropy(root)
# get the number of instances of the decision stamp
instances = sum(root)
avg_entropy = 0
for partition in partitions:
part_sum = sum(partition)
avg_entropy = avg_entropy + ((part_sum/instances) * entropy(partition))
return root_entropy - avg_entropy
def dmap(attributes, data):
"""
atributes: ['usoara', 'mirositoare', 'arepete', 'neteda', 'comestibila']
data:
['1', '0', '0', '0', '1']
['1', '0', '1', '0', '1']
['0', '1', '0', '1', '1']
['0', '0', '0', '1', '0']
['1', '1', '1', '0', '0']
['1', '0', '1', '1', '0']
['1', '0', '0', '1', '0']
['0', '1', '0', '0', '0']
res: key=>value
usoara: ['1', '1', '0', '0', '1', '1', '1', '0']
mirositoare: ['0', '0', '1', '0', '1', '0', '0', '1']
arepete: ['0', '1', '0', '0', '1', '1', '0', '0']
neteda: ['0', '0', '1', '1', '0', '1', '1', '0']
comestibila: ['1', '1', '1', '0', '0', '0', '0', '0']
"""
m = len(data[0])
n = len(data)
columns = []
for j in range(0, m):
column = []
for i in range(0, n):
column.append(data[i][j])
columns.append(column)
return dict(zip(attributes, columns))
def decision_stamps(dmap):
"""
dmap:
usoara: ['1', '1', '0', '0', '1', '1', '1', '0']
mirositoare: ['0', '0', '1', '0', '1', '0', '0', '1']
arepete: ['0', '1', '0', '0', '1', '1', '0', '0']
neteda: ['0', '0', '1', '1', '0', '1', '1', '0']
comestibila: ['1', '1', '1', '0', '0', '0', '0', '0']
res:
{
'usoara': {
'0': {'0': 2, '1': 1},
'1': {'0': 3, '1': 2}
},
'mirositoare': {
'0': {'0': 3, '1': 2},
'1': {'0': 2, '1': 1}
},
'arepete': {
'0': {'0': 3, '1': 2},
'1': {'0': 2, '1': 1}
},
'neteda': {
'0': {'0': 2, '1': 2},
'1': {'0': 3, '1': 1}
}
}
"""
ds = {}
decision_key = [*dmap][-1]
decisions = dmap[decision_key]
for key, value in dmap.items():
if key == decision_key:
continue
ds[key] = {}
unique_values = list(set(value))
for uv in unique_values:
subset = []
for i, v in enumerate(value):
if v == uv:
subset += [decisions[i]]
unique_decision = list(set(decisions))
uv_l = []
for d in unique_decision:
uv_l.append(subset.count(d))
ds[key][uv] = dict(zip(unique_decision, uv_l))
return ds
def all_partitions(decision_stamps):
"""
decision_stamps:
{
'usoara': {
'0': {'0': 2, '1': 1},
'1': {'0': 3, '1': 2}
},
'mirositoare': {
'0': {'0': 3, '1': 2},
'1': {'0': 2, '1': 1}
},
'arepete': {
'0': {'0': 3, '1': 2},
'1': {'0': 2, '1': 1}
},
'neteda': {
'0': {'0': 2, '1': 2},
'1': {'0': 3, '1': 1}
}
}
res:
{
'usoara': [[2, 3], [1, 2]],
'mirositoare': [[1, 2], [2, 3]],
'arepete': [[1, 2], [2, 3]],
'neteda': [[1, 3], [2, 2]]
}
"""
partitions = {}
for key, decision_stamp in decision_stamps.items():
partitions[key] = []
for partition in decision_stamp.values():
p = [*partition.values()]
partitions[key].append(p)
return partitions
def best_attribute(partitions):
"""
partitions:
{
'usoara': [[2, 3], [1, 2]],
'mirositoare': [[1, 2], [2, 3]],
'arepete': [[1, 2], [2, 3]],
'neteda': [[1, 3], [2, 2]]
}
res:
neteda
"""
if len(partitions.values()) == 2:
return [*partitions.keys()][0]
max_ig = 0
max_ig_name = ''
for key, values in partitions.items():
a = ig(values)
if a > max_ig:
max_ig = a
max_ig_name = key
return max_ig_name
def filter_data(data, attributes, attribute, attribute_value):
"""
using data, attributes, target_attribute and target_attribute_value
filter based on traget_attribute_value the data and return it's subset
"""
s = copy.deepcopy(data)
subset = None
index = attributes.index(attribute)
subset = []
for row in s:
if row[index] == attribute_value:
subset.append(row)
s = subset
return subset
def remove_column(data, col):
"""
remove the hole column in our newly copied data
"""
d = copy.deepcopy(data)
for row in d:
row.pop(col)
return d
def pick_best_attribute(attributes, data):
"""
for the attributes and data given compute pick, select
the best attribute that has the max information gain
"""
# if we are dealing with just tow attributes
# this means we have just only one attribute to classify
# and we return that exact attribute
if len(attributes) == 2:
attribute = attributes[0]
dmapp = dmap(attribute, data)
ds = decision_stamps(dmapp)
# return attribute and his decision stamp
return (attribute, ds)
# make a mapping out of all attributes and data
dmapp = dmap(attributes, data)
# create decision stamps of the mapping
ds = decision_stamps(dmapp)
# return just the partitions of the decision stamps
parts = all_partitions(ds)
# for all the partitions compute the best attribute
attribute = best_attribute(parts)
# return the best attribute and his decision stamp
return (attribute, ds)
def pick_subset(data, attributes, attribute, vertice):
"""
for the given data, attributes, target_attribute, and his
vertice
pick the subset / instances that has includes the target_attribute
vertice and return the subset and the corresponding attribute subset as a tuple
"""
# filter the data based on the attribute and vertice
data = filter_data(data, attributes, attribute, vertice)
# because attributes is a list, make a copy and preserve the original list
attr = copy.copy(attributes)
# compute the index and remove the attribute
# from the list of attributes
idx = attr.index(attribute) # compute the index of the attribute in attributes
attr.remove(attribute) # remove the attribute from the list
# after the filter_data process we should also
# remove the column that contains the vertice value
data = remove_column(data, idx)
# return the subset pair data
return (data, attr)
class Node:
"""
Node represents a single decision node in our Id3
tree. This will hold the attribute name, his neighbours and
his decision stamp
"""
def __init__(self, attribute=None, stamp=None):
if stamp == None and attribute == None:
self.attribute = None
self.stamp = None
self.neighbours = None
return
self.attribute = attribute
self.stamp = stamp
self.neighbours = {}
# for every stamp we have we should make now the decisions
# of every vertices and if we can't make the decision we should
# add in self.neighbours a None value
for vertice, s in self.stamp.items():
self.neighbours[vertice] = self._decision(s)
def __repr__(self):
return self.__str__()
@property
def vertices(self):
return [*self.neighbours]
def __str__(self):
message = ''
message += '[NB attribute = {}, '.format(self.attribute)
message += 'stamp = {}, '.format(self.stamp)
for key, value in self.neighbours.items():
message += 'vertice:{} => decision|node {} '.format(key, value)
message += 'NE]'
return message
def _decision(self, s):
"""
for every decision_stamp value
check if we can make a decision and classify our
examples or we need to mark it as unknown for know
"""
aparitions = 0
dec = None
dict_values = s.values()
values = [*dict_values]
for v in values:
if v > 0:
dec = v
aparitions = aparitions + 1
if aparitions > 1:
# this means we don't have a partition that classify
# our instances perfectly
return None
keys = [*s.keys()] # get all keys of the dict
return keys[values.index(dec)] # get the value key in the dict
def empty(self):
return (self.stamp == None and
self.attribute == None and
self.neighbours == None)
def push(self, node):
if self.empty():
self.stamp = node.stamp
self.attribute = node.attribute
self.neighbours = node.neighbours
return
for neighbour in self.neighbours.values():
if neighbour == None:
neighbour = node
def push_neighbours(self, vertice, node):
if self.empty():
self.stamp = node.stamp
self.attribute = node.attribute
self.neighbours = node.neighbours
return
if self.neighbours[vertice] != None:
raise ValueError(
'An already decision was made for vertice {} decision {}'.
format(vertice, self.neighbours[vertice])
)
self.neighbours[vertice] = node
def neighbour(self, vertice):
return self.neighbours[vertice]
def node_enighbours_are_classified(self):
"""
Do we still have neighbours that needs
to be classified. If yes then return False
else return True
"""
if self.empty():
return False
for v in self.neighbours.values():
if v == None:
return False
return True
class Tree(object):
"""
Tree is a general purpose tree that holds
Id3 nodes
"""
# starting node
root = None
# maintain the current node
current = None
def empty(self):
return (self.root == None and self.current == None)
def classify(self, attributes, instance):
node = self.root # take the root node
decision = None
while node != None :
idx = attributes.index(node.attribute) # get the node attribute
vertice = instance[idx] # retrieve the instance vertice of that attribute
value = node.neighbour(vertice) # take the decision/node
# if this is not a Node it's an decision , take it and stop
if not isinstance(value, Node):
decision = value
break
# this means we have a node, not a decision
node = value
return decision
def push(self, new_node):
"""
push the new node maintaining
the root and current balance
"""
if self.empty():
self.root = new_node
self.current = new_node
return
# push the node to the current not classify neighbour
self.current.push(new_node)
def push_neighbours(self, vertice, new_node):
if self.empty():
raise ValueError("Can't push neighbours to a empty tree")
self.current.push_neighbours(vertice, new_node)
if not new_node.node_enighbours_are_classified():
if isinstance(value):
self.current = new_node
def id3(data, attributes):
tree = Tree()
attribute, ds = pick_best_attribute(attributes, data)
node = Node(attribute, ds[attribute])
tree.push(node)
if node.node_enighbours_are_classified():
return tree.root
for vertice in node.vertices:
subset = pick_subset(data, attributes, attribute, vertice)
node = id3(*subset)
tree.push_neighbours(vertice, node)
return tree
def make_decisions(test_data, attributes, tree):
decisions = []
for row in test_data:
decisions.append(tree.classify(attributes, row))
return decisions
def main():
if len(sys.argv) < 2 or sys.argv[1] == None:
raise ValueError("Please specify csv data file")
name = sys.argv[1]
data = None
with open(name, mode='r') as file:
r = csv.reader(file)
data = [row for row in r]
attributes = data[0]
data = data[1:]
tree = id3(data,attributes)
print(tree.root)
test_data = [[ '0', '1', '1', '1'],
[ '0', '1', '0', '1'],
[ '1', '1', '0', '0']]
decisions = make_decisions(test_data, attributes, tree)
print()
for attr in attributes:
print('{} '.format(attr), end='')
print()
for i, d in enumerate(test_data):
for k in d:
print('{} '.format(k), end='')
print('Decision: {} '.format( decisions[i]))
if __name__ == '__main__':
main()
all:
./id3.py data.csv
graph:
./graph.py data.csv
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment