Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@hamsham
Last active February 16, 2023 09:30
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save hamsham/16c638b8252ffd23b8a6 to your computer and use it in GitHub Desktop.
Save hamsham/16c638b8252ffd23b8a6 to your computer and use it in GitHub Desktop.
Bignum multiplication of numbers with arbitrary bases, using the Schönhage-Strassen algorithm.
/**
* Bignum multiplication of numbers with arbitrary bases, using the
* Schonhage-Strassen algorithm.
*
* g++ -std=c++11 -Wall -Wextra -Werror -pedantic-errors bn_strassen.cpp -o bn_strassen
*
* usage: bn_strassen.exe 123456789 987654321
*
* Sources:
* Ando Emerencia
* Multiplying Huge Integers Using Fourier Transforms
* http://www.cs.rug.nl/~ando/pdfs/Ando_Emerencia_multiplying_huge_integers_using_fourier_transforms_paper.pdf
*
* Manuel Blum
* Carnegie Mellon University, Online lecture notes
* http://www.cs.cmu.edu/
* http://www.cs.cmu.edu/afs/cs/academic/class/15451-s10/www/lectures/lect0423.txt
* http://www.cs.cmu.edu/afs/cs/academic/class/15750-s01/www/notes/lect0424
*
* Christoph Luders
* Implementation of the DKSS Algorithm for Multiplication of Large Numbers
* http://wrogn.com/implementation-of-the-dkss-algorithm-for-multiplication-of-large-numbers-issac15/
* BIGNUM Library
* http://wrogn.com/bignum/
*
* Institut für Mathematik der Universität Zürich
* Fast Arithmetic
* http://www.math.uzh.ch/?file&key1=23398
*
* Stack Overflow
* http://stackoverflow.com/questions/31369856/problems-using-fft-on-multiplication-of-large-numbers
*/
#include <cassert>
#include <chrono>
#include <cmath>
#include <complex>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <string>
#include <thread>
#include <vector>
//#define NUM_BASE_16
#define SIZE_MULTIPLIER_A 20
#define SIZE_MULTIPLIER_B 11
#define PRINT_DEBUG std::cout << "DEBUG: " << __FUNCTION__ << " - " << __LINE__ << std::endl
#ifdef NUM_BASE_16
enum {
NUM_MIN = 0,
NUM_MAX = 15,
NUM_BASE = 16
};
#elif defined(NUM_BASE_8)
enum {
NUM_MIN = 0,
NUM_MAX = 7,
NUM_BASE = 8
};
#elif defined(NUM_BASE_2)
enum {
NUM_MIN = 0,
NUM_MAX = 1,
NUM_BASE = 2
};
#else
enum {
NUM_MIN = 0,
NUM_MAX = 9,
NUM_BASE = 10
};
#endif
typedef unsigned int bn_t;
typedef unsigned long bn_double_t;
typedef std::vector<bn_t> bignum;
using std::chrono::steady_clock;
typedef std::chrono::time_point<steady_clock> time_point;
typedef std::chrono::milliseconds millis;
typedef std::chrono::duration<double> duration;
template <typename flt_t>
using cmplx_list_t = std::vector<std::complex<flt_t>>;
template <typename flt_t>
using cmplx_size_t = typename cmplx_list_t<flt_t>::size_type;
template <typename flt_t>
using cmplx_value_t = typename cmplx_list_t<flt_t>::value_type;
/**
* Simple bignum printing operation
*/
std::ostream& operator << (std::ostream& ostr, const bignum& bn) {
for (unsigned i = bn.size(); i > 0; --i) {
const unsigned index = i-1;
bn_t digit = bn[index];
#ifdef NUM_BASE_16
ostr << std::hex << digit;
#else
ostr << digit;
#endif
}
return ostr;
}
/**
* Bignum Comparison
*/
bool operator == (const bignum& a, const bignum& b) {
if (a.size() != b.size()) {
return false;
}
for (bignum::size_type i = 0; i < a.size(); ++i) {
if (a[i] != b[i]) {
return false;
}
}
return true;
}
/*
*
*/
bool bignum_from_str(bignum& b, const char* const number) {
int i = strlen(number);
while (i --> 0) {
bn_t digit = number[i];
if (digit >= '0' && digit <= '9') {
digit -= '0';
}
#ifdef NUM_BASE_16
else if (digit >= 'a' && digit <= 'f') {
digit = 10 + (digit-'a');
}
else if (digit >= 'A' && digit <= 'F') {
digit = 10 + (digit-'A');
}
#endif
else {
std::cerr
<< "Invalid digit \'" << digit
<< "\' in parameter \'" << number << "\'."
<< std::endl;
b.clear();
return false;
}
b.push_back(digit);
}
return true;
}
/*
*
*/
bignum bignum_from_num(bn_double_t n) {
bn_double_t count = n;
bn_double_t numDigits = 0;
while (count) {
++numDigits;
count /= NUM_BASE;
}
bignum ret;
ret.reserve(numDigits);
count = n;
while (count) {
ret.push_back((bn_t)count % NUM_BASE);
count /= NUM_BASE;
}
return ret;
}
/**
* Convert the first two parameters into numbers
*/
bool tokenize(int argc, char** argv, bignum& out1, bignum& out2){
bignum first = {};
bignum second = {};
if (argc != 3) {
std::cerr << "Invalid number of arguments." << std::endl;
return false;
};
if (!bignum_from_str(first, argv[1])
|| !bignum_from_str(second, argv[2])
) {
return false;
}
out1 = std::move(first);
out2 = std::move(second);
return true;
}
/**
* Multiply bignums (naive)
*/
bignum mul_naive(const bignum& a, const bignum& b) {
bignum ret{};
const bignum::size_type totalLen = a.size() + b.size();
if (!totalLen) {
return ret;
}
ret.resize(totalLen);
for (unsigned i = 0; i < b.size(); ++i) {
unsigned remainder = 0;
for (unsigned j = 0; j < a.size(); ++j) {
const bignum::size_type index = i + j;
ret[index] += remainder + a[j] * b[i];
remainder = ret[index] / NUM_BASE;
ret[index] = ret[index] % NUM_BASE;
}
ret[i+a.size()] += remainder;
}
return ret;
}
/*
*
*/
bignum::size_type next_pow2(bignum::size_type n) {
if (n == 0) {
return 0;
}
--n;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
return ++n;
}
/*
*
*/
constexpr bool is_pow2(const unsigned n) {
return n && !(n & (n-1));
}
template <typename flt_t>
void fft(cmplx_list_t<flt_t>& x) {
static constexpr flt_t pi = 3.1415926535897932384626433832795;
// in the event that an array was passed in with a non-power-of-two length.
const cmplx_size_t<flt_t> len = x.size();
// base case
if (len == 1) {
x = cmplx_list_t<flt_t>{x[0]};
return;
}
// Partition the input list into evenly-index and oddly-indexed elements.
const cmplx_size_t<flt_t> halfLen = len / 2;
cmplx_list_t<flt_t> evens; evens.reserve(halfLen);
cmplx_list_t<flt_t> odds; odds.reserve(halfLen);
for (cmplx_size_t<flt_t> i = 0; i < halfLen; ++i) {
// zero-pad both partitions if the input array was not at the
// required radix=2 capacity
if (i < x.size()) {
evens.push_back(x[2*i]);
odds.push_back(x[2*i+1]);
}
else {
evens[i] = odds[i] = std::complex<flt_t>{0};
}
}
// TODO: Fix the recursion!
fft<flt_t>(evens);
fft<flt_t>(odds);
// combine the even and odd partitions
const cmplx_size_t<flt_t> nf = (flt_t)len;
for (cmplx_size_t<flt_t> k = 0; k < halfLen; ++k) {
const flt_t kf = (flt_t)k;
const flt_t w = -2.0 * kf * pi / nf;
const std::complex<flt_t> wk {std::cos(w), std::sin(w)};
x[k] = evens[k] + (wk * odds[k]);
x[k+halfLen] = evens[k] - (wk * odds[k]);
}
}
template <typename flt_t>
void ifft(cmplx_list_t<flt_t>& x) {
for (cmplx_value_t<flt_t>& e : x) {
e = std::conj(e);
}
fft<flt_t>(x);
const flt_t len = static_cast<flt_t>(x.size());
for (cmplx_value_t<flt_t>& e : x) {
e = std::conj(e) / len;
}
}
template <typename flt_t>
cmplx_list_t<flt_t> create_fft_table(const bignum& a, const bignum& b) {
// Create a list of complex numbers with interleaved values from the two
// input numbers. The output list must have a length that's a power of 2.
bignum::size_type aLen = a.size();
bignum::size_type bLen = b.size();
// Ensure the Cooley-Tukey algorithm has its radix=2 requirement fulfilled.
// Add a few digits to the end so the algorithm has room for overflow.
bignum::size_type size = aLen + bLen;
if (!is_pow2(size)) {
size = next_pow2(size);
}
cmplx_list_t<flt_t> ret;
ret.reserve(size);
for (cmplx_size_t<flt_t> i = 0; i < size; ++i) {
// Add some zero-padding to the output list if necessary
const bignum::value_type aVal = i < aLen ? a[i] : 0;
const bignum::value_type bVal = i < bLen ? b[i] : 0;
ret.push_back(std::complex<flt_t>{(flt_t)aVal, (flt_t)bVal});
}
return ret;
}
template <typename flt_t>
void convolute_fft(cmplx_list_t<flt_t>& fftTable) {
const cmplx_size_t<flt_t> fftSize = fftTable.size();
// transform.
cmplx_list_t<flt_t> transforms{fftTable};
fft<flt_t>(transforms);
// point-wise multiplication in frequency domain.
for (cmplx_size_t<flt_t> i = 0; i < fftSize; ++i) {
// extract the individual transformed signals from the composed one.
const cmplx_value_t<flt_t>& ti = transforms[i];
const cmplx_value_t<flt_t>&& tc = std::conj(transforms[-i % fftSize]);
// perform convolution
const cmplx_value_t<flt_t> x1 = ti + tc;
const cmplx_value_t<flt_t> x2 = ti - tc;
const cmplx_value_t<flt_t> x3 = x1 * x2;
// avoid pedantic compilers
constexpr flt_t rotation = flt_t{0.25};
fftTable[i] = std::complex<flt_t>{x3.imag(), -x3.real()} * rotation;
}
}
bignum mul_strassen(const bignum& a, const bignum& b) {
// Default type is double. Use floats if they perform well enough.
typedef double flt_t;
// building a complex signal with the information of both signals.
cmplx_list_t<flt_t> fftTable = std::move(create_fft_table<flt_t>(a, b));
convolute_fft<flt_t>(fftTable);
ifft<flt_t>(fftTable);
const cmplx_list_t<flt_t>& inverses = fftTable;
bignum ret;
ret.reserve(inverses.size());
for (cmplx_size_t<flt_t> i = 0, c = 0; i < inverses.size(); ++i) {
// drop imaginary part of the number
const flt_t x = inverses[i].real();
// round to an integer
const bn_double_t ci = (bn_double_t)(c + std::floor(x + 0.5));
ret.push_back(ci % NUM_BASE);
// carry propagation
c = (ci / NUM_BASE);
}
// trim trailing zeroes from the most-significant digits
bignum::size_type numZeroes = 0;
for (bignum::size_type i = ret.size(); i --> 0;) {
if (ret[i]) {
break;
}
++numZeroes;
}
ret.resize(ret.size() - numZeroes);
return ret;
}
bignum::size_type get_num_bytes(const bignum& b) {
return sizeof(bignum::value_type) * b.size();
}
void print_bignum_stats(const std::string& name, const bignum& b) {
std::cout << name << " element count: " << b.size() << std::endl;
std::cout << name << " byte size: " << get_num_bytes(b) << std::endl;
}
void run_mult_bench(
const std::string& name,
const bignum& a,
const bignum& b,
bignum& result,
bignum (*mul_func)(const bignum&, const bignum&)
) {
std::cout << "Performing " << name << " Multiplication..." << std::endl;
const steady_clock::time_point start = steady_clock::now();
result = std::move(mul_func(a, b));
const steady_clock::time_point end = steady_clock::now();
const duration elapsed = end - start;
std::cout << name << " method time: " << elapsed.count() << "s." << std::endl;
print_bignum_stats(name, result);
std::cout << std::endl;
}
/**
* Main()
*/
int main(int, char**) {
#ifdef NUM_BASE_16
bignum a = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
bignum b = {15,14,13,12,11,10,9,8,7,6,5,4,3,2,1};
#elif defined(NUM_BASE_8)
bignum a = {0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7};
bignum b = {0,7,6,5,4,3,2,1,0,7,6,5,4,3,2,1};
#elif defined(NUM_BASE_2)
bignum a = {0,1,1,1,0,1,1,1,0,1,1,1,0,1,1,1,0,1,1,1};
bignum b = {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1};
#else
bignum a = {1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9};
bignum b = {9,8,7,6,5,4,3,2,1,0,9,8,7,6,5,4,3,2,1};
#endif
for (unsigned i = 0; i < SIZE_MULTIPLIER_A; ++i) {
a.insert(a.end(), a.begin(), a.end());
}
for (unsigned i = 0; i < SIZE_MULTIPLIER_B; ++i) {
b.insert(b.end(), b.begin(), b.end());
}
/*
if (!tokenize(argc, argv, a, b)) {
return -1;
}
*/
std::cout << "Testing bignum multiplication." << std::endl;
print_bignum_stats("Test number A", a);
print_bignum_stats("Test Number B", b);
std::cout << '\n' << a << "\n\n" << b << '\n' << std::endl;
/*
bignum naive;
std::thread t1{
run_mult_bench,
"Naive",
std::ref(a),
std::ref(b),
std::ref(naive),
mul_naive
};
*/
bignum strassen;
std::thread t2{
run_mult_bench,
"Strassen",
std::ref(a),
std::ref(b),
std::ref(strassen),
mul_strassen
};
//t1.join();
t2.join();
//std::cout << '\n' << naive << "\n\n" << strassen << '\n' << std::endl;
std::cout << '\n' << strassen << '\n' << std::endl;
//std::cout << "Naive == Strassen: " << (naive == strassen) << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment