Implement decision tree
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import math | |
class DecisionTree(object): | |
def __init__(self, data, method="naive"): | |
"""Learn a decision tree from data and label | |
:data: List[List[val]], a list contains M sample, each sample is represented by a List | |
The last column of the sample is the label | |
:returns: The root of a decision tree | |
""" | |
super(DecisionTree, self).__init__() | |
self.method = method | |
self.root = self._split(data) | |
def _split_samples(self, samples, feature): | |
"""Split samples into subsets, according to the feature | |
:samples: List[List[val]] | |
:feature: Int | |
:returns: {val: List[data]} a dict contains the data of subsets | |
""" | |
ret = {} | |
for sample in samples: | |
val = sample[feature] | |
ret.setdefault(val, []) | |
ret[val].append(sample) | |
return ret | |
def _split(self, data, level=0): | |
"""recursively split the data for node | |
:data: List[data] | |
:returns: label if should stop, else a node of the tree | |
""" | |
if self._stop_now(data): | |
return data[0][-1] | |
# split the data | |
feature = self._get_feature(data, level) | |
subsets = self._split_samples(data, feature) | |
return {key: self._split(subset, level+1) for key, subset in subsets.items()} | |
def _stop_now(self, data): | |
"""check if we need to stop now | |
:data: List[data] | |
:returns: Boolean | |
""" | |
labels = [d[-1] for d in data] | |
return len(set(labels)) <= 1 | |
def _get_feature(self, data, level): | |
"""Decide which feature to be used to split data | |
:data: List[data] | |
:level: Int the level of the tree | |
:returns: Int the dimension of the data to be used for split data | |
""" | |
if self.method == 'gain': | |
return self._gain_feature(data, level) | |
else: | |
return self._naive_feature(data, level) | |
def _entropy(self, dataset): | |
"""calculate the entropy of a dataset | |
:dataset: List[data], each data is List[val], last column is label | |
:returns: Float | |
""" | |
counts = {} | |
for data in dataset: | |
label = data[-1] | |
counts.setdefault(label, 0) | |
counts[label] += 1 | |
total_num = len(dataset) | |
return sum([-count/total_num * math.log2(count/total_num) for count in counts.values()]) | |
def _conditional_entropy(self, dataset, feature): | |
"""calculate the conditional entropy of dataset on feature | |
:dataset: List[data] | |
:feature: Int | |
:returns: Float | |
""" | |
subsets = self._split_samples(dataset, feature) | |
total_num = len(subsets) | |
return sum([len(subset)/total_num * self._entropy(subset) for subset in subsets.values()]) | |
def _gain_feature(self, data, level): | |
dimensions = len(data[0]) - 1 | |
entropy = self._entropy(data) | |
gains = [entropy - self._conditional_entropy(data, i) for i in range(dimensions)] | |
return gains.index(max(gains)) | |
def _naive_feature(self, data, level): | |
return level | |
data = [['白', '富', '美', '去'], | |
['白', '富', '不美', '去'], | |
['白', '不富', '美', '犹豫'], | |
['白', '不富', '不美', '犹豫'], | |
['不白', '富', '美', '去'], | |
['不白', '富', '不美', '去'], | |
['不白', '不富', '美', '犹豫'], | |
['不白', '不富', '不美', '不去']] | |
tree = DecisionTree(data, method='gain') | |
print(tree.root) | |
tree = DecisionTree(data) | |
print(tree.root) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This comment has been minimized.
我们可以使用
代替