Skip to content

Instantly share code, notes, and snippets.

@gsingh93
Last active October 27, 2016 13:38
Show Gist options
  • Save gsingh93/dc5ebe6c8a1582731918 to your computer and use it in GitHub Desktop.
Save gsingh93/dc5ebe6c8a1582731918 to your computer and use it in GitHub Desktop.
Segment tree implementation in Rust
use std::fmt::Show;
use std::default::Default;
use std::ops::Add;
struct SegmentTree<T> {
size: uint,
root: Node<T>
}
struct Node<T> {
left: Option<Box<Node<T>>>,
right: Option<Box<Node<T>>>,
val: T
}
impl<T: Default + Clone + Add<T, T>> SegmentTree<T> {
pub fn new(elts: &[T]) -> SegmentTree<T> {
let root = SegmentTree::build(elts, 0, elts.len() - 1);
SegmentTree { size: elts.len(), root: root }
}
pub fn query(&self, start: uint, end: uint) -> Result<T, String> {
if end >= self.size {
return Err("Out of bounds".to_string());
} else if start > end {
return Err("Start of query range can't be greater \
than end of range".to_string());
}
Ok(self.query_(0, self.size - 1, start, end, &self.root))
}
fn build(elts: &[T], left: uint, right: uint) -> Node<T> {
if elts.len() == 0 {
return Node { left: None, right: None, val: Default::default() }
}
let mut node = Node { left: None, right: None,
val: elts[left].clone() };
if left == right {
return node;
}
let mid = (left + right) / 2;
node.left = Some(box SegmentTree::build(elts, left, mid));
node.right = Some(box SegmentTree::build(elts, mid + 1, right));
match (&node.left, &node.right) {
(&Some(ref l), &Some(ref r)) => node.val = l.val + r.val,
_ => ()
}
node
}
fn query_(&self, left: uint, right: uint, start: uint, end: uint,
cur: &Node<T>) -> T {
if left == right || (left == start && right == end) {
return cur.val.clone();
}
let cr = match cur.right {
Some(box ref node) => node,
None => fail!("")
};
let cl = match cur.left {
Some(box ref node) => node,
None => fail!("")
};
let mid = (left + right) / 2;
if start > mid {
return self.query_(mid + 1, right, start, end, cr);
} else if end <= mid {
return self.query_(left, mid, start, end, cl);
} else {
return self.query_(left, mid, start, end, cl) +
self.query_(mid + 1, right, start, end, cr);
}
}
}
#[test]
fn segment_tree_test() {
let v: Vec<int> = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
let tree: SegmentTree<int> = SegmentTree::new(v.as_slice());
check(&tree, v.as_slice(), 0, 9);
check(&tree, v.as_slice(), 0, 1);
check(&tree, v.as_slice(), 4, 7);
}
#[test]
fn segment_tree_empty_test() {
let v: Vec<int> = vec!();
SegmentTree::new(v.as_slice());
}
#[test]
fn segment_tree_out_of_range_test() {
let v: Vec<int> = vec!(1);
let tree = SegmentTree::new(v.as_slice());
check(&tree, v.as_slice(), 0, 0);
assert!(tree.query(0, 1).is_err());
}
#[test]
fn segment_tree_backwards_range_test() {
let v: Vec<int> = vec!(1, 2, 3);
let tree = SegmentTree::new(v.as_slice());
check(&tree, v.as_slice(), 0, 2);
assert!(tree.query(2, 0).is_err());
}
#[cfg(test)]
fn check<T: Default + Add<T, T> + Show + Eq + Clone>(tree: &SegmentTree<T>,
elts: &[T], start: uint,
end: uint) {
assert_eq!(tree.query(start, end).unwrap(),
query(elts.as_slice(), start, end));
}
#[cfg(test)]
fn query<T: Add<T, T> + Show + Default>(elts: &[T], start: uint,
end: uint) -> T {
elts.iter().skip(start).take(end - start + 1).fold(Default::default(),
|a: T, b| a + *b)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment