Skip to content

Instantly share code, notes, and snippets.

@mumbleskates
Last active July 11, 2023 12:15
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mumbleskates/0ef75bf3f25d0faeecc73ddb9373ea75 to your computer and use it in GitHub Desktop.
Save mumbleskates/0ef75bf3f25d0faeecc73ddb9373ea75 to your computer and use it in GitHub Desktop.
Pure Python N-ary Heap implementation
# coding=utf-8
class NaryHeap(object):
"""implements an n-ary heap"""
def __init__(self, items=(), *, n=2, direction=min):
"""
create a new heap
items must be an iterable, n must be a positive integer, and direction must be
min or max.
"""
if not isinstance(n, int) or n < 1:
raise ValueError("Degree of N-ary heap (n) must be an integer >= 1")
self._n = n
self._list = []
if direction not in (min, max):
raise ValueError("Direction must be min or max")
self._direction = direction
for item in items:
self.push(item)
def _parent_ix(self, ix):
"""returns the parent index of a given node"""
return (ix - 1) // self._n
def _children_ix(self, ix):
"""returns all valid child indices of a given index"""
first_child = ix * self._n + 1
return range(first_child, min(first_child + self._n, len(self._list)))
def _sift_towards_root(self, ix):
current = self._list[ix]
while ix: # stop when ix is 0, the root
parent_ix = self._parent_ix(ix)
parent = self._list[parent_ix]
# note that we want to put the option that represents the least work first,
# because min() and max() return the earliest item if all are equal
if self._direction(parent, current) is parent:
return # we are done sorting this item
# continue sorting towards root
self._list[parent_ix], self._list[ix], ix = current, parent, parent_ix
def _sift_towards_leaf(self, ix):
current = self._list[ix]
while True:
children = tuple((self._list[x], x) for x in self._children_ix(ix))
if not children:
return # we are done, item is now a leaf
# get the most parental child
candidate, candidate_ix = self._direction(children)
# again, option representing the least work comes first
if self._direction(current, candidate) is current:
return # we are done sorting this item
# continue sorting towards leaf
self._list[candidate_ix], self._list[ix], ix = current, candidate, candidate_ix
def __len__(self):
return len(self._list)
def push(self, item):
"""Insert an item into the heap"""
self._list.append(item)
self._sift_towards_root(len(self._list) - 1)
def peek(self):
"""Show the next item in the heap"""
return self._list[0]
def pop(self):
"""Remove and return the next item in the heap"""
result, self._list[0] = self._list[0], self._list[-1]
self._list.pop()
if self._list:
self._sift_towards_leaf(0)
return result
# coding=utf-8
from functools import partial
from itertools import product
from random import choice
import pytest
from heap import NaryHeap
TEST_ITEMS = [
[],
[1],
list(range(10)),
list(range(123)),
list(reversed(range(123))),
[choice(range(123)) for _ in range(123)],
[1] * 50,
"this string is also an iterable it turns out",
(34, 0, 77, 95, 21, 8009, 788324),
]
TEST_DEGREES = range(1, 11)
TEST_DIRECTIONS = (min, max)
@pytest.fixture(scope='session', params=product(TEST_DEGREES, TEST_DIRECTIONS))
def factory(request):
n, direction = request.param
return partial(NaryHeap, n=n, direction=direction)
def assert_invariant(heap):
# make sure that the heap invariant is maintained
for i, item in enumerate(heap._list):
if i == 0:
continue # no need to test the root
parent = heap._list[(i - 1) // heap._n]
assert heap._direction(parent, item) is parent
@pytest.mark.parametrize('items', TEST_ITEMS)
def test_init(items, factory):
heap = factory(items)
assert_invariant(heap)
@pytest.mark.parametrize('items', TEST_ITEMS)
def test_push(items, factory):
heap = factory()
for item in items:
heap.push(item)
assert_invariant(heap)
@pytest.mark.parametrize('items', TEST_ITEMS)
def test_pop(items, factory):
heap = factory(items)
# assert that the items come out in the expected order
expected = sorted(items, reverse=(heap._direction is max))
for expected_item in expected:
assert heap.pop() == expected_item
with pytest.raises(IndexError):
heap.pop() # heap should be empty now
@pytest.mark.parametrize('items', TEST_ITEMS)
def test_peek(items, factory):
# this test is surely overbuilt
heap = factory(items)
while heap:
assert heap.peek() is heap.pop()
with pytest.raises(IndexError):
heap.peek() # heap should be empty now
def test_bad_direction():
with pytest.raises(ValueError):
NaryHeap(direction=5) # not min or max
BAD_DEGREES = [0, -5, "string", 4.4, 2.0]
@pytest.mark.parametrize('degree', BAD_DEGREES)
def test_bad_degree(degree):
with pytest.raises(ValueError):
NaryHeap(n=degree)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment