Skip to content

Instantly share code, notes, and snippets.

@findstr
Created November 1, 2018 03:03
Show Gist options
  • Save findstr/98edcbd720aee1e1460426e3da252be2 to your computer and use it in GitHub Desktop.
Save findstr/98edcbd720aee1e1460426e3da252be2 to your computer and use it in GitHub Desktop.
dc3
#include <assert.h>
#include <time.h>
#include <ctype.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
//#define DEBUG
#define TERM (0)
#define V(n) buf[(n)]
#define NEW(t, n) malloc((n) * sizeof(t))
#define ABS_TO_S12(x) ((x) % 3 == 1 ? ((x) / 3) : (x) / 3 + n0)
#define S12_TO_LABS(x) ((x) * 3+1)
#define S12_TO_RABS(x) (((x)-n0)*3+2)
#define S12_IS_LEFT(x) ((x) < n0)
#define S12_TO_ABS(x) (S12_IS_LEFT(x)?S12_TO_LABS(x):S12_TO_RABS(x))
static int
compare(const uint8_t *a, int na, const uint8_t *b, int nb)
{
int i, x;
x = na > nb ? nb : na;
for (i = 0; i < x; i++) {
if (a[i] < b[i])
return -1;
if (a[i] > b[i])
return 1;
}
if (na < nb)
return -1;
if (na > nb)
return 1;
return 0;
}
static void
print(const char *title, int *a, int s)
{
int i;
printf("%s:", title);
for (i = 0; i < s; i++) {
printf("'%d', ", a[i]);
}
printf("\n");
}
static void
printC(const char *title, const uint8_t *a, int s)
{
int i;
printf("%s:", title);
for (i = 0; i < s; i++) {
printf("'%c', ", a[i]);
}
printf("\n");
}
static void
print_rank(const uint8_t *buf, int r, int from, int size)
{
int i;
printf("rank:%02d start:%02d = '", r, from);
for (i = from; i < size; i++) {
if (isalpha(buf[i]))
printf("%c ", buf[i]);
else
printf("%d ", buf[i]);
}
printf("'\n");
}
static void
print_all_rank(const uint8_t *buf, int *SA, int sn, int size)
{
int i;
for (i = 0; i < sn; i++) {
print_rank(buf, i, SA[i], size);
}
}
static void
check_all_rank(const uint8_t *buf, int *SA, int sn, int size)
{
int i;
for (i = 0; i < sn - 1; i++) {
int eq = compare(&buf[SA[i]], size - SA[i], &buf[SA[i+1]], size - SA[i+1]);
if (eq > 0) {
printf("assert:%d sa:%d sa+1:%d res:%d\n", i, SA[i], SA[i+1], eq);
assert(eq < 0);
}
}
}
static inline int
leq2(int a1, int b1, int a2, int b2)
{
return (a1 < a2 || (a1 == a2 && b1 <= b2));
}
static inline int
leq3(int a1, int b1, int c1, int a2, int b2, int c2)
{
return (a1 < a2 || (a1 == a2 && leq2(b1, c1, b2, c2)));
}
static void
radix_pass(int *src, int *dst, int size, const int *buf, int space)
{
int i, n, *c;
space += 1;
c = NEW(int, space);
for (i = 0; i < space; i++)
c[i] = 0;
for (i = 0; i < size; i++) {
int n = V(src[i]);
++c[n];
}
n = 0;
for (i = 0; i < space; i++) {
int tmp = c[i];
c[i] = n;
n += tmp;
}
for (i = 0; i < size; i++) {
int pos = c[V(src[i])]++;
dst[pos] = src[i];
}
free(c);
return ;
}
/*
---------> s0 s1 s2 n0 n1 n2
index
\|/
123 123 230 300 1 1 1
1234 123,400 234,000 340 2 1 1
12345 123,450 234,500 345 2 2 1
*/
void
dc3_r(const int *buf, int *SA, int size, int space)
{
int i, lexname, v[3];
int n0, n1, n2, n12;
int *p, *SA0, *s0, *SA12, *s12;
n0 = (2+size) / 3;
n1 = (2+size - 1) / 3;
n2 = (2+size - 2) / 3;
n12 = n0 + n2;
//sort S1 S2
s0 = NEW(int, n0);
SA0 = NEW(int, n0);
s12 = NEW(int, n12 + 3);
SA12 = NEW(int, n12 + 3);
s12[n12] = s12[n12+1] = s12[n12+2] = TERM;
SA12[n12] = SA12[n12+1] = SA12[n12+2] = TERM;
p = s12;
for (i = 0; i < size + (n0 - n1); i++) {
if (i % 3 != 0)
*p++ = i;
}
assert(p - s12 == n12);
radix_pass(s12, SA12, n12, buf+2, space); //tri[i+2]
radix_pass(SA12, s12, n12, buf+1, space); //tri[i+1]
radix_pass(s12, SA12, n12, buf+0, space); //tri[i+0]
//build lexname for s12
lexname = 0;
s12[n12] = s12[n12 + 1] = s12[n12+2] = TERM;
v[0] = v[1] = v[2] = -1;
for (i = 0; i < n12; i++) {
int idx12;
int abs= SA12[i];
if (V(abs) != v[0] || V(abs+1) != v[1] || V(abs+2) != v[2]) {
++lexname;
v[0] = V(abs);
v[1] = V(abs+1);
v[2] = V(abs+2);
}
assert(lexname > 0);
assert(abs % 3 != 0);
idx12 = ABS_TO_S12(abs);
s12[idx12] = lexname;
}
#ifdef DEBUG
printf("---------enter size:%d\n", size);
print("SA12", SA12, n12);
printf("rank:%d\n", size);
printf("s12:\n");
for (i = 0; i < n12; i++)
printf("abs:%d rank:%d\n", S12_TO_ABS(i), s12[i]);
printf("\n");
#endif
assert(s12[n12] == s12[n12 + 1]);
assert(s12[n12+1] == s12[n12+2] && s12[n12 + 2] == TERM);
//now s12[idx] = rank which idx % 3 != 0
if ((lexname + 0) < n12) { //lexname conflict
dc3_r(s12, SA12, n12, lexname);
for (i = 0; i < n12; i++)
printf("ret abs:%d rank:%d\n", S12_TO_ABS(SA12[i]), i);
for (i = 0; i < n12; i++)
s12[SA12[i]] = i + 1;
} else {
for (i = 0; i < n12; i++)
SA12[s12[i] - 1] = i;
}
//sort S0
p = s0;
//radix pass 4
for (i = 0; i < n12; i++) {
int n = SA12[i];
if (S12_IS_LEFT(n))
*p++ = S12_TO_LABS(n) - 1;
}
assert(p - s0 == n0);
radix_pass(s0, SA0, n0, buf, space);
#ifdef DEBUG
printf("merge:\n");
print("s0", s0, n0);
print("s12", s12, n12);
#endif
int k, r0 = 0, r12 = n0 - n1;
for (k = 0; k < size; k++) {
int i, j, i12, j12;
i = SA12[r12];
j = SA0[r0];
assert(r12 < n12);
assert(r0 < n0);
int leq = 0;
// printf("merge item i:%d(%d), j:%d\n", i, S12_TO_ABS(i), j);
if (S12_IS_LEFT(i)) { //k mod 3 == 1, i+n1 ==> k mod 3 == 2
i = S12_TO_LABS(i);
i12 = ABS_TO_S12(i+1);
j12 = ABS_TO_S12(j+1);
leq = leq2(buf[i], s12[i12], buf[j], s12[j12]);
} else { //k mod 3 == 2, k mode 3 + 2 => (k mod 3 == 1)+1
i = S12_TO_RABS(i);
i12 = ABS_TO_S12(i+2);
j12 = ABS_TO_S12(j+2);
//printf("ii:%d n1:%d n2:%d i:%d size:%d\n", ii, n1, n2, i, size);
leq = leq3(buf[i], buf[i+1], s12[i12], buf[j], buf[j+1], s12[j12]);
}
if (leq) {
SA[k] = i;
r12++;
if (r12 == n12)
break;
} else {
SA[k] = j;
++r0;
if (r0 == n0)
break;
}
}
while (r0 < n0)
SA[++k] = SA0[r0++];
while (r12 < n12) {
int ii = SA12[r12++];
SA[++k] = S12_TO_ABS(ii);
}
#ifdef DEBUG
print("SA", SA, size);
printf("----leave %d\n", size);
#endif
//free buffer
free(s0);
free(s12);
free(SA0);
free(SA12);
return ;
}
int *
dc3(const uint8_t *dat, int size)
{
int i, *p;
int *SA = NEW(int, size);
int *buf = NEW(int, size+3);
p = buf;
for (i = 0; i < size; i++) {
printf("xx:%c\n", dat[i]);
*p++ = dat[i]+1;
}
for (i = 0; i < 3; i++)
*p++ = TERM;
printf("xxx : size:%d %d %d\n", size, buf[0], buf[1]);
dc3_r(buf, SA, size, 0xff+1);
free(buf);
return SA;
}
int
main(int argc, char * argv[])
{
time_t t = time(NULL);
printf("srand:%ld\n", t);
srand(t);
#ifndef DEBUG
int k;
for (k = 0; k < 1024 ; k++) {
#endif
int *SA;
int size = rand() % 1204 + 2;
printf("================\n");
#ifndef DEBUG
int j;
int *p = NEW(int, size);
for (j = 0; j < size; j++)
p[j] = rand() % 26 + 'A';
#else
uint8_t p[] = {
// 'K', 'Y', 'R', 'K', 'J', 'C', 'B', 'C', 'B', 'C', 'M', 'J'
'B', 'C', 'A',// 'D', 'F', 'E',
};
size = sizeof(p) / sizeof(p[0]);
#endif
printC("+++++++++++++++src:", p, size);
SA = dc3(p, size);
// print_all_rank(p, SA, size, size);
check_all_rank(p, SA, size, size);
free(SA);
#ifndef DEBUG
free(p);
}
#endif
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment