Created
September 17, 2012 12:39
-
-
Save n9986/3737062 to your computer and use it in GitHub Desktop.
Spatial Hash
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import floor | |
from collections import namedtuple | |
Rect = namedtuple('Rect', ('x1', 'y1', 'x2', 'y2')) | |
class SpatialHash(object): | |
def __init__(self, cell_size=10.0): | |
self.cell_size = float(cell_size) | |
self.d = {} | |
def _add(self, cell_coord, o): | |
"""Add the object o to the cell at cell_coord.""" | |
try: | |
self.d.setdefault(cell_coord, set()).add(o) | |
except KeyError: | |
self.d[cell_coord] = set((o,)) | |
def _remove(self, cell_coord, o): | |
"""Remove the object o from the cell at cell_coord.""" | |
cell = self.d[cell_coord] | |
cell.remove(o) | |
# Delete the cell from the hash if it is empty. | |
if not cell: | |
del(self.d[cell_coord]) | |
def _cells_for_rect(self, r): | |
"""Return a set of the cells into which r extends.""" | |
cells = set() | |
cy = floor(r.y1 / self.cell_size) | |
while (cy * self.cell_size) <= r.y2: | |
cx = floor(r.x1 / self.cell_size) | |
while (cx * self.cell_size) <= r.x2: | |
cells.add((int(cx), int(cy))) | |
cx += 1.0 | |
cy += 1.0 | |
return cells | |
def add_rect(self, r, obj): | |
"""Add an object obj with bounds r.""" | |
cells = self._cells_for_rect(r) | |
for c in cells: | |
self._add(c, obj) | |
def remove_rect(self, r, obj): | |
"""Remove an object obj which had bounds r.""" | |
cells = self._cells_for_rect(r) | |
for c in cells: | |
self._remove(c, obj) | |
def potential_collisions(self, r, obj): | |
"""Get a set of all objects that potentially intersect obj.""" | |
cells = self._cells_for_rect(r) | |
potentials = set() | |
for c in cells: | |
potentials.update(self.d.get(c, set())) | |
potentials.discard(obj) # obj cannot intersect itself | |
return potentials | |
def test_cells_for_rect(): | |
h = SpatialHash() | |
cells = h._cells_for_rect(Rect(1, 2, 9, 12)) | |
assert cells == set([(0, 0), (0, 1)]), cells | |
r = Rect(7, 15, 13, 19) | |
cells = h._cells_for_rect(r) | |
assert cells == set([(0, 1), (1, 1)]), cells | |
def test_add(): | |
h = SpatialHash() | |
h.add_rect(Rect(1, 2, 3, 4), 'foo') | |
assert 'foo' in h.d[(0, 0)] | |
def test_add_spanning(): | |
h = SpatialHash() | |
h.add_rect(Rect(-1, 9, 2, 12), 'foo') | |
assert 'foo' in h.d[(0, 0)] | |
assert 'foo' in h.d[(0, 1)] | |
assert 'foo' in h.d[(-1, 0)] | |
assert 'foo' in h.d[(-1, 1)] | |
def test_remove(): | |
h = SpatialHash() | |
r = Rect(1, 2, 7, 12) | |
h.add_rect(r, 'foo') | |
h.remove_rect(r, 'foo') | |
assert (0, 0) not in h.d | |
assert (0, 1) not in h.d | |
def test_collide(): | |
h = SpatialHash() | |
h.add_rect(Rect(3, 8, 4, 11), 'foo') | |
r = Rect(7, 15, 13, 19) | |
h.add_rect(r, 'bar') | |
print h.d | |
collisions = h.potential_collisions(r, 'bar') | |
assert collisions == set(['foo']), collisions |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment