Implement decision tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import math | |
class DecisionTree(object): | |
def __init__(self, data, method="naive"): | |
"""Learn a decision tree from data and label | |
:data: List[List[val]], a list contains M sample, each sample is represented by a List | |
The last column of the sample is the label | |
:returns: The root of a decision tree | |
""" | |
super(DecisionTree, self).__init__() | |
self.method = method | |
self.root = self._split(data) | |
def _split_samples(self, samples, feature): | |
"""Split samples into subsets, according to the feature | |
:samples: List[List[val]] | |
:feature: Int | |
:returns: {val: List[data]} a dict contains the data of subsets | |
""" | |
ret = {} | |
for sample in samples: | |
val = sample[feature] | |
ret.setdefault(val, []) | |
ret[val].append(sample) | |
return ret | |
def _split(self, data, level=0): | |
"""recursively split the data for node | |
:data: List[data] | |
:returns: label if should stop, else a node of the tree | |
""" | |
if self._stop_now(data): | |
return data[0][-1] | |
# split the data | |
feature = self._get_feature(data, level) | |
subsets = self._split_samples(data, feature) | |
return {key: self._split(subset, level+1) for key, subset in subsets.items()} | |
def _stop_now(self, data): | |
"""check if we need to stop now | |
:data: List[data] | |
:returns: Boolean | |
""" | |
labels = [d[-1] for d in data] | |
return len(set(labels)) <= 1 | |
def _get_feature(self, data, level): | |
"""Decide which feature to be used to split data | |
:data: List[data] | |
:level: Int the level of the tree | |
:returns: Int the dimension of the data to be used for split data | |
""" | |
if self.method == 'gain': | |
return self._gain_feature(data, level) | |
else: | |
return self._naive_feature(data, level) | |
def _entropy(self, dataset): | |
"""calculate the entropy of a dataset | |
:dataset: List[data], each data is List[val], last column is label | |
:returns: Float | |
""" | |
counts = {} | |
for data in dataset: | |
label = data[-1] | |
counts.setdefault(label, 0) | |
counts[label] += 1 | |
total_num = len(dataset) | |
return sum([-count/total_num * math.log2(count/total_num) for count in counts.values()]) | |
def _conditional_entropy(self, dataset, feature): | |
"""calculate the conditional entropy of dataset on feature | |
:dataset: List[data] | |
:feature: Int | |
:returns: Float | |
""" | |
subsets = self._split_samples(dataset, feature) | |
total_num = len(subsets) | |
return sum([len(subset)/total_num * self._entropy(subset) for subset in subsets.values()]) | |
def _gain_feature(self, data, level): | |
dimensions = len(data[0]) - 1 | |
entropy = self._entropy(data) | |
gains = [entropy - self._conditional_entropy(data, i) for i in range(dimensions)] | |
return gains.index(max(gains)) | |
def _naive_feature(self, data, level): | |
return level | |
data = [['白', '富', '美', '去'], | |
['白', '富', '不美', '去'], | |
['白', '不富', '美', '犹豫'], | |
['白', '不富', '不美', '犹豫'], | |
['不白', '富', '美', '去'], | |
['不白', '富', '不美', '去'], | |
['不白', '不富', '美', '犹豫'], | |
['不白', '不富', '不美', '不去']] | |
tree = DecisionTree(data, method='gain') | |
print(tree.root) | |
tree = DecisionTree(data) | |
print(tree.root) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
我们可以使用
代替