Skip to content

Instantly share code, notes, and snippets.

@dvdvdmt
Last active May 6, 2017 12:10
Show Gist options
  • Save dvdvdmt/8e27b7e33eb4b3fb000ab6e7f6feac44 to your computer and use it in GitHub Desktop.
Save dvdvdmt/8e27b7e33eb4b3fb000ab6e7f6feac44 to your computer and use it in GitHub Desktop.
Max heap based on list
import heapq
# Max heap based on list
class MaxHeap(list):
def __init__(self, l):
super().__init__(l)
for i in range(self.parent(len(self)), -1, -1):
self.sift_down(i)
def sift_up(self, i):
while i > 0 and self[i] > self[self.parent(i)]:
self[i], self[self.parent(i)] = self[self.parent(i)], self[i]
i = self.parent(i)
def sift_down(self, i=0, size=None):
if not size:
size = len(self)
while True:
max_index = i
l_child = self.l_child(i)
r_child = self.r_child(i)
if l_child < size and self[l_child] > self[max_index]:
max_index = l_child
if r_child < size and self[r_child] > self[max_index]:
max_index = r_child
if i != max_index:
self[i], self[max_index] = self[max_index], self[i]
i = max_index
else:
break
def insert(self, p):
self.append(p)
self.sift_up(len(self) - 1)
def get_max(self):
return self[0]
def extract_max(self):
res = self[0]
self[0] = self.pop()
self.sift_down()
return res
def remove(self, i):
self[i] = float('inf')
self.sift_up(i)
self.extract_max()
def change_priority(self, i, p):
old_p = self[i]
self[i] = p
if p > old_p:
self.sift_up(i)
else:
self.sift_down(i)
def parent(self, i):
return round((i - 1) / 2)
def l_child(self, i):
return 2 * i + 1
def r_child(self, i):
return 2 * i + 2
# example of inplace sorting in ascending order
def heap_sort(ar):
h = MaxHeap(ar)
size = len(h)
while size > 1:
h[0], h[size - 1] = h[size - 1], h[0]
size -= 1
h.sift_down(size=size)
return h
# tests for max element and sort
l1 = [1, 2, 3, 5, 5, 6]
l2 = list(l1)
heapq._heapify_max(l2)
assert MaxHeap(l1).get_max() == l2[0]
assert list(heap_sort(l1)) == sorted(l2)
l1 = [7, 1, 2, 3, 4, 900]
l2 = list(l1)
heapq._heapify_max(l2)
assert MaxHeap(l1).get_max() == l2[0]
assert list(heap_sort(l1)) == sorted(l2)
l1 = [0, 0, 0, 1, -1]
l2 = list(l1)
heapq._heapify_max(l2)
assert MaxHeap(l1).get_max() == l2[0]
assert list(heap_sort(l1)) == sorted(l2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment