Skip to content

Instantly share code, notes, and snippets.

# youchan/decision_tree.rb Last active Feb 9, 2018

Decision Tree in Ruby
 require "datasets" class Array def avg(&block) return 0.0 if self.length == 0 self.sum(&block) / self.length end end class Node attr_reader :parent, :data, :name, :condition attr_accessor :ltree, :rtree @@count = 0 def initialize(data, condition, parent) @condition = condition @parent = parent @data = data @name = "node#{@@count}" @@count += 1 end def to_s @condition + "\n" + @data.group_by(&:class).map{|k, v| "#{k}: #{v.length}" }.join(', ') end end class DecisionTree def initialize(data) @data = data @class_names = data.group_by(&:class).keys end def average_entropy(data) n = data.length data.group_by(&:class).map do |c, list| p = list.length.to_f / n - p * Math.log2(p) end.sum / @class_names.length end def split_min_entropy(data) (min_sum_of_entropy, i) = (1..(data.length - 1)).reduce([1.0, 0]) do |(min, i), n| sum_of_entropy = average_entropy(data.take(n)) + average_entropy(data.last(data.length - n)) sum_of_entropy <= min ? [sum_of_entropy, n] : [min, i] end { ltree: data.take(i), rtree: data.last(data.length - i), min_sum_of_entropy: min_sum_of_entropy } end def build(data = @data, condition="", parent = nil) node = Node.new(data, condition, parent) return node if data.group_by(&:class).length == 1 return node if parent && (average_entropy(parent.data) - average_entropy(data)) < 0.1 split = [:sepal_length, :sepal_width, :petal_length, :petal_width].reduce({min_sum_of_entropy: 1.0}) do |memo, param_name| result = split_min_entropy(data.sort_by(¶m_name)) result[:min_sum_of_entropy] < memo[:min_sum_of_entropy] ? result.merge(param_name: param_name) : memo end return node if split[:ltree].length == 0 || split[:rtree].length == 0 mid = (split[:ltree].last.yield_self(&split[:param_name]) + split[:rtree].first.yield_self(&split[:param_name])) / 2 node.ltree = build(split[:ltree], "#{split[:param_name]} < #{mid}", node) node.rtree = build(split[:rtree], "#{split[:param_name]} >= #{mid}", node) node end end class TreeTraverser def initialize(tree) @tree = tree end def traverse(tree = @tree, depth = 0, &block) block.call(tree, depth) traverse(tree.ltree, depth + 1, &block) if tree.ltree traverse(tree.rtree, depth + 1, &block) if tree.rtree end end iris = Datasets::Iris.new data = iris.each.to_a # DecisionTree.new(data).split_min_entropy(data.sort_by(&:petal_length)) tree = DecisionTree.new(data).build puts < #{node.name} [label = \"#{node.condition}\"];" end puts "#{node.name} [label = \"#{node.data.group_by(&:class).map{|k, v| "#{k}: #{v.length}" }.join("\n")}\"];" end puts "}"
Owner Author

### youchan commented Feb 9, 2018

to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.