Skip to content

Instantly share code, notes, and snippets.

@zed
Created October 2, 2011 11:21
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zed/1257360 to your computer and use it in GitHub Desktop.
Save zed/1257360 to your computer and use it in GitHub Desktop.
naive quicksort in Cython #qsort
#
/cachegrind.out.profilestats
/profilestats.prof
cimport numpy as np
DEF CUTOFF = 17
def sort(np.ndarray[np.float_t,ndim=1] a):
qsort(<np.float_t*>a.data, 0, a.shape[0])
cdef void qsort(np.float_t* a, Py_ssize_t start, Py_ssize_t end):
if (end - start) < CUTOFF:
insertion_sort(a, start, end)
return
cdef Py_ssize_t boundary = partition(a, start, end)
qsort(a, start, boundary)
qsort(a, boundary+1, end)
cdef Py_ssize_t partition(np.float_t* a, Py_ssize_t start, Py_ssize_t end):
assert end > start
cdef Py_ssize_t i = start, j = end-1
cdef np.float_t pivot = a[j]
while True:
# assert all(x < pivot for x in a[start:i])
# assert all(x >= pivot for x in a[j:end])
while a[i] < pivot:
i += 1
while i < j and pivot <= a[j]:
j -= 1
if i >= j:
break
assert a[j] < pivot <= a[i]
swap(a, i, j)
assert a[i] < pivot <= a[j]
assert i >= j and i < end
swap(a, i, end-1)
assert a[i] == pivot
# assert all(x < pivot for x in a[start:i])
# assert all(x >= pivot for x in a[i:end])
return i
cdef inline void swap(np.float_t* a, Py_ssize_t i, Py_ssize_t j):
a[i], a[j] = a[j], a[i]
cdef void insertion_sort(np.float_t* a, Py_ssize_t start, Py_ssize_t end):
cdef Py_ssize_t i, j
cdef np.float_t v
for i in range(start, end):
#invariant: [start:i) is sorted
v = a[i]; j = i-1
while j >= start:
if a[j] <= v: break
a[j+1] = a[j]
j -= 1
a[j+1] = v
#!/usr/bin/env python
import pyximport; pyximport.install()
from timeit import default_timer as timer
from quicksort import sort as qsort
def measure(f):
def wrapper(L):
start = timer()
try: f(L)
finally:
return timer() - start
return wrapper
measure_qsort = measure(qsort)
def run_tests():
import numpy as np
for L in map(np.array, [[], [1.], [1.]*10]):
old = L.copy()
qsort(L)
assert np.all(old == L)
for n in np.r_[np.arange(2, 10),11,101,100,1000]:
L = np.random.random(int(n))
old = sorted(L.copy())
t = float("+inf")
for _ in xrange(10):
np.random.shuffle(L)
t = min(measure_qsort(L), t)
assert np.allclose(old, L), L
report_time(n, t)
for n in np.r_[1e4,1e5,1e6]:
t = float("+inf")
for _ in range(3):
t = min(measure_qsort(np.random.random_sample(n)), t)
report_time(n, t)
def report_time(n, t):
print "N=%09d took us %.2g\t%s" % (n, t*1000, "ms")
if __name__=="__main__":
run_tests()
# N=000000002 took us 0.00095 ms
# N=000000003 took us 0.00095 ms
# N=000000004 took us 0.00095 ms
# N=000000005 took us 0.00095 ms
# N=000000006 took us 0.00095 ms
# N=000000007 took us 0.0019 ms
# N=000000008 took us 0.00095 ms
# N=000000009 took us 0.00095 ms
# N=000000011 took us 0.00095 ms
# N=000000101 took us 0.0048 ms
# N=000000100 took us 0.005 ms
# N=000001000 took us 0.06 ms
# N=000010000 took us 0.96 ms
# N=000100000 took us 10 ms
# N=001000000 took us 1.2e+02 ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment