Last active
February 16, 2020 17:25
-
-
Save jdx/2fbf3fccc4de9723d127ea4c77c806ff to your computer and use it in GitHub Desktop.
max heap in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
from typing import TypeVar, Generic, Optional, List, Union, Tuple, Callable, Any | |
T = TypeVar("T") | |
class Heap(Generic[T]): | |
x: List[T] | |
def __init__(self, key: Callable[[T], Any] = lambda x: x, order="min"): | |
self.x = [] | |
self.key = key | |
self.order = order | |
def top(self) -> Optional[T]: | |
return self.x[0] if self.x else None | |
def pop(self) -> Optional[T]: | |
if not self.x: | |
return None | |
val = self.x[0] | |
tail = self.x.pop() | |
if len(self.x) > 0: | |
self.x[0] = tail | |
self._heapify_down(0) | |
return val | |
def add(self, element: T) -> None: | |
self.x.append(element) | |
self._heapify_up(len(self.x) - 1) | |
def _heapify_up(self, child_idx: int) -> None: | |
if child_idx == 0: | |
return | |
parent_idx = self._parent_idx(child_idx) | |
child = self.x[child_idx] | |
parent = self.x[parent_idx] | |
if self._compare(child, parent): | |
return | |
self.x[child_idx], self.x[parent_idx] = self.x[parent_idx], self.x[child_idx] | |
self._heapify_up(parent_idx) | |
def _heapify_down(self, parent_idx: int) -> None: | |
child_idx = (2 * parent_idx) + 1 | |
if child_idx >= len(self.x): | |
return | |
if child_idx + 1 < len(self.x): | |
if self._compare(self.x[child_idx], self.x[child_idx + 1]): | |
# use other child | |
child_idx += 1 | |
child = self.x[child_idx] | |
parent = self.x[parent_idx] | |
if self._compare(child, parent): | |
return | |
self.x[child_idx], self.x[parent_idx] = self.x[parent_idx], self.x[child_idx] | |
self._heapify_down(child_idx) | |
def _parent_idx(self, child_idx) -> int: | |
return (child_idx - 1) // 2 | |
def _child_idx(self, parent_idx) -> Union[None, Tuple[int], Tuple[int, int]]: | |
child_idx = (2 * parent_idx) + 1 | |
if child_idx >= len(self.x) - 1: | |
return None | |
if child_idx >= len(self.x) - 2: | |
return child_idx | |
return (child_idx, child_idx + 1) | |
def _compare(self, child: T, parent: T) -> bool: | |
if self.order == "min": | |
return self.key(child) >= self.key(parent) | |
return self.key(child) <= self.key(parent) | |
def test_min_heap(): | |
h = Heap() | |
h.add(1) | |
h.add(3) | |
h.add(5) | |
h.add(2) | |
h.add(1) | |
h.add(9) | |
assert h.top() == 1 | |
assert h.pop() == 1 | |
assert h.pop() == 1 | |
assert h.pop() == 2 | |
assert h.pop() == 3 | |
assert h.pop() == 5 | |
assert h.pop() == 9 | |
assert h.pop() is None | |
def test_max_heap(): | |
h = Heap(order="max") | |
h.add(1) | |
h.add(3) | |
h.add(5) | |
h.add(2) | |
h.add(1) | |
h.add(9) | |
assert h.top() == 9 | |
assert h.pop() == 9 | |
assert h.pop() == 5 | |
assert h.pop() == 3 | |
assert h.pop() == 2 | |
assert h.pop() == 1 | |
assert h.pop() == 1 | |
assert h.pop() is None | |
def test_custom_key(): | |
Car = namedtuple("Car", ["make", "model", "value"]) | |
h = Heap(key=lambda x: x.value) | |
h.add(Car("Honda", "Pilot", 2000)) | |
h.add(Car("Acura", "RSX", 3000)) | |
h.add(Car("Ford", "Explorer", 1000)) | |
assert h.pop().make == "Ford" | |
assert h.pop().make == "Honda" | |
assert h.pop().make == "Acura" | |
assert h.pop() is None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment