Skip to content

Instantly share code, notes, and snippets.

@dsprenkels
Created June 1, 2017 16:12
Show Gist options
  • Save dsprenkels/385ca247c87e1c84bf912c1d9fc7088a to your computer and use it in GitHub Desktop.
Save dsprenkels/385ca247c87e1c84bf912c1d9fc7088a to your computer and use it in GitHub Desktop.
/*
* Arithmetic modulo 2^221 - 3 in radix 2^16 (unsigned)
*
* Authors:
* - Daan Sprenkels <hello@dsprenkels.com>
* - Jordi Riemens <jordi.riemens@student.ru.nl>
*
* This module uses redundant representation. Each number is represented an
* array of 14 int64_t elements.
*/
#include <inttypes.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <assert.h>
typedef struct {
int64_t *buf;
size_t len;
} bigint;
bigint bigint_new(size_t len)
{
bigint ret;
/* Keep `len` always divisible by two for convenience in one-level Karatsuba
multiplication */
len += len & 1;
ret.buf = calloc(len, sizeof(int64_t));
assert(ret.buf != NULL);
ret.len = len;
return ret;
}
typedef int64_t bigint221[14];
bigint bigint_mul(const bigint x, const bigint y)
{
int i, j;
bigint out = bigint_new(x.len + y.len - 1);
for (i = 0; i < x.len; ++i) {
for (j = 0; j < y.len; ++j) {
out.buf[i + j] += x.buf[i] * y.buf[j];
}
}
return out;
}
bigint bigint_mul_karatsuba(const bigint x, const bigint y)
{
assert(x.len == y.len); /* Karatsuba cannot be done in this case */
int i, j;
bigint out = bigint_new(x.len + y.len - 1);
const int len = x.len;
const int half = len / 2;
int64_t tmp[len];
memset(tmp, 0, len * sizeof(int64_t));
/* Compute A0 + A1 */
for (i = 0; i < half; ++i) {
tmp[i] = x.buf[i] + x.buf[half + i];
}
/* Compute B0 + B1 */
for (i = 0; i < half; ++i) {
tmp[half + i] = y.buf[i] + y.buf[half + i];
}
/* Compute 2^m*(A0 + A1)*(B0 + B1) */
for (i = 0; i < half; ++i) {
for (j = 0; j < half; ++j) {
out.buf[half + i + j] += tmp[i] * tmp[half + j];
}
}
/* Compute A0 * B0 */
memset(tmp, 0, len * sizeof(int64_t));
for (i = 0; i < half; ++i) {
for (j = 0; j < half; ++j) {
tmp[i + j] += x.buf[i] * y.buf[j];
}
}
/* Subtract and add the (A0 * B0) parts */
for (i = 0; i < len - 1; ++i) {
out.buf[i] += tmp[i];
out.buf[half+i] -= tmp[i];
}
/* Compute 2^(2m)*(A1 * B1) */
memset(tmp, 0, len * sizeof(int64_t));
for (i = 0; i < half; ++i) {
for (j = 0; j < half; ++j) {
tmp[i + j] += x.buf[half + i] * y.buf[half + j];
}
}
/* Subtract and add the (A1 * B1) parts */
for (i = 0; i < len - 1; ++i) {
out.buf[len+i] += tmp[i];
out.buf[half+i] -= tmp[i];
}
return out;
}
bigint bigint_mul_karatsuba_refined(const bigint x, const bigint y)
{
assert(x.len == y.len); /* Karatsuba cannot be done in this case */
int i, j;
bigint out = bigint_new(x.len + y.len - 1);
const int len = x.len;
const int half = len / 2;
int64_t tmp[len];
memset(tmp, 0, len * sizeof(int64_t));
/* Compute A0 - A1 */
for (i = 0; i < half; ++i) {
tmp[i] = x.buf[i] - x.buf[half + i];
}
/* Compute B0 - B1 */
for (i = 0; i < half; ++i) {
tmp[half + i] = y.buf[i] - y.buf[half + i];
}
/* Compute 2^m*(A0 - A1)*(B0 - B1) */
for (i = 0; i < half; ++i) {
for (j = 0; j < half; ++j) {
out.buf[half + i + j] -= tmp[i] * tmp[half + j];
}
}
/* Compute A0 * B0 */
memset(tmp, 0, len * sizeof(int64_t));
for (i = 0; i < half; ++i) {
for (j = 0; j < half; ++j) {
tmp[i + j] += x.buf[i] * y.buf[j];
}
}
/* Add the (A0 * B0) parts */
for (i = 0; i < half; ++i) {
out.buf[i] += tmp[i];
out.buf[half+i] += tmp[i];
}
for (i = 0; i < half; ++i) {
tmp[i] = tmp[half+i];
tmp[half+i] = 0;
}
/* Compute 2^(2m)*(A1 * B1) */
for (i = 0; i < half; ++i) {
for (j = 0; j < half; ++j) {
tmp[i + j] += x.buf[half + i] * y.buf[half + j];
}
}
/* Add the (A1 * B1) parts */
for (i = 0; i < len - 1; ++i) {
out.buf[len+i] += tmp[i];
out.buf[half+i] += tmp[i];
}
return out;
}
void bigint_print(bigint n)
{
int idx;
if (n.len == 0) {
printf("0");
return;
}
printf("(");
printf("%ld", n.buf[0]);
for (idx = 1; idx < n.len; ++idx) {
printf(" + %ld*2^%d", n.buf[idx], idx * 16);
}
printf(")");
}
static void test()
{
bigint z = {0};
bigint x = bigint_new(14);
x.buf[13] = 1234567890;
x.buf[12] = 777;
x.buf[11] = 1039;
x.buf[10] = 42;
x.buf[9] = 1656;
x.buf[8] = 888;
x.buf[4] = 7645;
x.buf[1] = 1039;
x.buf[0] = 42;
bigint y = bigint_new(14);
y.buf[13] = 9876543;
y.buf[12] = 13;
y.buf[11] = 989754;
y.buf[10] = 42;
y.buf[2] = 999;
y.buf[0] = 1;
/* Test 1 */
z = bigint_mul_karatsuba(x, y);
bigint_print(x);
printf(" * ");
bigint_print(y);
printf(" - ");
bigint_print(z);
printf("\n");
free(z.buf);
/* Test 2 */
z = bigint_mul(x, y);
bigint_print(x);
printf(" * ");
bigint_print(y);
printf(" - ");
bigint_print(z);
printf("\n");
free(z.buf);
/* Test 3 */
z = bigint_mul_karatsuba_refined(x, y);
bigint_print(x);
printf(" * ");
bigint_print(y);
printf(" - ");
bigint_print(z);
printf("\n");
free(z.buf);
free(x.buf);
free(y.buf);
}
int main()
{
test();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment