Skip to content

Instantly share code, notes, and snippets.

@d1manson
Last active December 3, 2015 15:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save d1manson/31b09afa2a46e59372da to your computer and use it in GitHub Desktop.
Save d1manson/31b09afa2a46e59372da to your computer and use it in GitHub Desktop.
vectorized radix sort, but a lot slower than just np.sort
"""
If the data is for *small* integers, then you can use counting sort with
np.bincount and np.repeat.
Note that this implementation is currently limtied to non-negative ints.
To deal with negative ints, you could just add a step at the end which
finds the first negative number and then moves all negative (newly sorted)
numbers to the front...for example:
[ 0 15 15 17 19 26 34 84 99 -12 -11 -3]
|------------|
move to front
"""
def radix_sort(a, batch_m_bits=4):
bit_len = np.max(a).bit_length()
n = len(a)
batch_m = 2**batch_m_bits
mask = 2**batch_m_bits - 1
k_shifts = int(bit_len/batch_m_bits) + (1 if bit_len % batch_m_bits else 0)
for shift in range(k_shifts):
a_shifted_masked = (a >> (shift*batch_m_bits)) & mask
counts = np.bincount(a_shifted_masked, minlength=batch_m)
cumsum_counts = np.cumsum(counts) - counts
new_a = np.empty(n, dtype=a.dtype)
for ii, (len_ii, start_ii) in enumerate(zip(counts, cumsum_counts)):
new_a[start_ii:start_ii+len_ii] = a[a_shifted_masked==ii]
a = new_a
return a
"""
# slightly faster alternative if len(a) is power of 2
def radix_sort(a, batch_m_bits=3):
bit_len = np.max(a).bit_length()
assert(len(a) == 1 << (len(a).bit_length() -1))
batch_m = 2**batch_m_bits
mask = 2**batch_m_bits - 1
val_set = np.arange(batch_m, dtype=a.dtype)[:, nax] # nax = np.newaxis
for _ in range((bit_len-1)//batch_m_bits + 1): # ceil-division
a = a[np.flatnonzero((a & mask)[nax, :] == val_set) & (len(a) -1)]
val_set <<= batch_m_bits
mask <<= batch_m_bits
return a
"""
# simple example:
a = np.array([34, 19, 26, 15, 11, 3, 0, 15, 12, 84, 99, 17])
print "example..."
print a
print "radix_sort..."
print radix_sort(a)
print ""
# test and benchmark against numpy quicksort...
a = np.random.randint(0,1e8,1e6)
assert(np.all(radix_sort(a) == np.sort(a)))
print "test passed for large random array"
print ""
print "timeit np.sort..."
%timeit np.sort(a)
print "timeit radix_sort..."
%timeit radix_sort(a) # about 6x slower than numpy quicksort..oh well!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment