Created
October 2, 2011 11:21
-
-
Save zed/1257360 to your computer and use it in GitHub Desktop.
naive quicksort in Cython #qsort
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
/cachegrind.out.profilestats | |
/profilestats.prof |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cimport numpy as np | |
DEF CUTOFF = 17 | |
def sort(np.ndarray[np.float_t,ndim=1] a): | |
qsort(<np.float_t*>a.data, 0, a.shape[0]) | |
cdef void qsort(np.float_t* a, Py_ssize_t start, Py_ssize_t end): | |
if (end - start) < CUTOFF: | |
insertion_sort(a, start, end) | |
return | |
cdef Py_ssize_t boundary = partition(a, start, end) | |
qsort(a, start, boundary) | |
qsort(a, boundary+1, end) | |
cdef Py_ssize_t partition(np.float_t* a, Py_ssize_t start, Py_ssize_t end): | |
assert end > start | |
cdef Py_ssize_t i = start, j = end-1 | |
cdef np.float_t pivot = a[j] | |
while True: | |
# assert all(x < pivot for x in a[start:i]) | |
# assert all(x >= pivot for x in a[j:end]) | |
while a[i] < pivot: | |
i += 1 | |
while i < j and pivot <= a[j]: | |
j -= 1 | |
if i >= j: | |
break | |
assert a[j] < pivot <= a[i] | |
swap(a, i, j) | |
assert a[i] < pivot <= a[j] | |
assert i >= j and i < end | |
swap(a, i, end-1) | |
assert a[i] == pivot | |
# assert all(x < pivot for x in a[start:i]) | |
# assert all(x >= pivot for x in a[i:end]) | |
return i | |
cdef inline void swap(np.float_t* a, Py_ssize_t i, Py_ssize_t j): | |
a[i], a[j] = a[j], a[i] | |
cdef void insertion_sort(np.float_t* a, Py_ssize_t start, Py_ssize_t end): | |
cdef Py_ssize_t i, j | |
cdef np.float_t v | |
for i in range(start, end): | |
#invariant: [start:i) is sorted | |
v = a[i]; j = i-1 | |
while j >= start: | |
if a[j] <= v: break | |
a[j+1] = a[j] | |
j -= 1 | |
a[j+1] = v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import pyximport; pyximport.install() | |
from timeit import default_timer as timer | |
from quicksort import sort as qsort | |
def measure(f): | |
def wrapper(L): | |
start = timer() | |
try: f(L) | |
finally: | |
return timer() - start | |
return wrapper | |
measure_qsort = measure(qsort) | |
def run_tests(): | |
import numpy as np | |
for L in map(np.array, [[], [1.], [1.]*10]): | |
old = L.copy() | |
qsort(L) | |
assert np.all(old == L) | |
for n in np.r_[np.arange(2, 10),11,101,100,1000]: | |
L = np.random.random(int(n)) | |
old = sorted(L.copy()) | |
t = float("+inf") | |
for _ in xrange(10): | |
np.random.shuffle(L) | |
t = min(measure_qsort(L), t) | |
assert np.allclose(old, L), L | |
report_time(n, t) | |
for n in np.r_[1e4,1e5,1e6]: | |
t = float("+inf") | |
for _ in range(3): | |
t = min(measure_qsort(np.random.random_sample(n)), t) | |
report_time(n, t) | |
def report_time(n, t): | |
print "N=%09d took us %.2g\t%s" % (n, t*1000, "ms") | |
if __name__=="__main__": | |
run_tests() | |
# N=000000002 took us 0.00095 ms | |
# N=000000003 took us 0.00095 ms | |
# N=000000004 took us 0.00095 ms | |
# N=000000005 took us 0.00095 ms | |
# N=000000006 took us 0.00095 ms | |
# N=000000007 took us 0.0019 ms | |
# N=000000008 took us 0.00095 ms | |
# N=000000009 took us 0.00095 ms | |
# N=000000011 took us 0.00095 ms | |
# N=000000101 took us 0.0048 ms | |
# N=000000100 took us 0.005 ms | |
# N=000001000 took us 0.06 ms | |
# N=000010000 took us 0.96 ms | |
# N=000100000 took us 10 ms | |
# N=001000000 took us 1.2e+02 ms | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment