Instantly share code, notes, and snippets.

Last active March 15, 2018 01:42
Show Gist options
• Save lotabout/ae2401b091bd7faf4ae6230666f53568 to your computer and use it in GitHub Desktop.
Implement decision tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 #!/usr/bin/env python3 # -*- coding: utf-8 -*- import math class DecisionTree(object): def __init__(self, data, method="naive"): """Learn a decision tree from data and label :data: List[List[val]], a list contains M sample, each sample is represented by a List The last column of the sample is the label :returns: The root of a decision tree """ super(DecisionTree, self).__init__() self.method = method self.root = self._split(data) def _split_samples(self, samples, feature): """Split samples into subsets, according to the feature :samples: List[List[val]] :feature: Int :returns: {val: List[data]} a dict contains the data of subsets """ ret = {} for sample in samples: val = sample[feature] ret.setdefault(val, []) ret[val].append(sample) return ret def _split(self, data, level=0): """recursively split the data for node :data: List[data] :returns: label if should stop, else a node of the tree """ if self._stop_now(data): return data[0][-1] # split the data feature = self._get_feature(data, level) subsets = self._split_samples(data, feature) return {key: self._split(subset, level+1) for key, subset in subsets.items()} def _stop_now(self, data): """check if we need to stop now :data: List[data] :returns: Boolean """ labels = [d[-1] for d in data] return len(set(labels)) <= 1 def _get_feature(self, data, level): """Decide which feature to be used to split data :data: List[data] :level: Int the level of the tree :returns: Int the dimension of the data to be used for split data """ if self.method == 'gain': return self._gain_feature(data, level) else: return self._naive_feature(data, level) def _entropy(self, dataset): """calculate the entropy of a dataset :dataset: List[data], each data is List[val], last column is label :returns: Float """ counts = {} for data in dataset: label = data[-1] counts.setdefault(label, 0) counts[label] += 1 total_num = len(dataset) return sum([-count/total_num * math.log2(count/total_num) for count in counts.values()]) def _conditional_entropy(self, dataset, feature): """calculate the conditional entropy of dataset on feature :dataset: List[data] :feature: Int :returns: Float """ subsets = self._split_samples(dataset, feature) total_num = len(subsets) return sum([len(subset)/total_num * self._entropy(subset) for subset in subsets.values()]) def _gain_feature(self, data, level): dimensions = len(data[0]) - 1 entropy = self._entropy(data) gains = [entropy - self._conditional_entropy(data, i) for i in range(dimensions)] return gains.index(max(gains)) def _naive_feature(self, data, level): return level data = [['白', '富', '美', '去'], ['白', '富', '不美', '去'], ['白', '不富', '美', '犹豫'], ['白', '不富', '不美', '犹豫'], ['不白', '富', '美', '去'], ['不白', '富', '不美', '去'], ['不白', '不富', '美', '犹豫'], ['不白', '不富', '不美', '不去']] tree = DecisionTree(data, method='gain') print(tree.root) tree = DecisionTree(data) print(tree.root)

### pickfire commented Mar 15, 2018

```import collections

counts = collections.Counter(data[-1] for data in dataset)```

```counts = {}
for data in dataset:
label = data[-1]
counts.setdefault(label, 0)
counts[label] += 1```