Skip to content

Instantly share code, notes, and snippets.

@GoodNovember
Last active April 7, 2023 03:54
Show Gist options
  • Save GoodNovember/0dd9f731b5728379c94191500a0adb73 to your computer and use it in GitHub Desktop.
Save GoodNovember/0dd9f731b5728379c94191500a0adb73 to your computer and use it in GitHub Desktop.
A Rust Octree implementation that uses Bevy's math code.
use bevy::math::{ Vec3 };
#[derive(Debug, Clone, Copy, PartialEq)]
/// an axist aligned bounding box in 3d space
pub struct BoundingBox {
pub center: Vec3,
pub half_extents: Vec3,
}
impl BoundingBox {
pub fn new(center: Vec3, half_extents: Vec3) -> Self {
Self { center, half_extents }
}
pub fn new_from_point_and_radius(point: Vec3, radius: f32) -> Self {
Self {
center: point,
half_extents: Vec3::new(radius, radius, radius),
}
}
pub fn get_closest_point(&self, point: &Vec3) -> Vec3 {
let mut closest_point = Vec3::ZERO;
for i in 0..3 {
let min = self.center[i] - self.half_extents[i];
let max = self.center[i] + self.half_extents[i];
if point[i] < min {
closest_point[i] = min;
} else if point[i] > max {
closest_point[i] = max;
} else {
closest_point[i] = point[i];
}
}
closest_point
}
pub fn distance_to_point(&self, point: Vec3) -> f32 {
let closest_point = self.get_closest_point(&point);
(closest_point - point).length()
}
pub fn zero() -> Self {
Self {
center: Vec3::ZERO,
half_extents: Vec3::ZERO,
}
}
pub fn equals(&self, other: &BoundingBox) -> bool {
self.center == other.center && self.half_extents == other.half_extents
}
pub fn merge_two_create_new(aabb_a: &Self, aabb_b: &Self) -> Self {
let mut aabb = BoundingBox::zero();
aabb.merge_with_aabb(aabb_a);
aabb.merge_with_aabb(aabb_b);
aabb
}
pub fn merged(&self, other: &BoundingBox) -> Self {
let mut aabb = BoundingBox::zero();
aabb.merge_with_aabb(self);
aabb.merge_with_aabb(other);
aabb
}
pub fn merge_with_aabb(&mut self, other: &BoundingBox) {
self.center.x = self.center.x.min(other.center.x);
self.center.y = self.center.y.min(other.center.y);
self.center.z = self.center.z.min(other.center.z);
self.half_extents.x = self.half_extents.x.max(other.half_extents.x);
self.half_extents.y = self.half_extents.y.max(other.half_extents.y);
self.half_extents.z = self.half_extents.z.max(other.half_extents.z);
}
pub fn intersects(&self, other: &BoundingBox) -> bool {
self.center.x <= other.half_extents.x && self.half_extents.x >= other.center.x
&& self.center.y <= other.half_extents.y && self.half_extents.y >= other.center.y
&& self.center.z <= other.half_extents.z && self.half_extents.z >= other.center.z
}
pub fn join_and_make_new (&self, other: &BoundingBox) -> Self {
let mut aabb = BoundingBox::zero();
aabb.merge_with_aabb(self);
aabb.merge_with_aabb(other);
aabb
}
pub fn contains_point(&self, point: &Vec3) -> bool {
let min = self.center - self.half_extents;
let max = self.center + self.half_extents;
point.x >= min.x && point.x <= max.x
&& point.y >= min.y && point.y <= max.y
&& point.z >= min.z && point.z <= max.z
}
pub fn get_random_point_within(&self) -> Vec3 {
let min = self.center - self.half_extents;
let max = self.center + self.half_extents;
Vec3::new(
rand::random::<f32>() * (max.x - min.x) + min.x,
rand::random::<f32>() * (max.y - min.y) + min.y,
rand::random::<f32>() * (max.z - min.z) + min.z,
)
}
pub fn max(&self) -> Vec3 {
self.center + self.half_extents
}
pub fn min(&self) -> Vec3 {
self.center - self.half_extents
}
pub fn center(&self) -> Vec3 {
self.center
}
pub fn set_x_min(&mut self, x: f32) {
self.center.x = x + self.half_extents.x;
}
pub fn set_x_max(&mut self, x: f32) {
self.center.x = x - self.half_extents.x;
}
pub fn set_y_min(&mut self, y: f32) {
self.center.y = y + self.half_extents.y;
}
pub fn set_y_max(&mut self, y: f32) {
self.center.y = y - self.half_extents.y;
}
pub fn set_z_min(&mut self, z: f32) {
self.center.z = z + self.half_extents.z;
}
pub fn set_z_max(&mut self, z: f32) {
self.center.z = z - self.half_extents.z;
}
}
use bevy::math::Vec3;
#[derive(Debug, Clone, Copy)]
/// a point in 3d space with a unique id
pub struct DataPoint {
pub id: u32,
pub position: Vec3,
}
impl DataPoint {
pub fn new(id: u32, position: Vec3) -> Self {
Self { id, position }
}
}
use bevy::math::Vec3;
use super::octree_node::OctreeNode;
use super::bounding_box::BoundingBox;
use super::data_point::DataPoint;
pub struct Octree {
root: OctreeNode,
}
impl Octree {
pub fn new(boundary: BoundingBox, capacity: usize) -> Self {
Self {
root: OctreeNode::new(boundary, capacity),
}
}
pub fn insert(&mut self, data_point: DataPoint) -> bool {
self.root.insert(data_point)
}
pub fn query(&self, region: &BoundingBox, results: &mut Vec<DataPoint>) {
self.root.query(region, results)
}
pub fn update_position(&mut self, id: u32, new_position: Vec3) -> bool {
if let Some(mut data_point) = self.root.remove_data_point(id) {
if !self.root.boundary.contains_point(&new_position) {
self.root.grow(&data_point)
}
data_point.position = new_position;
self.insert(data_point)
} else {
false
}
}
pub fn query_within_radius(&self, center: &Vec3, radius: f32, results: &mut Vec<DataPoint>) {
let half_extents = Vec3::splat(radius);
let region = BoundingBox {
center: center.clone(),
half_extents,
};
self.root.query_within_radius(*center, radius * radius, &region, results);
}
pub fn get_root_boundary(&self) -> BoundingBox {
self.root.boundary
}
pub fn get_root_center(&self) -> Vec3 {
self.root.boundary.center
}
}
#[test]
fn test_basic_operation(){
let mut octree = Octree::new(BoundingBox::new(Vec3::ZERO, Vec3::splat(100.0)), 4);
let data_points = vec![
DataPoint::new(0, Vec3::new(0.0, 0.0, 0.0)),
DataPoint::new(1, Vec3::new(1.0, 1.0, 1.0)),
DataPoint::new(2, Vec3::new(2.0, 2.0, 2.0)),
DataPoint::new(3, Vec3::new(3.0, 3.0, 3.0)),
DataPoint::new(4, Vec3::new(4.0, 4.0, 4.0)),
DataPoint::new(5, Vec3::new(5.0, 5.0, 5.0)),
DataPoint::new(6, Vec3::new(6.0, 6.0, 6.0)),
DataPoint::new(7, Vec3::new(7.0, 7.0, 7.0)),
DataPoint::new(8, Vec3::new(8.0, 8.0, 8.0)),
DataPoint::new(9, Vec3::new(9.0, 9.0, 9.0)),
DataPoint::new(10, Vec3::new(10.0, 10.0, 10.0)),
DataPoint::new(11, Vec3::new(11.0, 11.0, 11.0)),
DataPoint::new(12, Vec3::new(12.0, 12.0, 12.0)),
DataPoint::new(13, Vec3::new(13.0, 13.0, 13.0)),
DataPoint::new(14, Vec3::new(14.0, 14.0, 14.0)),
DataPoint::new(15, Vec3::new(15.0, 15.0, 15.0)),
DataPoint::new(16, Vec3::new(16.0, 16.0, 16.0)),
DataPoint::new(17, Vec3::new(17.0, 17.0, 17.0)),
DataPoint::new(18, Vec3::new(18.0, 18.0, 18.0)),
DataPoint::new(19, Vec3::new(19.0, 19.0, 19.0)),
DataPoint::new(20, Vec3::new(20.0, 20.0, 20.0)),
];
for data_point in data_points {
octree.insert(data_point);
}
let mut results = vec![];
octree.query(&BoundingBox::new(Vec3::new(0.0, 0.0, 0.0), Vec3::splat(10.0)), &mut results);
assert_eq!(results.len(), 11);
results.clear();
octree.query(&BoundingBox::new(Vec3::new(0.0, 0.0, 0.0), Vec3::splat(5.0)), &mut results);
assert_eq!(results.len(), 6);
results.clear();
octree.query(&BoundingBox::new(Vec3::new(0.0, 0.0, 0.0), Vec3::splat(1.0)), &mut results);
assert_eq!(results.len(), 2);
results.clear();
octree.query(&BoundingBox::new(Vec3::new(0.0, 0.0, 0.0), Vec3::splat(0.1)), &mut results);
assert_eq!(results.len(), 1);
results.clear();
octree.query(&BoundingBox::new(Vec3::new(0.0, 0.0, 0.0), Vec3::splat(100.0)), &mut results);
assert_eq!(results.len(), 19);
}
#[test]
fn test_query_within_radius(){
let mut octree = Octree::new(BoundingBox::new(Vec3::ZERO, Vec3::splat(100.0)), 4);
let data_points = vec![
DataPoint::new(0, Vec3::new(0.0, 0.0, 0.0)),
DataPoint::new(1, Vec3::new(1.0, 1.0, 1.0)),
DataPoint::new(2, Vec3::new(2.0, 2.0, 2.0)),
DataPoint::new(3, Vec3::new(3.0, 3.0, 3.0)),
DataPoint::new(4, Vec3::new(4.0, 4.0, 4.0)),
DataPoint::new(5, Vec3::new(5.0, 5.0, 5.0)),
DataPoint::new(6, Vec3::new(6.0, 6.0, 6.0)),
DataPoint::new(7, Vec3::new(7.0, 7.0, 7.0)),
DataPoint::new(8, Vec3::new(8.0, 8.0, 8.0)),
DataPoint::new(9, Vec3::new(9.0, 9.0, 9.0)),
DataPoint::new(10, Vec3::new(10.0, 10.0, 10.0)),
DataPoint::new(11, Vec3::new(11.0, 11.0, 11.0)),
DataPoint::new(12, Vec3::new(12.0, 12.0, 12.0)),
DataPoint::new(13, Vec3::new(13.0, 13.0, 13.0)),
DataPoint::new(14, Vec3::new(14.0, 14.0, 14.0)),
DataPoint::new(15, Vec3::new(15.0, 15.0, 15.0)),
DataPoint::new(16, Vec3::new(16.0, 16.0, 16.0)),
DataPoint::new(17, Vec3::new(17.0, 17.0, 17.0)),
DataPoint::new(18, Vec3::new(18.0, 18.0, 18.0)),
DataPoint::new(19, Vec3::new(19.0, 19.0, 19.0)),
DataPoint::new(20, Vec3::new(20.0, 20.0, 20.0)),
];
for data_point in data_points {
octree.insert(data_point);
}
let mut results = vec![];
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 10.0, &mut results);
assert_eq!(results.len(), 6);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 5.0, &mut results);
assert_eq!(results.len(), 3);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 1.0, &mut results);
assert_eq!(results.len(), 1);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 0.1, &mut results);
assert_eq!(results.len(), 1);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 100.0, &mut results);
assert_eq!(results.len(), 19);
}
#[test]
fn test_update_position(){
let mut octree = Octree::new(BoundingBox::new(Vec3::ZERO, Vec3::splat(100.0)), 4);
let data_points = vec![
DataPoint::new(0, Vec3::new(0.0, 0.0, 0.0)),
DataPoint::new(1, Vec3::new(1.0, 1.0, 1.0)),
DataPoint::new(2, Vec3::new(2.0, 2.0, 2.0)),
DataPoint::new(3, Vec3::new(3.0, 3.0, 3.0)),
DataPoint::new(4, Vec3::new(4.0, 4.0, 4.0)),
DataPoint::new(5, Vec3::new(5.0, 5.0, 5.0)),
DataPoint::new(6, Vec3::new(6.0, 6.0, 6.0)),
DataPoint::new(7, Vec3::new(7.0, 7.0, 7.0)),
DataPoint::new(8, Vec3::new(8.0, 8.0, 8.0)),
DataPoint::new(9, Vec3::new(9.0, 9.0, 9.0)),
DataPoint::new(10, Vec3::new(10.0, 10.0, 10.0)),
DataPoint::new(11, Vec3::new(11.0, 11.0, 11.0)),
DataPoint::new(12, Vec3::new(12.0, 12.0, 12.0)),
DataPoint::new(13, Vec3::new(13.0, 13.0, 13.0)),
DataPoint::new(14, Vec3::new(14.0, 14.0, 14.0)),
DataPoint::new(15, Vec3::new(15.0, 15.0, 15.0)),
DataPoint::new(16, Vec3::new(16.0, 16.0, 16.0)),
DataPoint::new(17, Vec3::new(17.0, 17.0, 17.0)),
DataPoint::new(18, Vec3::new(18.0, 18.0, 18.0)),
DataPoint::new(19, Vec3::new(19.0, 19.0, 19.0)),
DataPoint::new(20, Vec3::new(20.0, 20.0, 20.0)),
];
for data_point in data_points {
octree.insert(data_point);
}
// Move the first point to the end of the octree
octree.update_position(0, Vec3::new(20.0, 20.0, 20.0));
let mut results = vec![];
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 10.0, &mut results);
assert_eq!(results.len(), 5);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 5.0, &mut results);
assert_eq!(results.len(), 2);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 1.0, &mut results);
assert_eq!(results.len(), 0);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 0.1, &mut results);
assert_eq!(results.len(), 0);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 100.0, &mut results);
assert_eq!(results.len(), 19);
// Move the first point back to the beginning of the octree
octree.update_position(0, Vec3::new(0.0, 0.0, 0.0));
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 10.0, &mut results);
assert_eq!(results.len(), 6);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 5.0, &mut results);
assert_eq!(results.len(), 3);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 1.0, &mut results);
assert_eq!(results.len(), 1);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 0.1, &mut results);
assert_eq!(results.len(), 1);
results.clear();
octree.query_within_radius(&Vec3::new(0.0, 0.0, 0.0), 100.0, &mut results);
assert_eq!(results.len(), 19);
}
use super::bounding_box::BoundingBox;
use super::data_point::DataPoint;
use bevy::math::Vec3;
#[derive(Clone)]
pub struct OctreeNode {
pub boundary: BoundingBox,
pub capacity: usize,
pub data_points: Vec<DataPoint>,
pub divided: bool,
pub children: [Option<Box<OctreeNode>>; 8],
}
impl OctreeNode {
pub fn new(boundary: BoundingBox, capacity: usize) -> Self {
Self {
boundary,
capacity,
data_points: Vec::new(),
divided: false,
children: [None, None, None, None, None, None, None, None],
}
}
pub fn insert(&mut self, data_point: DataPoint) -> bool {
if !self.boundary.contains_point(&data_point.position) {
return false;
}
if self.data_points.len() < self.capacity {
self.data_points.push(data_point);
return true;
}
if !self.divided {
self.subdivide();
}
for child in self.children.iter_mut() {
if let Some(ref mut child_node) = child {
if child_node.insert(data_point) {
return true;
}
}
}
false
}
pub fn remove_data_point(&mut self, id: u32) -> Option<DataPoint> {
if let Some(index) = self.data_points.iter().position(|dp| dp.id == id) {
let removed_data_point = self.data_points.remove(index);
self.merge();
return Some(removed_data_point);
}
if self.divided {
for child in self.children.iter_mut() {
if let Some(ref mut child_node) = child {
if let Some(removed_data_point) = child_node.remove_data_point(id) {
self.merge();
return Some(removed_data_point);
}
}
}
}
None
}
pub fn subdivide(&mut self) {
let cx = self.boundary.center.x;
let cy = self.boundary.center.y;
let cz = self.boundary.center.z;
let hx = self.boundary.half_extents.x;
let hy = self.boundary.half_extents.y;
let hz = self.boundary.half_extents.z;
for i in 0..8 {
let center = Vec3::new(
cx + (i & 1) as f32 * hx,
cy + ((i >> 1) & 1) as f32 * hy,
cz + ((i >> 2) & 1) as f32 * hz,
);
let boundary = BoundingBox {
center,
half_extents: self.boundary.half_extents * 0.5,
};
self.children[i] = Some(Box::new(OctreeNode::new(boundary, self.capacity)));
}
self.divided = true;
}
pub fn query(&self, region: &BoundingBox, results: &mut Vec<DataPoint>) {
if !self.boundary.intersects(region) {
return;
}
for data_point in &self.data_points {
if region.contains_point(&data_point.position) {
results.push(*data_point);
}
}
if self.divided {
for child in self.children.iter() {
if let Some(ref child_node) = child {
child_node.query(region, results);
}
}
}
}
pub fn query_within_radius(
&self,
center: Vec3,
squared_radius: f32,
region: &BoundingBox,
results: &mut Vec<DataPoint>,
) {
if !self.boundary.intersects(region) {
return;
}
for data_point in &self.data_points {
let squared_distance = (data_point.position - center).length_squared();
if squared_distance <= squared_radius {
results.push(*data_point);
}
}
if self.divided {
for child in self.children.iter() {
if let Some(ref child_node) = child {
child_node.query_within_radius(center, squared_radius, region, results);
}
}
}
}
pub fn merge(&mut self) {
if self.divided {
let mut total_points = self.data_points.len();
for child in self.children.iter() {
if let Some(ref child_node) = child {
total_points += child_node.data_points.len();
}
}
if total_points <= self.capacity {
for child in self.children.iter_mut() {
if let Some(ref mut child_node) = child {
self.data_points.append(&mut child_node.data_points);
*child = None;
}
}
self.divided = false;
}
}
}
pub fn grow(&mut self, data_point: &DataPoint) {
let direction = (data_point.position - self.boundary.center).normalize();
let new_center = self.boundary.center + direction * self.boundary.half_extents;
let new_boundary = BoundingBox {
center: new_center,
half_extents: self.boundary.half_extents * 2.0,
};
let mut new_root = OctreeNode::new(new_boundary, self.capacity);
new_root.insert(*data_point);
new_root.divided = true;
let child_index = new_root.get_child_index(&self.boundary.center);
new_root.children[child_index] = Some(Box::new(self.clone()));
*self = new_root;
}
fn get_child_index(&self, point: &Vec3) -> usize {
let mut index = 0;
if point.x >= self.boundary.center.x {
index |= 1;
}
if point.y >= self.boundary.center.y {
index |= 2;
}
if point.z >= self.boundary.center.z {
index |= 4;
}
index
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment