Skip to content

Instantly share code, notes, and snippets.

@richard-to
Created November 11, 2014 07:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save richard-to/c63300e0d1e54533ab1b to your computer and use it in GitHub Desktop.
Save richard-to/c63300e0d1e54533ab1b to your computer and use it in GitHub Desktop.
Decision tree implementation on examples from ML in Action by Peter Harrington
import math
def createDataSet():
dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no'],
]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calcShannonEntropy(dataSet):
total = float(len(dataSet))
labels = {}
for data in dataSet:
if data[-1] not in labels:
labels[data[-1]] = 0.0
labels[data[-1]] += 1
entropy = 0
for count in labels.values():
entropy -= (count/total) * math.log(count/total, 2)
return entropy
def splitDataSet(dataSet, axis, value):
subSet = []
for data in dataSet:
if data[axis] == value:
subSet.append(data[:axis] + data[axis + 1:])
return subSet
def chooseBestFeatureToSplit(dataSet):
total = float(len(dataSet))
features = [{} for i in xrange(len(dataSet[0][:-1]))]
baseGain = calcShannonEntropy(dataSet)
infoGain = [0] * len(features)
for data in dataSet:
result = data[-1]
for i, feature in enumerate(data[:-1]):
if feature not in features[i]:
features[i][feature] = []
features[i][feature].append((feature, result))
for i, feature in enumerate(features):
entropy = 0
for branch in feature.values():
entropy += len(branch) / total * calcShannonEntropy(branch)
infoGain[i] = baseGain - entropy
maxGain = max(infoGain)
return infoGain.index(maxGain)
def createTree(dataSet, labels):
classifications = [example[-1] for example in dataSet]
if all(classifications[0] == classification for classification in classifications):
return classifications[0]
if len(dataSet[0]) == 1:
return max(set(classifications), key=classifications.count)
bestFeature = chooseBestFeatureToSplit(dataSet)
tree = {labels[bestFeature]: {}}
for value in set([example[bestFeature] for example in dataSet]):
subset = splitDataSet(dataSet, bestFeature, value)
tree[labels[bestFeature]][value] = createTree(
subset, labels[:bestFeature] + labels[bestFeature + 1:])
return tree
def classify(inputData, tree, labels):
if not isinstance(tree, dict):
return tree
label = tree.keys()[0]
labelIndex = labels.index(label)
return classify(inputData, tree[label][inputData[labelIndex]], labels)
dataSet, labels = createDataSet()
tree = createTree(dataSet, labels)
print classify([1, 0], tree, labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment