Created
November 11, 2014 07:37
-
-
Save richard-to/c63300e0d1e54533ab1b to your computer and use it in GitHub Desktop.
Decision tree implementation on examples from ML in Action by Peter Harrington
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
def createDataSet(): | |
dataSet = [ | |
[1, 1, 'yes'], | |
[1, 1, 'yes'], | |
[1, 0, 'no'], | |
[0, 1, 'no'], | |
[0, 1, 'no'], | |
] | |
labels = ['no surfacing', 'flippers'] | |
return dataSet, labels | |
def calcShannonEntropy(dataSet): | |
total = float(len(dataSet)) | |
labels = {} | |
for data in dataSet: | |
if data[-1] not in labels: | |
labels[data[-1]] = 0.0 | |
labels[data[-1]] += 1 | |
entropy = 0 | |
for count in labels.values(): | |
entropy -= (count/total) * math.log(count/total, 2) | |
return entropy | |
def splitDataSet(dataSet, axis, value): | |
subSet = [] | |
for data in dataSet: | |
if data[axis] == value: | |
subSet.append(data[:axis] + data[axis + 1:]) | |
return subSet | |
def chooseBestFeatureToSplit(dataSet): | |
total = float(len(dataSet)) | |
features = [{} for i in xrange(len(dataSet[0][:-1]))] | |
baseGain = calcShannonEntropy(dataSet) | |
infoGain = [0] * len(features) | |
for data in dataSet: | |
result = data[-1] | |
for i, feature in enumerate(data[:-1]): | |
if feature not in features[i]: | |
features[i][feature] = [] | |
features[i][feature].append((feature, result)) | |
for i, feature in enumerate(features): | |
entropy = 0 | |
for branch in feature.values(): | |
entropy += len(branch) / total * calcShannonEntropy(branch) | |
infoGain[i] = baseGain - entropy | |
maxGain = max(infoGain) | |
return infoGain.index(maxGain) | |
def createTree(dataSet, labels): | |
classifications = [example[-1] for example in dataSet] | |
if all(classifications[0] == classification for classification in classifications): | |
return classifications[0] | |
if len(dataSet[0]) == 1: | |
return max(set(classifications), key=classifications.count) | |
bestFeature = chooseBestFeatureToSplit(dataSet) | |
tree = {labels[bestFeature]: {}} | |
for value in set([example[bestFeature] for example in dataSet]): | |
subset = splitDataSet(dataSet, bestFeature, value) | |
tree[labels[bestFeature]][value] = createTree( | |
subset, labels[:bestFeature] + labels[bestFeature + 1:]) | |
return tree | |
def classify(inputData, tree, labels): | |
if not isinstance(tree, dict): | |
return tree | |
label = tree.keys()[0] | |
labelIndex = labels.index(label) | |
return classify(inputData, tree[label][inputData[labelIndex]], labels) | |
dataSet, labels = createDataSet() | |
tree = createTree(dataSet, labels) | |
print classify([1, 0], tree, labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment