Skip to content

Instantly share code, notes, and snippets.

@jdx
Last active February 16, 2020 17:25
Show Gist options
  • Save jdx/2fbf3fccc4de9723d127ea4c77c806ff to your computer and use it in GitHub Desktop.
Save jdx/2fbf3fccc4de9723d127ea4c77c806ff to your computer and use it in GitHub Desktop.
max heap in python
from collections import namedtuple
from typing import TypeVar, Generic, Optional, List, Union, Tuple, Callable, Any
T = TypeVar("T")
class Heap(Generic[T]):
x: List[T]
def __init__(self, key: Callable[[T], Any] = lambda x: x, order="min"):
self.x = []
self.key = key
self.order = order
def top(self) -> Optional[T]:
return self.x[0] if self.x else None
def pop(self) -> Optional[T]:
if not self.x:
return None
val = self.x[0]
tail = self.x.pop()
if len(self.x) > 0:
self.x[0] = tail
self._heapify_down(0)
return val
def add(self, element: T) -> None:
self.x.append(element)
self._heapify_up(len(self.x) - 1)
def _heapify_up(self, child_idx: int) -> None:
if child_idx == 0:
return
parent_idx = self._parent_idx(child_idx)
child = self.x[child_idx]
parent = self.x[parent_idx]
if self._compare(child, parent):
return
self.x[child_idx], self.x[parent_idx] = self.x[parent_idx], self.x[child_idx]
self._heapify_up(parent_idx)
def _heapify_down(self, parent_idx: int) -> None:
child_idx = (2 * parent_idx) + 1
if child_idx >= len(self.x):
return
if child_idx + 1 < len(self.x):
if self._compare(self.x[child_idx], self.x[child_idx + 1]):
# use other child
child_idx += 1
child = self.x[child_idx]
parent = self.x[parent_idx]
if self._compare(child, parent):
return
self.x[child_idx], self.x[parent_idx] = self.x[parent_idx], self.x[child_idx]
self._heapify_down(child_idx)
def _parent_idx(self, child_idx) -> int:
return (child_idx - 1) // 2
def _child_idx(self, parent_idx) -> Union[None, Tuple[int], Tuple[int, int]]:
child_idx = (2 * parent_idx) + 1
if child_idx >= len(self.x) - 1:
return None
if child_idx >= len(self.x) - 2:
return child_idx
return (child_idx, child_idx + 1)
def _compare(self, child: T, parent: T) -> bool:
if self.order == "min":
return self.key(child) >= self.key(parent)
return self.key(child) <= self.key(parent)
def test_min_heap():
h = Heap()
h.add(1)
h.add(3)
h.add(5)
h.add(2)
h.add(1)
h.add(9)
assert h.top() == 1
assert h.pop() == 1
assert h.pop() == 1
assert h.pop() == 2
assert h.pop() == 3
assert h.pop() == 5
assert h.pop() == 9
assert h.pop() is None
def test_max_heap():
h = Heap(order="max")
h.add(1)
h.add(3)
h.add(5)
h.add(2)
h.add(1)
h.add(9)
assert h.top() == 9
assert h.pop() == 9
assert h.pop() == 5
assert h.pop() == 3
assert h.pop() == 2
assert h.pop() == 1
assert h.pop() == 1
assert h.pop() is None
def test_custom_key():
Car = namedtuple("Car", ["make", "model", "value"])
h = Heap(key=lambda x: x.value)
h.add(Car("Honda", "Pilot", 2000))
h.add(Car("Acura", "RSX", 3000))
h.add(Car("Ford", "Explorer", 1000))
assert h.pop().make == "Ford"
assert h.pop().make == "Honda"
assert h.pop().make == "Acura"
assert h.pop() is None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment