Skip to content

Instantly share code, notes, and snippets.

@bufas
Last active August 29, 2015 14:07
Show Gist options
  • Save bufas/831b8225e94cd3fbf2c7 to your computer and use it in GitHub Desktop.
Save bufas/831b8225e94cd3fbf2c7 to your computer and use it in GitHub Desktop.
Barebones ID3 Classifier
__author__ = 'Mathias Bak Bertelsen'
__email__ = 'bufas@cs.au.dk'
from __future__ import division
import math
def segment_set_by_class(es):
return segment_set_by_attr(es, len(es[0])-1)
def segment_set_by_attr(es, attr_idx):
segmented = {}
for s in es:
attr_val = s[attr_idx]
if attr_val not in segmented:
segmented[attr_val] = [s]
else:
segmented[attr_val].append(s)
return segmented
def entropy(s):
return sum([-len(Sc)/len(s) * math.log(len(Sc)/len(s)) for _, Sc in segment_set_by_class(s).iteritems()])
def gain(s, attr_idx):
return entropy(s) - sum([(len(sv))/len(s) * entropy(sv) for _, sv in segment_set_by_attr(s, attr_idx).iteritems()])
def grow(train, attributes):
# Check if this should be a leaf
segmented = segment_set_by_class(train)
if len(segmented) == 1:
return '#LEAF#', segmented.keys()[0]
# Pick the root node
split_attr = (None, -1)
for i in range(len(attributes) - 1):
information_gain = gain(train, i)
if information_gain > split_attr[1]:
split_attr = (i, information_gain)
# Recurse on subtrees
# (attr, {val1: subtree, val2: subtree, val3: subtree})
subtree_representation = {}
subtrees = segment_set_by_attr(train, split_attr[0])
for val, subtree in subtrees.iteritems():
subtree_representation[val] = grow(subtree, attributes)
return split_attr[0], subtree_representation
def classify(tree, instance):
if tree[0] == '#LEAF#':
return tree[1]
return classify(tree[1][instance[tree[0]]], instance)
# USAGE
# Create a training set
a = ('outlook', 'temp', 'humidity', 'wind', 'play')
x = [('sunny', 'hot', 'high', 'weak', 'No'),
('sunny', 'hot', 'high', 'strong', 'No'),
('overcast', 'hot', 'high', 'weak', 'Yes'),
('rain', 'mild', 'high', 'weak', 'Yes'),
('rain', 'cool', 'normal', 'weak', 'Yes'),
('rain', 'cool', 'normal', 'strong', 'No'),
('overcast', 'cool', 'normal', 'strong', 'Yes'),
('sunny', 'mild', 'high', 'weak', 'No'),
('sunny', 'cold', 'normal', 'weak', 'Yes'),
('rain', 'mild', 'normal', 'weak', 'Yes'),
('sunny', 'mild', 'normal', 'strong', 'Yes'),
('overcast', 'mild', 'high', 'strong', 'Yes'),
('overcast', 'hot', 'normal', 'weak', 'Yes'),
('rain', 'mild', 'high', 'strong', 'No')]
# Train the model
my_tree = grow(x, a)
# Classify some instances
print classify(my_tree, ('sunny', 'cool', 'normal', 'strong'))
print classify(my_tree, ('overcast', 'mild', 'normal', 'weak'))
print classify(my_tree, ('rain', 'hot', 'high', 'strong'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment