Skip to content

Instantly share code, notes, and snippets.

@lotabout
Last active March 15, 2018 01:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lotabout/ae2401b091bd7faf4ae6230666f53568 to your computer and use it in GitHub Desktop.
Save lotabout/ae2401b091bd7faf4ae6230666f53568 to your computer and use it in GitHub Desktop.
Implement decision tree
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
class DecisionTree(object):
def __init__(self, data):
"""Learn a decision tree from data and label
:data: List[List[val]], a list contains M sample, each sample is represented by a List
The last column of the sample is the label
:returns: The root of a decision tree
"""
super(DecisionTree, self).__init__()
self.root = self._split(data)
def _split_samples(self, samples, feature):
"""Split samples into subsets, according to the feature
:samples: List[List[val]]
:feature: Int
:returns: {val: List[data]} a dict contains the data of subsets
"""
ret = {}
for sample in samples:
val = sample[feature]
ret.setdefault(val, [])
ret[val].append(sample)
return ret
def _split(self, data, level=0):
"""recursively split the data for node
:data: List[data]
:returns: label if should stop, else a node of the tree
"""
if self._stop_now(data):
return data[0][-1]
# split the data
feature = self._get_feature(data, level)
subsets = self._split_samples(data, feature)
return {key: self._split(subset, level+1) for key, subset in subsets.items()}
def _stop_now(self, data):
"""check if we need to stop now
:data: List[data]
:returns: Boolean
"""
labels = [d[-1] for d in data]
return len(set(labels)) <= 1
def _get_feature(self, data, level):
"""Decide which feature to be used to split data
:data: List[data]
:level: Int the level of the tree
:returns: Int the dimension of the data to be used for split data
"""
return level
data = [['白', '富', '美', '去'],
['白', '富', '不美', '去'],
['白', '不富', '美', '犹豫'],
['白', '不富', '不美', '犹豫'],
['不白', '富', '美', '去'],
['不白', '富', '不美', '去'],
['不白', '不富', '美', '犹豫'],
['不白', '不富', '不美', '不去']]
tree = DecisionTree(data)
print(tree.root)
@pickfire
Copy link

我们可以使用

import collections

counts = collections.Counter(data[-1] for data in dataset)

代替

counts = {}
for data in dataset:
    label = data[-1]
    counts.setdefault(label, 0)
    counts[label] += 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment