Skip to content

Instantly share code, notes, and snippets.

@bpd1069
Forked from jbedo/hilbert.c
Created March 11, 2018 14:16
Show Gist options
  • Save bpd1069/4c253d93b4924c16273bfff53b6d3b2d to your computer and use it in GitHub Desktop.
Save bpd1069/4c253d93b4924c16273bfff53b6d3b2d to your computer and use it in GitHub Desktop.
Matrix multiplication with Hilbert space-filling curves
#include<u.h>
#include<libc.h>
#define MAX(a, b) ((a) > (b) ? (a) : (b))
/* Hilbert curve functions (from Wikipedia) */
/* rotate/flip a quadrant appropriately */
void rot(int n, int *x, int *y, int rx, int ry)
{
int t;
if (ry == 0) {
if (rx == 1) {
*x = n - 1 - *x;
*y = n - 1 - *y;
}
t = *x;
*x = *y;
*y = t;
}
}
/* (x,y) ↦ d */
int xy2d (int n, int x, int y)
{
int rx, ry, s, d = 0;
for (s = n / 2; s > 0; s /= 2) {
rx = (x & s) > 0;
ry = (y & s) > 0;
d += s * s * ((3 * rx) ^ ry);
rot(s, &x, &y, rx, ry);
}
return d;
}
/* d ↦ (x,y) */
void d2xy(int n, int d, int *x, int *y)
{
int rx, ry, s, t = d;
*x = *y = 0;
for (s = 1; s < n; s *= 2) {
rx = 1 & (t / 2);
ry = 1 & (t ^ rx);
rot(s, x, y, rx, ry);
*x += s * rx;
*y += s * ry;
t /= 4;
}
}
/* Memory */
void *
emalloc(ulong sz)
{
void *p;
p = malloc(sz);
if(p == 0)
sysfatal("emalloc: %r");
return p;
}
void *
cmalloc(ulong sz)
{
void *p;
p = emalloc(sz);
memset(p, 0, sz);
return p;
}
void *
erealloc(void *p, ulong sz)
{
p = realloc(p, sz);
if(p == 0)
sysfatal("erealloc: %r");
return p;
}
/* Matrices */
typedef struct matrix matrix;
struct matrix{
double *data;
int nr, nc, n;
int *Δi, *Δj;
};
typedef struct miter miter;
struct miter{
int *pj;
double *value;
matrix *x;
int valid;
};
matrix *
nmatrix(matrix *x, int nr, int nc)
{
int i, j, d;
int d′;
int *pi, *pj;
if(x == nil)
x = cmalloc(sizeof(*x));
x->nr = nr;
x->nc = nc;
if(nr > 0 && nc > 0){
x->n = 1 << ceil(log(MAX(nr, nc)) / log(2));
x->data = erealloc(x->data, sizeof(*x->data) * x->n * x->n);
x->Δi = erealloc(x->Δi, sizeof(*x->Δi) * (x->nr * x->nc + 1));
x->Δj = erealloc(x->Δj, sizeof(*x->Δj) * (x->nr * x->nc + 1));
/* Populate delta arrays */
pj = x->Δj;
for(i = d′ = 0; i < x->nr; i++){
for(j = 1; j < x->nc; j++){
d = xy2d(x->n, i, j);
*pj++ = d - d′;
d′ = d;
}
if(i < x->nr - 1){
d = xy2d(x->n, i + 1, 0);
*pj++ = d - d′;
d′ = d;
}
}
*pj = 0;
pi = x->Δi;
for(j = d′ = 0; j < x->nc; j++){
for(i = 1; i < x->nr; i++){
d = xy2d(x->n, i, j);
*pi++ = d - d′;
d′ = d;
}
if(j < x->nc - 1){
d = xy2d(x->n, 0, j + 1);
*pi++ = d - d′;
d′ = d;
}
}
*pi = 0;
}
return x;
}
void
dmatrix(matrix *x)
{
if(x->data != nil){
free(x->data);
free(x->Δi);
free(x->Δj);
}
free(x);
}
#define mget(x, i, j) ((x)->data[xy2d((x)->n, (i), (j))])
void
niter(miter *i, matrix *x)
{
i->pj = x->Δj;
i->valid = x->nr > 0 && x->nc > 0;
i->value = x->data;
}
void
next(miter *i)
{
if(*i->pj == 0){
i->valid = 0;
return;
}
i->value += *i->pj++;
}
void
zero(matrix *a)
{
memset(a->data, 0, sizeof(*a->data) * a->n * a->n);
}
matrix *
mdot(matrix *a, matrix *b, matrix *c)
{
double *ap, *ap′, *bp, *cp;
int *pi, *pi′, *pj, *pk;
int col, icol;
c = nmatrix(c, a->nr, b->nc);
zero(c);
ap′ = a->data;
bp = b->data;
cp = c->data;
pi′ = a->Δj;
pj = b->Δi;
pk = c->Δj;
for(col = 1;; cp += *pk++, col++){
ap = ap′;
pi = pi′;
for(icol = 0; icol < a->nc; icol++, ap += *pi++, bp += *pj++){
*cp += *ap * *bp;
}
if(col == c->nc){
ap′ = ap;
pi′ = pi;
bp = b->data;
pj = b->Δi;
col = 0;
}
if(*pk == 0)
break;
}
return c;
}
double *
tdot(double *a, double *b, double *c, uint n)
{
uint i, j, k;
c = erealloc(c, sizeof(*c) * n * n);
memset(c, 0, sizeof(*c) * n * n);
for(i = 0; i < n; i++)
for(j = 0; j < n; j++)
for(k = 0; k < n; k++)
c[i * n + j] += a[i * n + k] * b[k * n + j];
return c;
}
void
usage(void)
{
fprint(2, "%s: [-n msize]\n", argv0);
exits("usage");
}
void
main(int argc, char **argv)
{
int n = 128;
uvlong begin, end;
matrix *x, *y;
miter it;
double *sx, *sy, *px;
ARGBEGIN{
case 'n':
n = atoi(EARGF(usage()));
break;
case 'h':
default:
usage();
}ARGEND;
if(n <= 1)
sysfatal("Matrix size must be ≥ 2\n");
cycles(&begin);
x = nmatrix(nil, n, n);
cycles(&end);
print("%ulld\t", end - begin);
sx = emalloc(sizeof(*sx) * n * n);
for(niter(&it, x), px = sx; it.valid; next(&it), px++)
*it.value = *px = frand() - 0.5;
cycles(&begin);
y = mdot(x, x, nil);
cycles(&end);
print("%ulld\t", end - begin);
dmatrix(x);
dmatrix(y);
cycles(&begin);
sy = tdot(sx, sx, nil, n);
cycles(&end);
print("%ulld\n", end - begin);
free(sy);
free(sx);
exits(0);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment