Last active
February 9, 2018 16:32
-
-
Save youchan/185c56328e4244f1aaa02017bd6d4897 to your computer and use it in GitHub Desktop.
Decision Tree in Ruby
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "datasets" | |
class Array | |
def avg(&block) | |
return 0.0 if self.length == 0 | |
self.sum(&block) / self.length | |
end | |
end | |
class Node | |
attr_reader :parent, :data, :name, :condition | |
attr_accessor :ltree, :rtree | |
@@count = 0 | |
def initialize(data, condition, parent) | |
@condition = condition | |
@parent = parent | |
@data = data | |
@name = "node#{@@count}" | |
@@count += 1 | |
end | |
def to_s | |
@condition + "\n" + | |
@data.group_by(&:class).map{|k, v| "#{k}: #{v.length}" }.join(', ') | |
end | |
end | |
class DecisionTree | |
def initialize(data) | |
@data = data | |
@class_names = data.group_by(&:class).keys | |
end | |
def average_entropy(data) | |
n = data.length | |
data.group_by(&:class).map do |c, list| | |
p = list.length.to_f / n | |
- p * Math.log2(p) | |
end.sum / @class_names.length | |
end | |
def split_min_entropy(data) | |
(min_sum_of_entropy, i) = (1..(data.length - 1)).reduce([1.0, 0]) do |(min, i), n| | |
sum_of_entropy = average_entropy(data.take(n)) + average_entropy(data.last(data.length - n)) | |
sum_of_entropy <= min ? [sum_of_entropy, n] : [min, i] | |
end | |
{ ltree: data.take(i), rtree: data.last(data.length - i), min_sum_of_entropy: min_sum_of_entropy } | |
end | |
def build(data = @data, condition="", parent = nil) | |
node = Node.new(data, condition, parent) | |
return node if data.group_by(&:class).length == 1 | |
return node if parent && (average_entropy(parent.data) - average_entropy(data)) < 0.1 | |
split = [:sepal_length, :sepal_width, :petal_length, :petal_width].reduce({min_sum_of_entropy: 1.0}) do |memo, param_name| | |
result = split_min_entropy(data.sort_by(¶m_name)) | |
result[:min_sum_of_entropy] < memo[:min_sum_of_entropy] ? result.merge(param_name: param_name) : memo | |
end | |
return node if split[:ltree].length == 0 || split[:rtree].length == 0 | |
mid = (split[:ltree].last.yield_self(&split[:param_name]) + split[:rtree].first.yield_self(&split[:param_name])) / 2 | |
node.ltree = build(split[:ltree], "#{split[:param_name]} < #{mid}", node) | |
node.rtree = build(split[:rtree], "#{split[:param_name]} >= #{mid}", node) | |
node | |
end | |
end | |
class TreeTraverser | |
def initialize(tree) | |
@tree = tree | |
end | |
def traverse(tree = @tree, depth = 0, &block) | |
block.call(tree, depth) | |
traverse(tree.ltree, depth + 1, &block) if tree.ltree | |
traverse(tree.rtree, depth + 1, &block) if tree.rtree | |
end | |
end | |
iris = Datasets::Iris.new | |
data = iris.each.to_a | |
# DecisionTree.new(data).split_min_entropy(data.sort_by(&:petal_length)) | |
tree = DecisionTree.new(data).build | |
puts <<EOD | |
digraph decision_tree { | |
graph [ | |
charset = "UTF-8"; | |
label = "Decision Tree", | |
labelloc = "t", | |
labeljust = "c", | |
bgcolor = "#343434", | |
fontcolor = white, | |
fontsize = 18, | |
style = "filled", | |
rankdir = TB, | |
margin = 0.2, | |
splines = spline, | |
ranksep = 1.0, | |
nodesep = 0.9 | |
]; | |
node [ | |
colorscheme = "rdylgn11" | |
style = "solid,filled", | |
fontsize = 16, | |
fontcolor = 6, | |
fontname = "Migu 1M", | |
color = 7, | |
fillcolor = 11, | |
fixedsize = true, | |
height = 1.0, | |
width = 2.0, | |
shape = "box" | |
]; | |
edge [ | |
style = solid, | |
fontsize = 14, | |
fontcolor = white, | |
fontname = "Migu 1M", | |
color = white, | |
labelfloat = true, | |
labeldistance = 2.5, | |
labelangle = 70 | |
]; | |
EOD | |
TreeTraverser.new(tree).traverse do |node, depth| | |
if node.parent | |
puts "#{node.parent.name} -> #{node.name} [label = \"#{node.condition}\"];" | |
end | |
puts "#{node.name} [label = \"#{node.data.group_by(&:class).map{|k, v| "#{k}: #{v.length}" }.join("\n")}\"];" | |
end | |
puts "}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
http://blog.youchan.org/2018-02-10