Instantly share code, notes, and snippets.

Embed
What would you like to do?
Decision Tree in Ruby
require "datasets"
class Array
def avg(&block)
return 0.0 if self.length == 0
self.sum(&block) / self.length
end
end
class Node
attr_reader :parent, :data, :name, :condition
attr_accessor :ltree, :rtree
@@count = 0
def initialize(data, condition, parent)
@condition = condition
@parent = parent
@data = data
@name = "node#{@@count}"
@@count += 1
end
def to_s
@condition + "\n" +
@data.group_by(&:class).map{|k, v| "#{k}: #{v.length}" }.join(', ')
end
end
class DecisionTree
def initialize(data)
@data = data
@class_names = data.group_by(&:class).keys
end
def average_entropy(data)
n = data.length
data.group_by(&:class).map do |c, list|
p = list.length.to_f / n
- p * Math.log2(p)
end.sum / @class_names.length
end
def split_min_entropy(data)
(min_sum_of_entropy, i) = (1..(data.length - 1)).reduce([1.0, 0]) do |(min, i), n|
sum_of_entropy = average_entropy(data.take(n)) + average_entropy(data.last(data.length - n))
sum_of_entropy <= min ? [sum_of_entropy, n] : [min, i]
end
{ ltree: data.take(i), rtree: data.last(data.length - i), min_sum_of_entropy: min_sum_of_entropy }
end
def build(data = @data, condition="", parent = nil)
node = Node.new(data, condition, parent)
return node if data.group_by(&:class).length == 1
return node if parent && (average_entropy(parent.data) - average_entropy(data)) < 0.1
split = [:sepal_length, :sepal_width, :petal_length, :petal_width].reduce({min_sum_of_entropy: 1.0}) do |memo, param_name|
result = split_min_entropy(data.sort_by(&param_name))
result[:min_sum_of_entropy] < memo[:min_sum_of_entropy] ? result.merge(param_name: param_name) : memo
end
return node if split[:ltree].length == 0 || split[:rtree].length == 0
mid = (split[:ltree].last.yield_self(&split[:param_name]) + split[:rtree].first.yield_self(&split[:param_name])) / 2
node.ltree = build(split[:ltree], "#{split[:param_name]} < #{mid}", node)
node.rtree = build(split[:rtree], "#{split[:param_name]} >= #{mid}", node)
node
end
end
class TreeTraverser
def initialize(tree)
@tree = tree
end
def traverse(tree = @tree, depth = 0, &block)
block.call(tree, depth)
traverse(tree.ltree, depth + 1, &block) if tree.ltree
traverse(tree.rtree, depth + 1, &block) if tree.rtree
end
end
iris = Datasets::Iris.new
data = iris.each.to_a
# DecisionTree.new(data).split_min_entropy(data.sort_by(&:petal_length))
tree = DecisionTree.new(data).build
puts <<EOD
digraph decision_tree {
graph [
charset = "UTF-8";
label = "Decision Tree",
labelloc = "t",
labeljust = "c",
bgcolor = "#343434",
fontcolor = white,
fontsize = 18,
style = "filled",
rankdir = TB,
margin = 0.2,
splines = spline,
ranksep = 1.0,
nodesep = 0.9
];
node [
colorscheme = "rdylgn11"
style = "solid,filled",
fontsize = 16,
fontcolor = 6,
fontname = "Migu 1M",
color = 7,
fillcolor = 11,
fixedsize = true,
height = 1.0,
width = 2.0,
shape = "box"
];
edge [
style = solid,
fontsize = 14,
fontcolor = white,
fontname = "Migu 1M",
color = white,
labelfloat = true,
labeldistance = 2.5,
labelangle = 70
];
EOD
TreeTraverser.new(tree).traverse do |node, depth|
if node.parent
puts "#{node.parent.name} -> #{node.name} [label = \"#{node.condition}\"];"
end
puts "#{node.name} [label = \"#{node.data.group_by(&:class).map{|k, v| "#{k}: #{v.length}" }.join("\n")}\"];"
end
puts "}"
@youchan

This comment has been minimized.

Show comment
Hide comment
Owner

youchan commented Feb 9, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment