Created
April 7, 2021 02:15
-
-
Save JosephCatrambone/eccd21728c9b267bc83ca8e578e5909f to your computer and use it in GitHub Desktop.
A Rust-language decision tree implementation with no external dependencies.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//use serde::{Serialize, Deserialize}; | |
use std::collections::HashMap; | |
// If you have serde, serialization is available with this derive chain: | |
//#[derive(Clone, Debug, Default, Serialize, Deserialize)] | |
#[derive(Clone, Debug, Default)] | |
pub struct DecisionTree { | |
feature: usize, | |
threshold: f32, | |
class: bool, | |
class_probability: f32, | |
left_subtree: Option<Box::<DecisionTree>>, | |
right_subtree: Option<Box::<DecisionTree>>, | |
} | |
impl DecisionTree { | |
pub fn new() -> Self { | |
DecisionTree { ..Default::default() } | |
} | |
pub fn get_decision_feature(&self) -> usize { | |
self.feature | |
} | |
pub fn predict(&self, example: &Vec<f32>) -> bool { | |
let (cls, _) = self.predict_with_confidence(example); | |
cls | |
} | |
pub fn predict_with_confidence(&self, example: &Vec<f32>) -> (bool, f32) { | |
if let (Some(ls), Some(rs)) = (&self.left_subtree, &self.right_subtree) { | |
if example[self.feature] < self.threshold { | |
ls.predict_with_confidence(example) | |
} else { | |
rs.predict_with_confidence(example) | |
} | |
} else { | |
(self.class, self.class_probability) | |
} | |
} | |
pub fn train(&mut self, examples: &Vec<&Vec<f32>>, labels:&Vec<bool>, max_depth:u32) { | |
self.threshold = 0.5f32; | |
// If this tree is 100% uniform (perfectly decides the data), we don't have to do more work. | |
// Check if this is the case. | |
let mut num_positive = 0; | |
let mut num_negative = 0; | |
for c in labels { | |
if *c { | |
num_positive += 1; | |
} else { | |
num_negative += 1; | |
} | |
} | |
// Find class purity. | |
if num_positive > num_negative { | |
self.class = true; | |
self.class_probability = num_positive as f32 / (num_positive+num_negative) as f32; | |
} else { | |
self.class = false; | |
self.class_probability = num_negative as f32 / (num_positive+num_negative) as f32; | |
} | |
// We might not be able to go further. | |
if num_positive == 0 || num_negative == 0 || max_depth == 0 { | |
return; | |
} | |
// We actually do have work to do. Find the feature which maximizes information gain. | |
let mut best_gini_impurity = 1000f32; | |
let mut best_feature = 0; | |
for candidate_feature in 0..examples[0].len() { | |
// TODO: Calculate midpoint. | |
let impurity = gini_impurity(examples, labels, candidate_feature, self.threshold); | |
if impurity > 0f32 && impurity < best_gini_impurity { | |
best_gini_impurity = impurity; | |
best_feature = candidate_feature; | |
} | |
} | |
// Now we have our optimal split. | |
self.feature = best_feature; | |
// Make the left examples. | |
let mut left_examples = Vec::<&Vec<f32>>::new(); | |
let mut right_examples = Vec::<&Vec<f32>>::new(); | |
let mut left_labels = Vec::<bool>::new(); | |
let mut right_labels = Vec::<bool>::new(); | |
for (x, y) in examples.into_iter().zip(labels) { | |
if x[self.feature] < self.threshold { | |
left_examples.push(x); | |
left_labels.push(*y); | |
} else { | |
right_examples.push(x); | |
right_labels.push(*y); | |
} | |
} | |
let mut left_sub = Box::new(DecisionTree::new()); | |
left_sub.train(&left_examples, &left_labels, max_depth-1); | |
self.left_subtree = Some(left_sub); | |
let mut right_sub = Box::new(DecisionTree::new()); | |
right_sub.train(&right_examples, &right_labels, max_depth-1); | |
self.right_subtree = Some(right_sub); | |
} | |
} | |
fn gini_impurity(examples: &Vec<&Vec<f32>>, labels: &Vec<bool>, feature:usize, thresh:f32) -> f32 { | |
let mut count_true_if_feature = 0; | |
let mut count_false_if_feature = 0; | |
let mut count_true_ifnot_feature = 0; | |
let mut count_false_ifnot_feature = 0; | |
for (x,y) in examples.iter().zip(labels) { | |
// Gini impurity = 1.0 - (# T / total)^2 - (# F / total)^2 | |
if x[feature] < thresh { | |
// ifnot_feature | |
if *y { | |
count_true_ifnot_feature += 1; | |
} else { | |
count_false_ifnot_feature += 1; | |
} | |
} else { | |
// if_feature | |
if *y { | |
count_true_if_feature += 1; | |
} else { | |
count_false_if_feature += 1; | |
} | |
} | |
} | |
let p_true_given_feature = count_true_if_feature as f32 / (1e-6 + (count_true_if_feature + count_false_if_feature) as f32); | |
let p_false_given_feature = count_false_if_feature as f32 / (1e-6 + (count_true_if_feature + count_false_if_feature) as f32); | |
let p_true_given_not_feature = count_true_ifnot_feature as f32 / (1e-6 + (count_true_ifnot_feature + count_false_ifnot_feature) as f32); | |
let p_false_given_not_feature = count_false_ifnot_feature as f32 / (1e-6 + (count_true_ifnot_feature + count_false_ifnot_feature) as f32); | |
let gini_impurity_with_feature = 1.0 - (p_true_given_feature*p_true_given_feature) - (p_false_given_feature*p_false_given_feature); | |
let gini_impurity_without_feature = 1.0 - (p_true_given_not_feature*p_true_given_not_feature) - (p_false_given_not_feature*p_false_given_not_feature); | |
let total_entries = (count_true_if_feature + count_false_if_feature + count_true_ifnot_feature + count_false_ifnot_feature) as f32; | |
gini_impurity_with_feature*((count_true_if_feature+count_false_if_feature) as f32/total_entries) + gini_impurity_without_feature*((count_true_ifnot_feature+count_false_ifnot_feature) as f32/total_entries) | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
macro_rules! assert_approx_eq { | |
($a:expr, $b:expr) => { | |
assert!(($a as f32 - $b as f32).abs() < 1e-3f32, "{} !~= {}", $a, $b); | |
}; | |
} | |
#[test] | |
fn test_tree_sanity() { | |
let mut dt = DecisionTree::new(); | |
// Is this game for children? | |
// Feat: exploding heads, talking animals, chainsaws, shotgun blasts | |
let doom = vec![1f32, 0f32, 1f32, 1f32]; | |
let animal_crossing = vec![0f32, 1f32, 1f32, 0f32]; | |
let gta = vec![0f32, 0f32, 1f32, 1f32]; | |
let pokemon = vec![0f32, 1f32, 0f32, 0f32]; | |
let quake = vec![1f32, 0f32, 0f32, 1f32]; | |
let fez = vec![0f32, 0f32, 0f32, 0f32]; | |
let examples = vec![&doom, &animal_crossing, >a, &pokemon, &quake, &fez]; | |
let labels = vec![false, true, false, true, false, true]; | |
dt.train(&examples, &labels, 4); | |
let wolfenstein = vec![1f32, 0f32, 1f32, 1f32]; | |
let terraria = vec![0f32, 1f32, 0f32, 0f32]; | |
assert!(!dt.predict(&wolfenstein)); | |
assert!(dt.predict(&terraria)); | |
} | |
#[test] | |
fn test_tree_sanity_2() { | |
let mut dt = DecisionTree::new(); | |
// Factor B is the only one that should actually matter. | |
let a = vec![0f32, 0f32, 0f32]; | |
let b = vec![0f32, 1f32, 0f32]; | |
let c = vec![1f32, 0f32, 0f32]; | |
let d = vec![1f32, 1f32, 0f32]; | |
let examples = vec![&a, &b, &c, &d]; | |
let labels = vec![false, true, false, true]; | |
dt.train(&examples, &labels, 5); | |
assert_eq!(dt.get_decision_feature(), 1); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment