Skip to content

Instantly share code, notes, and snippets.

@youchan
Last active February 9, 2018 16:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save youchan/185c56328e4244f1aaa02017bd6d4897 to your computer and use it in GitHub Desktop.
Save youchan/185c56328e4244f1aaa02017bd6d4897 to your computer and use it in GitHub Desktop.
Decision Tree in Ruby
require "datasets"
class Array
def avg(&block)
return 0.0 if self.length == 0
self.sum(&block) / self.length
end
end
class Node
attr_reader :parent, :data, :name, :condition
attr_accessor :ltree, :rtree
@@count = 0
def initialize(data, condition, parent)
@condition = condition
@parent = parent
@data = data
@name = "node#{@@count}"
@@count += 1
end
def to_s
@condition + "\n" +
@data.group_by(&:class).map{|k, v| "#{k}: #{v.length}" }.join(', ')
end
end
class DecisionTree
def initialize(data)
@data = data
@class_names = data.group_by(&:class).keys
end
def average_entropy(data)
n = data.length
data.group_by(&:class).map do |c, list|
p = list.length.to_f / n
- p * Math.log2(p)
end.sum / @class_names.length
end
def split_min_entropy(data)
(min_sum_of_entropy, i) = (1..(data.length - 1)).reduce([1.0, 0]) do |(min, i), n|
sum_of_entropy = average_entropy(data.take(n)) + average_entropy(data.last(data.length - n))
sum_of_entropy <= min ? [sum_of_entropy, n] : [min, i]
end
{ ltree: data.take(i), rtree: data.last(data.length - i), min_sum_of_entropy: min_sum_of_entropy }
end
def build(data = @data, condition="", parent = nil)
node = Node.new(data, condition, parent)
return node if data.group_by(&:class).length == 1
return node if parent && (average_entropy(parent.data) - average_entropy(data)) < 0.1
split = [:sepal_length, :sepal_width, :petal_length, :petal_width].reduce({min_sum_of_entropy: 1.0}) do |memo, param_name|
result = split_min_entropy(data.sort_by(&param_name))
result[:min_sum_of_entropy] < memo[:min_sum_of_entropy] ? result.merge(param_name: param_name) : memo
end
return node if split[:ltree].length == 0 || split[:rtree].length == 0
mid = (split[:ltree].last.yield_self(&split[:param_name]) + split[:rtree].first.yield_self(&split[:param_name])) / 2
node.ltree = build(split[:ltree], "#{split[:param_name]} < #{mid}", node)
node.rtree = build(split[:rtree], "#{split[:param_name]} >= #{mid}", node)
node
end
end
class TreeTraverser
def initialize(tree)
@tree = tree
end
def traverse(tree = @tree, depth = 0, &block)
block.call(tree, depth)
traverse(tree.ltree, depth + 1, &block) if tree.ltree
traverse(tree.rtree, depth + 1, &block) if tree.rtree
end
end
iris = Datasets::Iris.new
data = iris.each.to_a
# DecisionTree.new(data).split_min_entropy(data.sort_by(&:petal_length))
tree = DecisionTree.new(data).build
puts <<EOD
digraph decision_tree {
graph [
charset = "UTF-8";
label = "Decision Tree",
labelloc = "t",
labeljust = "c",
bgcolor = "#343434",
fontcolor = white,
fontsize = 18,
style = "filled",
rankdir = TB,
margin = 0.2,
splines = spline,
ranksep = 1.0,
nodesep = 0.9
];
node [
colorscheme = "rdylgn11"
style = "solid,filled",
fontsize = 16,
fontcolor = 6,
fontname = "Migu 1M",
color = 7,
fillcolor = 11,
fixedsize = true,
height = 1.0,
width = 2.0,
shape = "box"
];
edge [
style = solid,
fontsize = 14,
fontcolor = white,
fontname = "Migu 1M",
color = white,
labelfloat = true,
labeldistance = 2.5,
labelangle = 70
];
EOD
TreeTraverser.new(tree).traverse do |node, depth|
if node.parent
puts "#{node.parent.name} -> #{node.name} [label = \"#{node.condition}\"];"
end
puts "#{node.name} [label = \"#{node.data.group_by(&:class).map{|k, v| "#{k}: #{v.length}" }.join("\n")}\"];"
end
puts "}"
@youchan
Copy link
Author

youchan commented Feb 9, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment