Skip to content

Instantly share code, notes, and snippets.

@n9986
Created September 17, 2012 12:39
Show Gist options
  • Save n9986/3737062 to your computer and use it in GitHub Desktop.
Save n9986/3737062 to your computer and use it in GitHub Desktop.
Spatial Hash
from math import floor
from collections import namedtuple
Rect = namedtuple('Rect', ('x1', 'y1', 'x2', 'y2'))
class SpatialHash(object):
def __init__(self, cell_size=10.0):
self.cell_size = float(cell_size)
self.d = {}
def _add(self, cell_coord, o):
"""Add the object o to the cell at cell_coord."""
try:
self.d.setdefault(cell_coord, set()).add(o)
except KeyError:
self.d[cell_coord] = set((o,))
def _remove(self, cell_coord, o):
"""Remove the object o from the cell at cell_coord."""
cell = self.d[cell_coord]
cell.remove(o)
# Delete the cell from the hash if it is empty.
if not cell:
del(self.d[cell_coord])
def _cells_for_rect(self, r):
"""Return a set of the cells into which r extends."""
cells = set()
cy = floor(r.y1 / self.cell_size)
while (cy * self.cell_size) <= r.y2:
cx = floor(r.x1 / self.cell_size)
while (cx * self.cell_size) <= r.x2:
cells.add((int(cx), int(cy)))
cx += 1.0
cy += 1.0
return cells
def add_rect(self, r, obj):
"""Add an object obj with bounds r."""
cells = self._cells_for_rect(r)
for c in cells:
self._add(c, obj)
def remove_rect(self, r, obj):
"""Remove an object obj which had bounds r."""
cells = self._cells_for_rect(r)
for c in cells:
self._remove(c, obj)
def potential_collisions(self, r, obj):
"""Get a set of all objects that potentially intersect obj."""
cells = self._cells_for_rect(r)
potentials = set()
for c in cells:
potentials.update(self.d.get(c, set()))
potentials.discard(obj) # obj cannot intersect itself
return potentials
def test_cells_for_rect():
h = SpatialHash()
cells = h._cells_for_rect(Rect(1, 2, 9, 12))
assert cells == set([(0, 0), (0, 1)]), cells
r = Rect(7, 15, 13, 19)
cells = h._cells_for_rect(r)
assert cells == set([(0, 1), (1, 1)]), cells
def test_add():
h = SpatialHash()
h.add_rect(Rect(1, 2, 3, 4), 'foo')
assert 'foo' in h.d[(0, 0)]
def test_add_spanning():
h = SpatialHash()
h.add_rect(Rect(-1, 9, 2, 12), 'foo')
assert 'foo' in h.d[(0, 0)]
assert 'foo' in h.d[(0, 1)]
assert 'foo' in h.d[(-1, 0)]
assert 'foo' in h.d[(-1, 1)]
def test_remove():
h = SpatialHash()
r = Rect(1, 2, 7, 12)
h.add_rect(r, 'foo')
h.remove_rect(r, 'foo')
assert (0, 0) not in h.d
assert (0, 1) not in h.d
def test_collide():
h = SpatialHash()
h.add_rect(Rect(3, 8, 4, 11), 'foo')
r = Rect(7, 15, 13, 19)
h.add_rect(r, 'bar')
print h.d
collisions = h.potential_collisions(r, 'bar')
assert collisions == set(['foo']), collisions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment