Skip to content

Instantly share code, notes, and snippets.

@nkaretnikov
Created September 29, 2018 08:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nkaretnikov/d8226c3a6df7e588f2aa524a9fb1e1a8 to your computer and use it in GitHub Desktop.
Save nkaretnikov/d8226c3a6df7e588f2aa524a9fb1e1a8 to your computer and use it in GitHub Desktop.
min-heap
#!/usr/bin/env python
# min-heap
# Construct a heap from an array respecting the shape property.
def build_heap(A):
n = len(A)
# Implicit cast to int.
for i in reversed(range(n / 2)):
heapify(A, n, i)
# Swap two elements.
def swap(A, n, m):
tmp = A[n]
A[n] = A[m]
A[m] = tmp
# Restore the heap property when removing an element.
def heapify(A, n, i):
left = 2 * i + 1
right = 2 * i + 2
if (left < n) and (A[left] < A[i]):
min = left
else:
min = i
if (right < n) and (A[right] < A[min]):
min = right
if (min != i):
swap(A, i, min)
heapify(A, n, min)
# Extract the minimum element, the heap property is not respected afterwards:
# call heapify on the root to restore it.
def extract(A, n):
if (n <= 0):
import sys
sys.stdout.write("Empty heap\n")
exit(1)
elif (n == 1):
return A.pop()
else:
first = A[0]
A[0] = A.pop() # make last the new first, and shrink the list
return first
# Walk the tree and check that the heap property holds.
def traverse(xs):
n = len(xs)
def go(i):
left = 2 * i + 1
right = 2 * i + 2
if left < n:
assert xs[i] <= xs[left]
go(left)
if right < n:
assert xs[i] <= xs[right]
go(right)
go(0)
# A quick-n-dirty way to visualize a tree (XXX: need to put more thought into calculating indexes).
def print_tree(xs):
N = len(xs)
def go(xs, i):
n = len(xs)
shift = n / (2**i)
m = 1 if N % 2 == 0 else 0
pad = (3 + m) + n - (2**i)
if n == 0:
return
else:
import sys
sys.stdout.write(" " * shift)
for x in range(2**i):
if x >= n:
return
sys.stdout.write("{}{}".format(xs[x], " " * pad))
print ""
xs = xs[2**i:]
i += 1
go(xs, i)
go(xs, 0)
print ""
# x
# 1x
# x1x
# 3 x
# 1x3 x
# x1x1x1x
# 6 x
# 2 x6 x
# x3 x x x
# Test
def test(inp):
from itertools import permutations
for perm in permutations(inp):
print "----"
print("build_heap")
A0 = list(perm)
build_heap(A0)
print("in: {}, out: {}".format(list(perm), A0))
print_tree(A0)
traverse(A0)
if (len(A0) <= 0):
continue
print("extract: {}".format(extract(A0, len(A0))))
print_tree(A0)
print("build_heap again")
# build_heap(A0)
heapify(A0, len(A0), 0)
print_tree(A0)
traverse(A0)
print "===="
# Generate permutations of a list:
# https://en.wikipedia.org/wiki/Heap%27s_algorithm
def perms(n, A):
c = []
for i in range(n):
c.append(0)
for i in range(n):
c[i] = 0
print(A)
i = 0
while i < n:
if c[i] < i:
if i % 2 == 0:
swap(A, 0, i)
else:
swap(A, c[i], i)
print(A)
c[i] += 1
i = 0
else:
c[i] = 0
i += 1
A = [1,2,3,4]
perms(len(A), A)
# test([])
# test([3])
# test([3,7])
# test([3,7,2])
# test([3,7,2,4])
# test([3,7,2,4,5])
# test([3,7,2,4,5,1])
# test([3,7,2,4,5,1,6])
# # test([2,2,2,2,2,3,3])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment