Skip to content

Instantly share code, notes, and snippets.

@kuczmama
Last active March 5, 2020 12:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kuczmama/9ad5075657c57f12e738ed8730cc6297 to your computer and use it in GitHub Desktop.
Save kuczmama/9ad5075657c57f12e738ed8730cc6297 to your computer and use it in GitHub Desktop.
A decision tree written in ruby
#!/usr/bin/env ruby
# frozen_string_literal: true
require 'set'
require 'pry'
require 'csv'
require 'optparse'
require 'time'
# Make predictions with trees
class DecisionTree
attr_accessor :root
def initialize(training_data, label_name = :label)
@label_name = label_name
@root = build_tree(training_data)
end
# Decision Tree Node
class Node
attr_accessor :label, :value, :left, :right, :predictions
def to_s
result = ''
unless label.nil? || value.nil?
comparator = value.is_a?(Numeric) ? '>=' : '=='
result += "Is #{label} #{comparator} #{value}"
end
result += "--> Predictions: #{predictions}" unless predictions.nil?
result
end
end
def to_s
print_tree(@root)
end
def predict(features)
predict_helper(features, @root)
end
private
# unique_features: {:weight=>#<Set: {10, 3, 5}>, :color=>#<Set: {"Green", "Orange"}>}
def find_unique_features(rows)
unique_features = {}
rows.each do |row|
row.each do |label, v|
next if label == @label_name # Ignore the label
unique_features[label] = Set.new if unique_features[label].nil?
unique_features[label] << v
end
end
unique_features
end
def build_tree(rows)
best_question = find_best_question(rows)
info_gain = best_question[:info_gain]
node = Node.new
if info_gain.zero?
node.predictions = labels(rows).uniq
return node
end
node.label = best_question[:label]
node.value = best_question[:value]
left, right = partition(rows, best_question[:label], best_question[:value])
node.left = build_tree(left)
node.right = build_tree(right)
node.label = best_question[:label]
node.value = best_question[:value]
node
end
def calc_weighted_uncertainty(left, right)
left_weight = left.length / (left.length + right.length).to_f
left_weight * gini(left) + (1 - left_weight) * gini(right)
end
def labels(rows)
rows.map { |row| row[@label_name.to_sym] || row[@label_name.to_s] }
end
# { label: best_question_label, value: best_question_value }
def find_best_question(rows)
best_label = nil
best_value = nil
best_gain = 0.0
current_uncertainty = gini(labels(rows))
find_unique_features(rows).each do |label, values|
values.each do |value|
left, right = partition(rows, label, value)
next if left.empty? || right.empty?
info_gain = current_uncertainty - calc_weighted_uncertainty(labels(left), labels(right))
next unless info_gain > best_gain
best_gain = info_gain
best_label = label
best_value = value
end
end
{ label: best_label, value: best_value, info_gain: best_gain }
end
def gini(labels)
label_counts = {}
labels.each do |label|
label_counts[label] = 0.0 if label_counts[label].nil?
label_counts[label] += 1.0
end
result = 0.0
labels.each do |label|
result += 1.0 / labels.length * (1 - label_counts[label] / labels.length)
end
result
end
def match(value, question_value)
if !!question_value == question_value
return value == question_value
end # boolean
return value == question_value if question_value.is_a? String
return value >= question_value if question_value.is_a? Numeric
raise "typeof #{question_value.class} is not supported"
end
def partition(rows, label, question_value)
trues = []
falses = []
rows.each do |row|
if match(row[label], question_value)
trues << row
else
falses << row
end
end
[trues, falses]
end
def print_tree(root, spacing = '')
return if root.nil?
puts "#{spacing}#{root}"
if root.left
puts "#{spacing}-->true"
print_tree(root.left, "#{spacing}\t")
end
if root.right
puts "#{spacing}-->false:"
print_tree(root.right, "#{spacing}\t")
end
end
def predict_helper(features, root = nil)
return root.predictions unless root.predictions.nil?
question_value = root.value
value = features[root.label]
if match(value, question_value)
predict_helper(features, root.left)
else
predict_helper(features, root.right)
end
end
end
Options = Struct.new(:max_train_rows, :label, :verbose)
class Parser
def self.parse(options)
args = Options.new(nil, :label, false)
# TODO: add test_data
opt_parser = OptionParser.new do |opts|
opts.banner = "Usage: #{$PROGRAM_NAME} TRAINING_DATA.csv [options]"
opts.on('-m ROWS', '--max-train-rows=ROWS', 'Max rows to read in from the TRAINING_DATA csv, default read in all rows') do |max_train_rows|
args.max_train_rows = max_train_rows
end
opts.on('-h', '--help', 'Prints this help') do
puts opts
exit
end
opts.on('-v', '--verbose', 'Run verbosely') do
args.verbose = true
end
opts.on('-l LABEL', '--label=LABEL', 'The column name that is the dependent variable. Default \'label\'') do |label|
args.label = label
end
# TODO: - date labels, regression, classifier
end
opt_parser.parse!(options)
args
end
end
options = Parser.parse(ARGV.length.zero? ? %w[--help] : ARGV)
class String
def is_float?
to_f.to_s == self
end
def is_int?
to_i.to_s == self
end
def is_date?
!Time.parse(self).nil?
rescue StandardError
false
end
end
training_data = []
i = 0
def date_features(date_time)
methods = ['month', 'wday', 'yday', 'dst?', 'gmtoff', 'gmt_offset', 'utc_offset', 'utc?', 'gmt?', 'sunday?', 'tuesday?', 'monday?', 'thursday?', 'wednesday?', 'saturday?', 'friday?']
features = {}
methods.each do |method|
features[method] = date_time.send(method)
end
features
end
CSV.foreach(ARGV[0], headers: true) do |csv_row|
break if !options.max_train_rows.nil? && i >= options.max_train_rows.to_i
row = {}
csv_row.each do |k, v|
# TODO: handle dates
row[k] = if v.nil?
v = 0
elsif v.is_float? || v.is_int?
v.to_f # Regression
elsif v.is_date?
date_time = Time.parse(v)
features = date_features(date_time)
row = row.merge(features)
else
v # Classifier - string
end
end
i += 1
training_data << row
end
# training_data = [
# { weight: 10, color: 'Green', label: 'Apple' },
# { weight: 10, color: 'Orange', label: 'Orange' },
# { weight: 3, color: 'Green', label: 'Grape' },
# { weight: 5, color: 'Green', label: 'Grape' }
# ]
# TODO: change this...
train_length = (training_data.length * 0.9).to_i
train = training_data[0...train_length]
test_data = training_data[train_length..-1]
if options[:verbose]
puts "Creating a decision tree with #{training_data.length} rows from #{ARGV[0]}..."
end
decision_tree = DecisionTree.new(train, options.label.to_s)
puts 'finding accuracy: '
test_data.each do |data|
actual = data.delete(options.label.to_s)
prediction = decision_tree.predict(data)[0]
puts "diff: #{actual - prediction} actual: #{actual} prediction: #{decision_tree.predict(data)}"
end
puts decision_tree
# probability_tree = decision_tree.predict(training_data[0..1000])
# binding.pry
# puts predict(probability_tree, bid_size: 80.0, ask_size: 200.0, previous_price: 20.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment