Created
January 7, 2023 15:51
-
-
Save SebastianBitsch/45bcaed56f3ad9eea2323d1070d81441 to your computer and use it in GitHub Desktop.
A simple Python quad tree as seen on wikipedia: https://en.wikipedia.org/wiki/Quadtree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A simple no dependency Python implementation of the Quadtree | |
# pseudocode on Wikipedia: https://en.wikipedia.org/wiki/Quadtree | |
from dataclasses import dataclass, field | |
@dataclass | |
class Point: | |
x: float | |
y: float | |
@dataclass | |
class AABB: | |
center: Point | |
half_dim: float | |
# Internal variables for simplifying code | |
min_x: float = field(init=False) | |
min_y: float = field(init=False) | |
max_x: float = field(init=False) | |
max_y: float = field(init=False) | |
def __post_init__(self): | |
self.min_x = self.center.x - self.half_dim | |
self.min_y = self.center.y - self.half_dim | |
self.max_x = self.center.x + self.half_dim | |
self.max_y = self.center.y + self.half_dim | |
def contains_point(self, p: Point) -> bool: | |
return (self.min_x <= p.x and p.x <= self.max_x) and (self.min_y <= p.y and p.y <= self.max_y) | |
def intersects_AABB(self, other) -> bool: | |
overlaps_x = (self.max_x >= other.min_x and other.max_x >= self.min_x) | |
overlaps_y = (self.max_y >= other.min_y and other.max_y >= self.min_y) | |
return overlaps_x and overlaps_y | |
class QuadTree: | |
def __init__(self, boundary: AABB, node_capacity:int = 10) -> None: | |
""" Create a quadtree object with a given node capacity and a bounding box""" | |
self.boundary = boundary | |
self.node_capacity = node_capacity | |
self.children: list[QuadTree] = [] | |
self.points: list[Point] = [] | |
def insert(self, p: Point) -> bool: | |
"""Insert a point in the quadtree""" | |
if not self.boundary.contains_point(p): | |
return False | |
if len(self.points) < self.node_capacity and self.children == []: | |
self.points.append(p) | |
return True | |
if self.children == []: | |
self.subdivide() | |
for child in self.children: | |
if child.insert(p): | |
return True | |
# Should never happen | |
return False | |
def subdivide(self) -> None: | |
""" Subdivide the quadtree by adding 4 children """ | |
c = self.boundary.center | |
s = self.boundary.half_dim * 0.5 | |
self.children = [ | |
QuadTree(AABB(Point(c.x - s, c.y - s), s), self.node_capacity), | |
QuadTree(AABB(Point(c.x - s, c.y + s), s), self.node_capacity), | |
QuadTree(AABB(Point(c.x + s, c.y + s), s), self.node_capacity), | |
QuadTree(AABB(Point(c.x + s, c.y - s), s), self.node_capacity) | |
] | |
def query_range(self, area: AABB) -> list[Point]: | |
""" Return all the points the quadtree has in the given range """ | |
if not self.boundary.intersects_AABB(area): | |
return [] | |
points_in_range = [p for p in self.points if area.contains_point(p)] | |
for child in self.children: | |
points_in_range += child.query_range(area) | |
return points_in_range | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment