Skip to content

Instantly share code, notes, and snippets.

@pitrou
Created March 3, 2016 11:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pitrou/e12353a5839fb60d3fc1 to your computer and use it in GitHub Desktop.
Save pitrou/e12353a5839fb60d3fc1 to your computer and use it in GitHub Desktop.
import numba as nb
import numpy as np
from time import process_time as clock
class Timer:
def __init__(self,title=None):
self.title=title
def __enter__(self):
if self.title:
print( 'Beginning {0}'.format( self.title ) )
self.start = clock()
return self
def __exit__(self, *args):
self.end = clock()
self.interval = self.end - self.start
if self.title:
print( '{1} took {0:0.4f} seconds'.format( self.interval, self.title ) )
else:
pass#
#print( 'Timer took {0:0.4f} seconds'.format( self.interval ) )
@nb.jit(nopython=True)
def insertion_sort(A, low, high):
"""
Insertion sort A[low:high + 1]. Note the inclusive bounds.
"""
for i in range(low + 1, high + 1):
v = A[i]
# Insert v into A[low:i]
j = i
while j > low and v < A[j - 1]:
# Make place for moving A[i] downwards
A[j] = A[j - 1]
j -= 1
A[j] = v
@nb.jit( nopython=True )
def merge2(x):
n = x.shape[0]
r = x.copy()
tgt = np.zeros_like(r)
# Start with an insertion sort of small chunks
width = 25
i = 0
while i < n:
istart = i
iend = istart + width
if iend > n:
iend = n
insertion_sort(r, istart, iend - 1)
i = iend
# Merge sorted chunks, bottom-up
while width < n:
i = 0
while i < n:
istart = i
imid = i + width
iend = imid + width
# i has become i + 2*width
if imid > n:
imid = n
if iend > n:
iend = n
i = iend
_merge2(r, tgt, istart, imid, iend)
# Swap them round, so that the partially sorted tgt becomes the result,
# and the result becomes a new target buffer
r, tgt = tgt, r
width *= 2
return r
@nb.jit( nopython=True )
def _merge2(src_arr, tgt_arr, istart, imid, iend):
""" The merge part of the merge sort """
i0 = istart
i1 = imid
ipos = istart
v0 = src_arr[i0]
v1 = src_arr[i1]
while i0 < imid and i1 < iend:
if v0 <= v1:
tgt_arr[ipos] = v0
i0 += 1
v0 = src_arr[i0]
else:
tgt_arr[ipos] = v1
i1 += 1
v1 = src_arr[i1]
ipos += 1
while i0 < imid:
tgt_arr[ipos] = src_arr[i0]
ipos += 1
i0 += 1
while i1 < iend:
tgt_arr[ipos] = src_arr[i1]
ipos += 1
i1 += 1
def test_merge_multi():
np.random.seed(42)
n0 = 20
n1 = 500000
nsteps = 30
src = np.random.random_integers(0, n1, size=n1).astype(np.int32)
# JIT warmup
merge2(src[:2])
for n in np.logspace(np.log10(n0), np.log10(n1), nsteps, dtype=np.intc):
x = src[:n]
with Timer() as t0:
r = merge2(x)
with Timer() as t1:
e = np.sort(x, kind='merge')
print('n = %6s => nb/np duration %.2f' % (n, t0.interval / t1.interval))
np.testing.assert_equal(e, r)
test_merge_multi()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment