Skip to content

Instantly share code, notes, and snippets.

@vurtun
Last active January 11, 2023 15:22
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vurtun/1c84e0956ab02685e8d74e01aadbdf3b to your computer and use it in GitHub Desktop.
Save vurtun/1c84e0956ab02685e8d74e01aadbdf3b to your computer and use it in GitHub Desktop.
// ref: http://www.codercorner.com/RadixSortRevisited.htm
// http://stereopsis.com/radix.html
// int/float: https://github.com/lshamis/FloatRadixSort
// string: https://github.com/rantala/string-sorting/blob/master/src/msd_ce.cpp
struct str {
const char *str;
const char *end;
int len;
};
/* ---------------------------------------------------------------------------
* Sort
* ---------------------------------------------------------------------------
*/
#include <stdio.h>
#include <assert.h>
#include <string.h>
#define cast(t, p) ((t)(p))
#define for_nstep(i,n,s) for (int i = 0; i < (n); i += (s))
#define for_cnt(i,n) for_nstep(i,n,1)
#define fori_cnt(i,n) for (i = 0; i < (n); i += 1)
// clang-format off
typedef unsigned(*sort_conv_f)(const void *p);
typedef void*(sort_access_f)(const void *data, void *usr);
#define sort__access(a,usr,access,conv,off) ((access) ? (conv)((access)(a + off, usr)) : (conv)(a + off))
#define sort__char_at(s,d) (((d) < (s)->len) ? (s)->str[d] : -1)
#define sort__str_get(a,access,usr) (struct str*)((access) ? (access(a, usr)) : (a))
static inline unsigned sort__cast_ushort(const void *p) {return *(const unsigned short*)p;}
static inline unsigned sort__cast_short(const void *p) {union bit_castu {short i; unsigned short u;} v = {.i = *(const short*)p}; return v.u ^ (1u << 15u);}
static inline unsigned sort__cast_uint(const void *p) {return *(const unsigned*)p;}
static inline unsigned sort__cast_int(const void *p) {union bit_castu {int i; unsigned u;} v = {.i = *(const int*)p}; return v.u ^ (1u << 31u);}
static inline unsigned sort__cast_flt(const void *p) {union bit_castu {float f; unsigned u;} v = {.f = *(const float*)p}; if ((v.u >> 31u) == 1u) {v.u *= (unsigned)-1; v.u ^= (1u << 31u);}return v.u ^ (1u << 31u);}
#define sort_short(out,a,siz,n,off) sort_radix16(out,a,siz,n,off,0,0,sort__cast_short)
#define sort_ushort(out,a,siz,n,off) sort_radix16(out,a,siz,n,off,0,0,sort__cast_ushort)
#define sort_int(out,a,siz,n,off) sort_radix32(out,a,siz,n,off,0,0,sort__cast_int)
#define sort_uint(out,a,siz,n,off) sort_radix32(out,a,siz,n,off,0,0,sort__cast_uint)
#define sort_flt(out,a,siz,n,off) sort_radix32(out,a,siz,n,off,0,0,sort__cast_flt)
#define sort_str(out,a,n,siz,off) sort__str(out,a,n,siz,off,0,0)
#define sort_shorts(out,a,n) sort_short(out,a,szof(short),n,0)
#define sort_ushorts(out,a,n) sort_ushort(out,a,szof(unsigned short),n,0)
#define sort_ints(out,a,n) sort_int(out,a,szof(int), n,0)
#define sort_uints(out,a,n) sort_uint(out,a,szof(unsigned),n,0)
#define sort_flts(out,a,n) sort_flt(out,a,szof(float),n,0)
// clang-format on
static void
sort__radix16(unsigned *restrict out, const void *a, int siz, int n, int off,
void *usr, sort_access_f access, sort_conv_f conv) {
assert(a);
assert(out);
/* <!> out needs to be at least size: 2*n+512 <!> */
unsigned *buf = out + 2 * n;
unsigned *restrict h[] = {buf, buf + 256};
const unsigned char *b = cast(const unsigned char*, a);
const unsigned char *e = cast(const unsigned char*, a) + n * siz;
/* build histogram */
int is_sorted = 1;
memset(buf, 0, 512 * sizeof(unsigned));
unsigned last = sort__access(b, usr, access, conv, off);
for (const unsigned char *it = b; it < e; it += siz) {
unsigned k = sort__access(it, usr, access, conv, off);
is_sorted = (k < last) ? 0 : is_sorted;
h[0][k & 0xff]++;
h[1][(k >> 8) & 0xff]++;
last = k;
}
if (is_sorted) {
return; /* already sorted so early out */
}
/* convert histogram into offset table */
unsigned sum[2] = {0};
for_cnt(i,256) {
unsigned t0 = h[0][i] + sum[0]; h[0][i] = sum[0], sum[0] = t0;
unsigned t1 = h[1][i] + sum[1]; h[1][i] = sum[1], sum[1] = t1;
}
/* sort 8-bit at a time */
unsigned *restrict idx[] = {out, out + n};
for (int p = 0, d = 1, s = 0; p < 2; ++p, d = !d, s = !s) {
for (unsigned i = 0u; i != cast(unsigned,n); ++i) {
unsigned at = idx[s][i];
unsigned k = sort__access(b + at * (unsigned)siz, usr, access, conv, off);
idx[d][h[p][(k>>(8*p))&0xff]++] = at;
}
}
}
static void
sort_radix16(unsigned *restrict out, const void *a, int siz, int n, int off,
void *usr, sort_access_f access, sort_conv_f conv) {
assert(a);
assert(out);
for_cnt(i,n) {out[i] = cast(unsigned,i);}
sort__radix16(out, a, siz, n, off, usr, access, conv);
}
static void
sort__radix32(unsigned *restrict out, const void *a, int siz, int n, int off,
void *usr, sort_access_f access, sort_conv_f conv) {
assert(a);
assert(out);
/* <!> out needs to be at least size: 2*n+1024 <!> */
unsigned *buf = out + 2 * n;
unsigned *restrict h[] = {buf, buf + 256, buf + 512, buf + 768};
const unsigned char *b = cast(const unsigned char*, a);
const unsigned char *e = cast(const unsigned char*, a) + n * siz;
/* build histogram */
int is_sorted = 1;
memset(buf,0,1024*sizeof(unsigned));
unsigned last = sort__access(b, usr, access, conv, off);
for (const unsigned char *it = b; it < e; it += siz) {
unsigned k = sort__access(it, usr, access, conv, off);
is_sorted = (k < last) ? 0 : is_sorted;
h[0][(k & 0xff)]++;
h[1][(k >> 8) & 0xff]++;
h[2][(k >> 16) & 0xff]++;
h[3][(k >> 24)]++;
last = k;
}
if (is_sorted) {
return; /* already sorted so early out */
}
/* convert histogram into offset table */
unsigned sum[4] = {0};
for_cnt(i,256) {
unsigned t0 = h[0][i] + sum[0]; h[0][i] = sum[0]; sum[0] = t0;
unsigned t1 = h[1][i] + sum[1]; h[1][i] = sum[1]; sum[1] = t1;
unsigned t2 = h[2][i] + sum[2]; h[2][i] = sum[2]; sum[2] = t2;
unsigned t3 = h[3][i] + sum[3]; h[3][i] = sum[3]; sum[3] = t3;
}
/* sort 8-bit at a time */
unsigned *restrict idx[] = {out, out + n};
for (int p = 0, d = 1, s = 0; p < 4; ++p, d = !d, s = !s) {
for (unsigned i = 0u; i != cast(unsigned,n); ++i) {
unsigned at = idx[s][i];
unsigned k = sort__access(b + at * (unsigned)siz, usr, access, conv, off);
idx[d][h[p][(k>>(8*p))&0xff]++] = at;
}
}
}
static void
sort_radix32(unsigned *restrict out, const void *a, int siz, int n, int off,
void *usr, sort_access_f access, sort_conv_f conv) {
assert(a);
assert(out);
for_cnt(i,n) {out[i] = cast(unsigned,i);}
sort__radix32(out, a, siz, n, off, usr, access, conv);
}
static char
sort__str_at(unsigned char *p, int d, sort_access_f access, void *usr) {
struct str * s = sort__str_get(p, access, usr);
return sort__char_at(s,d);
}
static void
sort__str_q3s(int *rnk, void *a, int lo, int hi, int d, int siz, int off,
sort_access_f access, void *usr) {
if (hi <= lo) return;
unsigned char *p = a;
int lt = lo, gt = hi, i = lo + 1;
int v = sort__str_at(p + rnk[lo] * siz + off, d, access, usr);
while (i <= gt) {
int t = sort__str_at(p + rnk[i] * siz + off, d, access, usr);
if (t < v) {int tmp = rnk[lt]; rnk[lt++] = rnk[i]; rnk[i++] = tmp;}
else if(t > v) {int tmp = rnk[i]; rnk[i] = rnk[gt]; rnk[gt--] = tmp;}
else i++;
}
sort__str_q3s(rnk, a, lo, lt-1, d, siz, off, access, usr);
if (v >= 0) sort__str_q3s(rnk, a, lt, gt, d + 1, siz, off, access, usr);
sort__str_q3s(rnk, a, gt+1, hi, d, siz, off, access, usr);
}
static int*
sort__str_base(int *rnk, int *rnk2, short *oracle,
void *a, int n, int siz, int off, int lo, int hi,
sort_access_f access, void *usr, int d) {
unsigned char * p = a;
if (n < 32) {
sort__str_q3s(rnk, a, lo, hi, d, siz, off, access, usr);
return rnk;
}
int c[257] = {0};
for (int i = 0; i < n; ++i)
oracle[i] = sort__str_at(p + rnk[i] * siz + off, d, access, usr);
for (int i = 0; i < n; ++i)
++c[oracle[i] + 1];
int idx[257]; idx[0] = idx[1] = 0;
for (int i = 1; i < 256; ++i)
idx[i+1] = idx[i] + c[i];
for (int i = 0; i < n; ++i)
rnk2[idx[oracle[i]+1]++] = rnk[i];
int *tmp = rnk; rnk = rnk2; rnk2 = tmp;
int bsum = c[1];
for (int i = 1; i < 256; ++i) {
if (c[i + 1] == 0) continue;
int *res = sort__str_base(rnk, rnk2, oracle, a, c[i+1], siz, off, bsum, bsum + c[i+1]-1, access, usr, d+1);
if (res != rnk) {
tmp = rnk; rnk = rnk2; rnk2 = tmp;
}
bsum += c[i+1];
}
return rnk;
}
static int*
sort__str(int *rnk, int *rnk2, short *oracle,
void *a, int n, int siz, int off,
sort_access_f access, void *usr) {
for (int i = 0; i < n; ++i) rnk[i] = i;
return sort__str_base(rnk, rnk2, oracle, a, n, siz, off, 0, n-1, access, usr, 0);
}
/* ---------------------------------------------------------------------------
* Test
* ---------------------------------------------------------------------------
*/
#include <stdio.h>
#define unused(a) ((void)a)
#define cast(t, p) ((t)(p))
#define szof(a) ((int)sizeof(a))
#define cntof(a) ((int)(sizeof(a) / sizeof((a)[0])))
#define str(s,n) (struct str){s, (s) + (n), (n)}
#define strv(s) str(s, cntof(s)-1)
int main(void) {
#if 1
int arr[] = { 256, -36789, 170, 45, 75, 1765987, 90, 802, -24, 2, -66, 17895 };
int sorted[cntof(arr) * 2 + 1024];
sort_ints(sorted, arr, cntof(arr));
for (int i = 0; i < cntof(arr); ++i) {
printf("%d\n", arr[sorted[i]]);
}
#elif 0
float arr[] = { 170.0f, 0.001f, -0.05f, 20.0f, -30.0f, 802.0f, 2.5f, 2000.65f, -12.54f, 66.0f };
int sorted[cntof(arr) * 2 + 1024];
sort_flts(sorted, arr, cntof(arr));
for (int i = 0; i < cntof(arr); ++i) {
printf("%.3f\n", arr[sorted[i]]);
}
#else
struct str arr[] = {
strv("aaaba"),
strv("dfjasdlifjai"),
strv("jiifjeogiejogp"),
strv("aabaaaa"),
strv("gsgj"),
strv("gerph"),
strv("aaaaaaa"),
strv("htjltjlrth"),
strv("joasdjfisdjfdo"),
strv("hthe"),
strv("aaaaaba"),
strv("jrykpjl"),
strv("hkoptjltp"),
strv("aaaaaa"),
strv("lprrjt")
};
short tmp3[cntof(arr)];
int tmp[cntof(arr)], tmp2[cntof(arr)];
int *idx = sort_str(tmp, tmp2, tmp3, arr, cntof(arr), sizeof(arr[0]), 0);
for (int i = 0; i < cntof(arr); ++i) {
printf("%.*s\n", arr[idx[i]].len, arr[idx[i]].str);
}
#endif
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment