Last active
May 3, 2018 17:50
-
-
Save JustAPerson/2a454b8cf9362ada9d2f83b5e6f68fe2 to your computer and use it in GitHub Desktop.
Extra fast radix sort variant. O(n log(w / log(n)) where w is the width of the number in bits. http://courses.csail.mit.edu/6.897/spring05/psets/ps7.pdf http://courses.csail.mit.edu/6.897/spring05/psets/ps7-sol.pdf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import math | |
def hi(x, halfw): | |
return x >> halfw | |
def lo(x, halfw): | |
return x & (1 << halfw) - 1 | |
def get_hiS(S, halfw): | |
hiS = {} | |
for x in S: | |
h = hi(x, halfw) | |
hiS[h] = x if h not in hiS else min(hiS[h], x) | |
return hiS | |
def get_loS(S, hiS, halfw): | |
loS = {} | |
for x in S: | |
h = hi(x, halfw) | |
l = lo(x, halfw) | |
if x == hiS[h]: continue | |
if l in loS: loS[l].append(x) | |
else: loS[l] = [x] | |
return loS | |
def get_Sprime(S, hiS, halfw): | |
return set(hiS.keys()) | set(lo(x, halfw) for x in S if hiS[hi(x, halfw)] != x) | |
def count_sort(S, w): | |
m = 1 << w | |
l = [0] * (m + 1) | |
for x in S: | |
l[x] += 1 | |
s = [] | |
for i, n in enumerate(l): | |
s.extend([i] * n) | |
return s | |
def sort(S, w = None, L = None): | |
if not L: L = len(S) | |
if not w: w = int(math.ceil(math.log(max(S), 2))) | |
if w <= math.log(L, 2) and w <= 4: | |
return count_sort(S, w) | |
else: | |
halfw = int(math.ceil(w / 2.0)) | |
hiS = get_hiS(S, halfw) | |
mins = {k: 0 for k in hiS.values()} | |
for v in S: | |
if v in mins: | |
mins[v] += 1 | |
Sprime = get_Sprime(S, hiS, halfw) | |
Sprime = sort(Sprime, halfw, L) | |
loS = get_loS(S, hiS, halfw) | |
hiS = {k: [v] * mins[v] for k, v in hiS.items()} | |
for t in Sprime: | |
if t not in loS: continue | |
for x in loS[t]: | |
hiS[hi(x, halfw)].append(x) | |
S = [] | |
for t in Sprime: | |
if t not in hiS: continue | |
S.extend(hiS[t]) | |
return S | |
def rand(): | |
return random.randint(0, 1 << 32) | |
L = [rand() for _ in range(1<<20)] | |
L = sort(L) | |
# L.sort() | |
# S = L[:] | |
# S.sort() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment