Skip to content

Instantly share code, notes, and snippets.

@rust-play
Created February 20, 2019 13:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rust-play/a57cdaaa637540caaee506854adc5606 to your computer and use it in GitHub Desktop.
Save rust-play/a57cdaaa637540caaee506854adc5606 to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
#[macro_use]
extern crate error_chain; // 0.12.0
extern crate num_traits; // 0.2.6
use num_traits::AsPrimitive;
use std::fmt::Debug;
error_chain! {
types { TractError, TractErrorKind, TractResultExt, TractResult; }
foreign_links {}
errors { TFString {} }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Cmp {
LessEqual,
Less,
GreaterEqual,
Greater,
Equal,
NotEqual,
}
impl Cmp {
pub fn compare(&self, x: f32, y: f32) -> bool {
match *self {
Cmp::LessEqual => x <= y,
Cmp::Less => x < y,
Cmp::GreaterEqual => x >= y,
Cmp::Greater => x > y,
Cmp::Equal => x == y,
Cmp::NotEqual => x != y,
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct Branch {
pub cmp: Cmp,
pub feature_id: usize,
pub value: f32,
pub true_id: usize,
pub false_id: usize,
pub nan_is_true: bool,
}
impl Branch {
pub fn child_id(&self, feature: f32) -> usize {
let condition = if feature.is_nan() {
self.nan_is_true
} else {
self.cmp.compare(feature, self.value)
};
if condition {
self.true_id
} else {
self.false_id
}
}
}
#[derive(Copy, Clone, Debug)]
pub enum Node<L> {
Branch(Branch),
Leaf(L),
}
#[derive(Clone, Debug)]
pub struct Tree<L> {
nodes: Vec<Node<L>>,
root_id: usize,
}
impl<L: Debug + Clone> Tree<L> {
pub fn eval_unchecked<X, T>(&self, x: X) -> TractResult<&L>
where
X: AsRef<[T]>, // not entirely correct (e.g. ndarray, strides etc)
T: AsPrimitive<f32>,
{
let x = x.as_ref();
let mut node_id = self.root_id;
loop {
let node = unsafe { self.nodes.get_unchecked(node_id) };
match node {
Node::Branch(ref b) => {
let feature = unsafe { *x.get_unchecked(b.feature_id) };
node_id = b.child_id(feature.as_());
}
Node::Leaf(ref leaf) => {
return Ok(&leaf);
}
}
}
}
fn branches(&self) -> impl Iterator<Item = &Branch> {
self.nodes.iter().filter_map(|node| match node {
Node::Branch(ref branch) => Some(branch),
_ => None,
})
}
fn leaves(&self) -> impl Iterator<Item = &L> {
self.nodes.iter().filter_map(|node| match node {
Node::Leaf(ref leaf) => Some(leaf),
_ => None,
})
}
pub fn max_feature_id(&self) -> usize {
self.branches().map(|b| b.feature_id).max().unwrap_or(0)
}
pub fn from_nodes(nodes: &[Node<L>]) -> TractResult<Self> {
let len = nodes.len();
ensure!(len > 0, "Invalid tree: expected non-zero node count");
let mut max_feature_id = 0;
for node in nodes {
if let &Node::Branch(b) = node {
ensure!(
b.feature_id < len,
"Invalid node: {:?} (expected feature_id = {} < len = {})",
node, b.feature_id, len
);
max_feature_id = max_feature_id.max(b.feature_id);
ensure!(
b.true_id < len,
"Invalid node: {:?} (expected true_id = {} < len = {})",
node, b.true_id, len
);
ensure!(
b.false_id < len,
"Invalid node: {:?} (expected false_id = {} < len = {})",
node, b.false_id, len
);
}
}
Ok(Self { nodes: nodes.into(), root_id: 123 })
}
}
pub struct TreeEnsemble<L> {
trees: Vec<Tree<L>>,
max_feature_id: usize,
}
impl<L: Debug + Clone> TreeEnsemble<L> {
pub fn from_trees(trees: &[Tree<L>]) -> Self {
let max_feature_id = trees.iter()
.map(Tree::max_feature_id).max().unwrap_or(0);
Self { trees: trees.into(), max_feature_id }
}
}
fn main() {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment