Skip to content

Instantly share code, notes, and snippets.

@cpq
Created July 29, 2013 15:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cpq/6105055 to your computer and use it in GitHub Desktop.
Save cpq/6105055 to your computer and use it in GitHub Desktop.
Big number integer arithmetic in a single header file
/*
* Large integer manipulation functions.
* All integers are assumed to be unsigned and represented as
* a contiguous array of bytes (unsigned char *ptr, int len)
* in little-endian byte order (ptr points to the least significant byte).
*/
#define HUGE 4096 /* Length of the biggest integer in bytes */
/*
* Compare two integers. Pointers to integers and their lengths are given.
* Return: 0 if equal, >0 if 'a' is greater, <0 if 'b' is greater.
*/
static __inline int
cmp_int(const void *va, int la, const void *vb, int lb)
{
register const unsigned char *a = va, *b = vb;
int i;
/* Make sure second number is never bigger in size that the first one */
if (la < lb)
return ( - cmp_int(b, lb, a, la));
/*
* First number may be bigger in size, check if the high order
* are all zeroes. If they are not, then it is definitely bigger
* then the second number
*/
a += la;
for (i = 0; i < la - lb; i++)
if (*--a != 0)
return (1);
/* Now compare the low order bytes */
b += lb;
for (i = 0; i < lb; i++)
if (*--a != *--b)
break;
return (*a - *b);
}
/*
* a += b
*/
static __inline void
add_int(void *va, int la, const void *vb, int lb)
{
register unsigned char *a = va;
register const unsigned char *b = vb;
int i, len, overflow, result;
len = la < lb ? la : lb;
overflow = 0;
for (i = 0; i < len; i++, a++, b++) {
result = *a + *b + overflow;
overflow = result > 0xff ? 1 : 0;
*a = result;
}
while (overflow && len < la) {
result = (*a) + 1;
overflow = result > 0xff ? 1 : 0;
*a++ = result;
len++;
}
}
/*
* a -= b (Assumes a >= b)
*/
static __inline void
sub_int(void *va, int la, const void *vb, int lb)
{
register unsigned char *a = va;
register const unsigned char *b = vb;
int i, carry, result;
carry = 0;
for (i = 0; i < lb; i++) {
result = a[i] - carry - b[i];
if (result < 0) {
carry = 1;
a[i] = 0x100 + result;
} else {
carry = 0;
a[i] = result;
}
}
while (carry && i < la)
if (a[i] == 0) {
a[i] = 0xff;
} else {
a[i]--;
carry = 0;
}
}
/*
* a = b
*/
static __inline void
copy_int(void *va, int la, const void *vb, int lb)
{
register unsigned char *a = va;
register const unsigned char *b = vb;
int i, len;
len = la < lb ? la : lb;
for (i = 0; i < len; i++)
*a++ = *b++;
while (i++ < la)
*a++ = 0;
}
/*
* a *= b
*/
static __inline void
mul_int(void *va, int la, const void *vb, int lb)
{
register unsigned char *a = va;
register const unsigned char *b = vb;
int i, j, result, overflow;
unsigned char p[HUGE];
/* Initialize temporary storage for result */
for (i = 0; i < la; i++)
p[i] = 0;
/* Do the multiplication */
for (i = 0; i < lb; i++) {
overflow = 0;
for (j = 0; i + j < la; j++) {
result = p[i + j] + b[i] * a[j] + overflow;
p[i + j] = result & 0xff;
overflow = result >> 8;
}
if (i + j < la)
p[i + j] = overflow;
}
/* Copy the result */
for (i = 0; i < la; i++)
a[i] = p[i];
}
/*
* a /= b
* c = a mod b
*/
static __inline int
div_int(unsigned char *a, int la, const unsigned char *b, int lb,
unsigned char *c, int lc)
{
unsigned char result[HUGE];
int i, n, res;
memset(result, 0, la);
/* Normalize the numbers: strip leading zeroes */
while (la > 0 && a[la - 1] == 0)
la--;
while (lb > 0 && b[lb - 1] == 0)
lb--;
if (lb == 0)
abort(); /* Division by zero */
if ((res = cmp_int(a, la, b, lb)) < 0) {
/* a < b. result is 0, and a itself is a remainder */
copy_int(c, lc, a, la);
(void) memset(a, 0, la);
} else if (res == 0) {
/* a == b, result is 1, remainder is 0 */
(void) memset(a, 0, la);
(void) memset(c, 0, lc);
a[0] = 1;
} else {
/* a > b. Do the division. */
i = la - lb;
n = lb;
do {
res = cmp_int(a + i, n, b, lb);
if (res < 0) {
result[i] = 0;
n++;
} else {
do {
result[i]++;
sub_int(a + i, n, b, lb);
} while (cmp_int(a + i, n, b, lb) >= 0);
while (n > 0 && a[i + n - 1] == 0)
n--;
n += lb;
}
} while (i-- > 0);
if (c != NULL)
copy_int(c, lc, a, n);
copy_int(a, la, result, la);
}
return (res);
}
/*
* Print the number
*/
static __inline void
print_int(const void *mem, int len, unsigned char base, FILE *fp)
{
static const char *digits = "0123456789ABCDEF";
const unsigned char *p;
/*
* Optimize here for the hexadecimal output
*/
if (base == 16) {
p = mem;
/* Skip most significant zero bytes */
while (len > 1 && p[len - 1] == 0)
len--;
/* Print the number */
while (len--) {
(void) fputc(digits[p[len] >> 4], fp);
(void) fputc(digits[p[len] & 0x0f], fp);
}
} else {
#if 1
switch (len) {
case 1:
fprintf(fp, "%u", * (unsigned char *) mem);
break;
case 2:
fprintf(fp, "%u", * (uint16_t *) mem);
break;
case 4:
fprintf(fp, "%u", * (uint32_t *) mem);
break;
case 8:
fprintf(fp, "%llu", * (uint64_t *) mem);
break;
default:
abort();
break;
}
#else
unsigned char rem, tmp[HUGE], buf[HUGE], i = 0;
int res;
copy_int(tmp, len, mem, len);
do {
res = div_int(tmp, len, &base,
sizeof(base), &rem, sizeof(rem));
assert(rem < base);
buf[i++] = digits[rem];
} while (res > 0);
/*
* Strange fact. We write words from left-to-write. But we
* write numbers from right-to-left: 'first' digits (least
* significant) are written last. Is this because we use
* Arabic notation, and all Arabic stuff is right-to-left?
* Print the digits buffer as we used to see, right-to-left.
*/
while (i--)
(void) fputc(buf[i], fp);
#endif
}
}
/*
* 'ch' is a ASCII character representing hex digit. Return its value
*/
static __inline unsigned char
hex(unsigned char ch)
{
static const unsigned char scan_tab[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0,
0,10,11,12,13,14,15, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0,10,11,12,13,14,15, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
};
return (scan_tab[ch]);
}
/*
* dst = strtoul(s, NULL, base)
*/
static __inline int
scan_int(const unsigned char *s, unsigned char base,
unsigned char *dst, int dst_len)
{
const unsigned char *e = s;
unsigned char ch;
int i, len;
/* Find where the number ends */
while (isxdigit(*e))
e++;
if ((len = e - s) == 0)
return (0);
/* For hex, we can identify overflow easily */
if (base == 16) {
if (len > dst_len * 2)
return (0);
/*
* Store the integer into destination, in little endian order
* We optimize here for the hexadecimal input, because it does
* not require any multiplications/additions.
*/
for (i = 0; i < len; i++) {
ch = hex(s[len - i - 1]);
if (i & 1)
*dst++ |= ch << 4;
else
*dst = ch;
}
if (i & 1)
dst++;
/* Fill the rest of the destination with zeroes. */
for (i = len; i < dst_len; i++)
*dst++ = 0;
} else {
/* Initialize destination */
for (i = 0; i < dst_len; i++)
dst[i] = 0;
/* Base is not 16. Do addition and multiplication */
for (i = 0; i < len; i++) {
ch = hex(s[i]);
mul_int(dst, dst_len, &base, sizeof(base));
add_int(dst, dst_len, &ch, sizeof(ch));
}
}
return (len);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment