Skip to content

Instantly share code, notes, and snippets.

@lotabout
Last active March 15, 2018 01:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lotabout/ae2401b091bd7faf4ae6230666f53568 to your computer and use it in GitHub Desktop.
Save lotabout/ae2401b091bd7faf4ae6230666f53568 to your computer and use it in GitHub Desktop.
Implement decision tree
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import math
class DecisionTree(object):
def __init__(self, data, method="naive"):
"""Learn a decision tree from data and label
:data: List[List[val]], a list contains M sample, each sample is represented by a List
The last column of the sample is the label
:returns: The root of a decision tree
"""
super(DecisionTree, self).__init__()
self.method = method
self.root = self._split(data)
def _split_samples(self, samples, feature):
"""Split samples into subsets, according to the feature
:samples: List[List[val]]
:feature: Int
:returns: {val: List[data]} a dict contains the data of subsets
"""
ret = {}
for sample in samples:
val = sample[feature]
ret.setdefault(val, [])
ret[val].append(sample)
return ret
def _split(self, data, level=0):
"""recursively split the data for node
:data: List[data]
:returns: label if should stop, else a node of the tree
"""
if self._stop_now(data):
return data[0][-1]
# split the data
feature = self._get_feature(data, level)
subsets = self._split_samples(data, feature)
return {key: self._split(subset, level+1) for key, subset in subsets.items()}
def _stop_now(self, data):
"""check if we need to stop now
:data: List[data]
:returns: Boolean
"""
labels = [d[-1] for d in data]
return len(set(labels)) <= 1
def _get_feature(self, data, level):
"""Decide which feature to be used to split data
:data: List[data]
:level: Int the level of the tree
:returns: Int the dimension of the data to be used for split data
"""
if self.method == 'gain':
return self._gain_feature(data, level)
else:
return self._naive_feature(data, level)
def _entropy(self, dataset):
"""calculate the entropy of a dataset
:dataset: List[data], each data is List[val], last column is label
:returns: Float
"""
counts = {}
for data in dataset:
label = data[-1]
counts.setdefault(label, 0)
counts[label] += 1
total_num = len(dataset)
return sum([-count/total_num * math.log2(count/total_num) for count in counts.values()])
def _conditional_entropy(self, dataset, feature):
"""calculate the conditional entropy of dataset on feature
:dataset: List[data]
:feature: Int
:returns: Float
"""
subsets = self._split_samples(dataset, feature)
total_num = len(subsets)
return sum([len(subset)/total_num * self._entropy(subset) for subset in subsets.values()])
def _gain_feature(self, data, level):
dimensions = len(data[0]) - 1
entropy = self._entropy(data)
gains = [entropy - self._conditional_entropy(data, i) for i in range(dimensions)]
return gains.index(max(gains))
def _naive_feature(self, data, level):
return level
data = [['白', '富', '美', '去'],
['白', '富', '不美', '去'],
['白', '不富', '美', '犹豫'],
['白', '不富', '不美', '犹豫'],
['不白', '富', '美', '去'],
['不白', '富', '不美', '去'],
['不白', '不富', '美', '犹豫'],
['不白', '不富', '不美', '不去']]
tree = DecisionTree(data, method='gain')
print(tree.root)
tree = DecisionTree(data)
print(tree.root)
@pickfire
Copy link

我们可以使用

import collections

counts = collections.Counter(data[-1] for data in dataset)

代替

counts = {}
for data in dataset:
    label = data[-1]
    counts.setdefault(label, 0)
    counts[label] += 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment