Skip to content

Instantly share code, notes, and snippets.

@jdxcode

jdxcode/min_max_heap.py

Last active Feb 16, 2020
Embed
What would you like to do?
max heap in python
from collections import namedtuple
from typing import TypeVar, Generic, Optional, List, Union, Tuple, Callable, Any
T = TypeVar("T")
class Heap(Generic[T]):
x: List[T]
def __init__(self, key: Callable[[T], Any] = lambda x: x, order="min"):
self.x = []
self.key = key
self.order = order
def top(self) -> Optional[T]:
return self.x[0] if self.x else None
def pop(self) -> Optional[T]:
if not self.x:
return None
val = self.x[0]
tail = self.x.pop()
if len(self.x) > 0:
self.x[0] = tail
self._heapify_down(0)
return val
def add(self, element: T) -> None:
self.x.append(element)
self._heapify_up(len(self.x) - 1)
def _heapify_up(self, child_idx: int) -> None:
if child_idx == 0:
return
parent_idx = self._parent_idx(child_idx)
child = self.x[child_idx]
parent = self.x[parent_idx]
if self._compare(child, parent):
return
self.x[child_idx], self.x[parent_idx] = self.x[parent_idx], self.x[child_idx]
self._heapify_up(parent_idx)
def _heapify_down(self, parent_idx: int) -> None:
child_idx = (2 * parent_idx) + 1
if child_idx >= len(self.x):
return
if child_idx + 1 < len(self.x):
if self._compare(self.x[child_idx], self.x[child_idx + 1]):
# use other child
child_idx += 1
child = self.x[child_idx]
parent = self.x[parent_idx]
if self._compare(child, parent):
return
self.x[child_idx], self.x[parent_idx] = self.x[parent_idx], self.x[child_idx]
self._heapify_down(child_idx)
def _parent_idx(self, child_idx) -> int:
return (child_idx - 1) // 2
def _child_idx(self, parent_idx) -> Union[None, Tuple[int], Tuple[int, int]]:
child_idx = (2 * parent_idx) + 1
if child_idx >= len(self.x) - 1:
return None
if child_idx >= len(self.x) - 2:
return child_idx
return (child_idx, child_idx + 1)
def _compare(self, child: T, parent: T) -> bool:
if self.order == "min":
return self.key(child) >= self.key(parent)
return self.key(child) <= self.key(parent)
def test_min_heap():
h = Heap()
h.add(1)
h.add(3)
h.add(5)
h.add(2)
h.add(1)
h.add(9)
assert h.top() == 1
assert h.pop() == 1
assert h.pop() == 1
assert h.pop() == 2
assert h.pop() == 3
assert h.pop() == 5
assert h.pop() == 9
assert h.pop() is None
def test_max_heap():
h = Heap(order="max")
h.add(1)
h.add(3)
h.add(5)
h.add(2)
h.add(1)
h.add(9)
assert h.top() == 9
assert h.pop() == 9
assert h.pop() == 5
assert h.pop() == 3
assert h.pop() == 2
assert h.pop() == 1
assert h.pop() == 1
assert h.pop() is None
def test_custom_key():
Car = namedtuple("Car", ["make", "model", "value"])
h = Heap(key=lambda x: x.value)
h.add(Car("Honda", "Pilot", 2000))
h.add(Car("Acura", "RSX", 3000))
h.add(Car("Ford", "Explorer", 1000))
assert h.pop().make == "Ford"
assert h.pop().make == "Honda"
assert h.pop().make == "Acura"
assert h.pop() is None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment