Last active
February 16, 2023 09:30
-
-
Save hamsham/16c638b8252ffd23b8a6 to your computer and use it in GitHub Desktop.
Bignum multiplication of numbers with arbitrary bases, using the Schönhage-Strassen algorithm.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Bignum multiplication of numbers with arbitrary bases, using the | |
* Schonhage-Strassen algorithm. | |
* | |
* g++ -std=c++11 -Wall -Wextra -Werror -pedantic-errors bn_strassen.cpp -o bn_strassen | |
* | |
* usage: bn_strassen.exe 123456789 987654321 | |
* | |
* Sources: | |
* Ando Emerencia | |
* Multiplying Huge Integers Using Fourier Transforms | |
* http://www.cs.rug.nl/~ando/pdfs/Ando_Emerencia_multiplying_huge_integers_using_fourier_transforms_paper.pdf | |
* | |
* Manuel Blum | |
* Carnegie Mellon University, Online lecture notes | |
* http://www.cs.cmu.edu/ | |
* http://www.cs.cmu.edu/afs/cs/academic/class/15451-s10/www/lectures/lect0423.txt | |
* http://www.cs.cmu.edu/afs/cs/academic/class/15750-s01/www/notes/lect0424 | |
* | |
* Christoph Luders | |
* Implementation of the DKSS Algorithm for Multiplication of Large Numbers | |
* http://wrogn.com/implementation-of-the-dkss-algorithm-for-multiplication-of-large-numbers-issac15/ | |
* BIGNUM Library | |
* http://wrogn.com/bignum/ | |
* | |
* Institut für Mathematik der Universität Zürich | |
* Fast Arithmetic | |
* http://www.math.uzh.ch/?file&key1=23398 | |
* | |
* Stack Overflow | |
* http://stackoverflow.com/questions/31369856/problems-using-fft-on-multiplication-of-large-numbers | |
*/ | |
#include <cassert> | |
#include <chrono> | |
#include <cmath> | |
#include <complex> | |
#include <cstring> | |
#include <iomanip> | |
#include <iostream> | |
#include <string> | |
#include <thread> | |
#include <vector> | |
//#define NUM_BASE_16 | |
#define SIZE_MULTIPLIER_A 20 | |
#define SIZE_MULTIPLIER_B 11 | |
#define PRINT_DEBUG std::cout << "DEBUG: " << __FUNCTION__ << " - " << __LINE__ << std::endl | |
#ifdef NUM_BASE_16 | |
enum { | |
NUM_MIN = 0, | |
NUM_MAX = 15, | |
NUM_BASE = 16 | |
}; | |
#elif defined(NUM_BASE_8) | |
enum { | |
NUM_MIN = 0, | |
NUM_MAX = 7, | |
NUM_BASE = 8 | |
}; | |
#elif defined(NUM_BASE_2) | |
enum { | |
NUM_MIN = 0, | |
NUM_MAX = 1, | |
NUM_BASE = 2 | |
}; | |
#else | |
enum { | |
NUM_MIN = 0, | |
NUM_MAX = 9, | |
NUM_BASE = 10 | |
}; | |
#endif | |
typedef unsigned int bn_t; | |
typedef unsigned long bn_double_t; | |
typedef std::vector<bn_t> bignum; | |
using std::chrono::steady_clock; | |
typedef std::chrono::time_point<steady_clock> time_point; | |
typedef std::chrono::milliseconds millis; | |
typedef std::chrono::duration<double> duration; | |
template <typename flt_t> | |
using cmplx_list_t = std::vector<std::complex<flt_t>>; | |
template <typename flt_t> | |
using cmplx_size_t = typename cmplx_list_t<flt_t>::size_type; | |
template <typename flt_t> | |
using cmplx_value_t = typename cmplx_list_t<flt_t>::value_type; | |
/** | |
* Simple bignum printing operation | |
*/ | |
std::ostream& operator << (std::ostream& ostr, const bignum& bn) { | |
for (unsigned i = bn.size(); i > 0; --i) { | |
const unsigned index = i-1; | |
bn_t digit = bn[index]; | |
#ifdef NUM_BASE_16 | |
ostr << std::hex << digit; | |
#else | |
ostr << digit; | |
#endif | |
} | |
return ostr; | |
} | |
/** | |
* Bignum Comparison | |
*/ | |
bool operator == (const bignum& a, const bignum& b) { | |
if (a.size() != b.size()) { | |
return false; | |
} | |
for (bignum::size_type i = 0; i < a.size(); ++i) { | |
if (a[i] != b[i]) { | |
return false; | |
} | |
} | |
return true; | |
} | |
/* | |
* | |
*/ | |
bool bignum_from_str(bignum& b, const char* const number) { | |
int i = strlen(number); | |
while (i --> 0) { | |
bn_t digit = number[i]; | |
if (digit >= '0' && digit <= '9') { | |
digit -= '0'; | |
} | |
#ifdef NUM_BASE_16 | |
else if (digit >= 'a' && digit <= 'f') { | |
digit = 10 + (digit-'a'); | |
} | |
else if (digit >= 'A' && digit <= 'F') { | |
digit = 10 + (digit-'A'); | |
} | |
#endif | |
else { | |
std::cerr | |
<< "Invalid digit \'" << digit | |
<< "\' in parameter \'" << number << "\'." | |
<< std::endl; | |
b.clear(); | |
return false; | |
} | |
b.push_back(digit); | |
} | |
return true; | |
} | |
/* | |
* | |
*/ | |
bignum bignum_from_num(bn_double_t n) { | |
bn_double_t count = n; | |
bn_double_t numDigits = 0; | |
while (count) { | |
++numDigits; | |
count /= NUM_BASE; | |
} | |
bignum ret; | |
ret.reserve(numDigits); | |
count = n; | |
while (count) { | |
ret.push_back((bn_t)count % NUM_BASE); | |
count /= NUM_BASE; | |
} | |
return ret; | |
} | |
/** | |
* Convert the first two parameters into numbers | |
*/ | |
bool tokenize(int argc, char** argv, bignum& out1, bignum& out2){ | |
bignum first = {}; | |
bignum second = {}; | |
if (argc != 3) { | |
std::cerr << "Invalid number of arguments." << std::endl; | |
return false; | |
}; | |
if (!bignum_from_str(first, argv[1]) | |
|| !bignum_from_str(second, argv[2]) | |
) { | |
return false; | |
} | |
out1 = std::move(first); | |
out2 = std::move(second); | |
return true; | |
} | |
/** | |
* Multiply bignums (naive) | |
*/ | |
bignum mul_naive(const bignum& a, const bignum& b) { | |
bignum ret{}; | |
const bignum::size_type totalLen = a.size() + b.size(); | |
if (!totalLen) { | |
return ret; | |
} | |
ret.resize(totalLen); | |
for (unsigned i = 0; i < b.size(); ++i) { | |
unsigned remainder = 0; | |
for (unsigned j = 0; j < a.size(); ++j) { | |
const bignum::size_type index = i + j; | |
ret[index] += remainder + a[j] * b[i]; | |
remainder = ret[index] / NUM_BASE; | |
ret[index] = ret[index] % NUM_BASE; | |
} | |
ret[i+a.size()] += remainder; | |
} | |
return ret; | |
} | |
/* | |
* | |
*/ | |
bignum::size_type next_pow2(bignum::size_type n) { | |
if (n == 0) { | |
return 0; | |
} | |
--n; | |
n |= n >> 1; | |
n |= n >> 2; | |
n |= n >> 4; | |
n |= n >> 8; | |
n |= n >> 16; | |
return ++n; | |
} | |
/* | |
* | |
*/ | |
constexpr bool is_pow2(const unsigned n) { | |
return n && !(n & (n-1)); | |
} | |
template <typename flt_t> | |
void fft(cmplx_list_t<flt_t>& x) { | |
static constexpr flt_t pi = 3.1415926535897932384626433832795; | |
// in the event that an array was passed in with a non-power-of-two length. | |
const cmplx_size_t<flt_t> len = x.size(); | |
// base case | |
if (len == 1) { | |
x = cmplx_list_t<flt_t>{x[0]}; | |
return; | |
} | |
// Partition the input list into evenly-index and oddly-indexed elements. | |
const cmplx_size_t<flt_t> halfLen = len / 2; | |
cmplx_list_t<flt_t> evens; evens.reserve(halfLen); | |
cmplx_list_t<flt_t> odds; odds.reserve(halfLen); | |
for (cmplx_size_t<flt_t> i = 0; i < halfLen; ++i) { | |
// zero-pad both partitions if the input array was not at the | |
// required radix=2 capacity | |
if (i < x.size()) { | |
evens.push_back(x[2*i]); | |
odds.push_back(x[2*i+1]); | |
} | |
else { | |
evens[i] = odds[i] = std::complex<flt_t>{0}; | |
} | |
} | |
// TODO: Fix the recursion! | |
fft<flt_t>(evens); | |
fft<flt_t>(odds); | |
// combine the even and odd partitions | |
const cmplx_size_t<flt_t> nf = (flt_t)len; | |
for (cmplx_size_t<flt_t> k = 0; k < halfLen; ++k) { | |
const flt_t kf = (flt_t)k; | |
const flt_t w = -2.0 * kf * pi / nf; | |
const std::complex<flt_t> wk {std::cos(w), std::sin(w)}; | |
x[k] = evens[k] + (wk * odds[k]); | |
x[k+halfLen] = evens[k] - (wk * odds[k]); | |
} | |
} | |
template <typename flt_t> | |
void ifft(cmplx_list_t<flt_t>& x) { | |
for (cmplx_value_t<flt_t>& e : x) { | |
e = std::conj(e); | |
} | |
fft<flt_t>(x); | |
const flt_t len = static_cast<flt_t>(x.size()); | |
for (cmplx_value_t<flt_t>& e : x) { | |
e = std::conj(e) / len; | |
} | |
} | |
template <typename flt_t> | |
cmplx_list_t<flt_t> create_fft_table(const bignum& a, const bignum& b) { | |
// Create a list of complex numbers with interleaved values from the two | |
// input numbers. The output list must have a length that's a power of 2. | |
bignum::size_type aLen = a.size(); | |
bignum::size_type bLen = b.size(); | |
// Ensure the Cooley-Tukey algorithm has its radix=2 requirement fulfilled. | |
// Add a few digits to the end so the algorithm has room for overflow. | |
bignum::size_type size = aLen + bLen; | |
if (!is_pow2(size)) { | |
size = next_pow2(size); | |
} | |
cmplx_list_t<flt_t> ret; | |
ret.reserve(size); | |
for (cmplx_size_t<flt_t> i = 0; i < size; ++i) { | |
// Add some zero-padding to the output list if necessary | |
const bignum::value_type aVal = i < aLen ? a[i] : 0; | |
const bignum::value_type bVal = i < bLen ? b[i] : 0; | |
ret.push_back(std::complex<flt_t>{(flt_t)aVal, (flt_t)bVal}); | |
} | |
return ret; | |
} | |
template <typename flt_t> | |
void convolute_fft(cmplx_list_t<flt_t>& fftTable) { | |
const cmplx_size_t<flt_t> fftSize = fftTable.size(); | |
// transform. | |
cmplx_list_t<flt_t> transforms{fftTable}; | |
fft<flt_t>(transforms); | |
// point-wise multiplication in frequency domain. | |
for (cmplx_size_t<flt_t> i = 0; i < fftSize; ++i) { | |
// extract the individual transformed signals from the composed one. | |
const cmplx_value_t<flt_t>& ti = transforms[i]; | |
const cmplx_value_t<flt_t>&& tc = std::conj(transforms[-i % fftSize]); | |
// perform convolution | |
const cmplx_value_t<flt_t> x1 = ti + tc; | |
const cmplx_value_t<flt_t> x2 = ti - tc; | |
const cmplx_value_t<flt_t> x3 = x1 * x2; | |
// avoid pedantic compilers | |
constexpr flt_t rotation = flt_t{0.25}; | |
fftTable[i] = std::complex<flt_t>{x3.imag(), -x3.real()} * rotation; | |
} | |
} | |
bignum mul_strassen(const bignum& a, const bignum& b) { | |
// Default type is double. Use floats if they perform well enough. | |
typedef double flt_t; | |
// building a complex signal with the information of both signals. | |
cmplx_list_t<flt_t> fftTable = std::move(create_fft_table<flt_t>(a, b)); | |
convolute_fft<flt_t>(fftTable); | |
ifft<flt_t>(fftTable); | |
const cmplx_list_t<flt_t>& inverses = fftTable; | |
bignum ret; | |
ret.reserve(inverses.size()); | |
for (cmplx_size_t<flt_t> i = 0, c = 0; i < inverses.size(); ++i) { | |
// drop imaginary part of the number | |
const flt_t x = inverses[i].real(); | |
// round to an integer | |
const bn_double_t ci = (bn_double_t)(c + std::floor(x + 0.5)); | |
ret.push_back(ci % NUM_BASE); | |
// carry propagation | |
c = (ci / NUM_BASE); | |
} | |
// trim trailing zeroes from the most-significant digits | |
bignum::size_type numZeroes = 0; | |
for (bignum::size_type i = ret.size(); i --> 0;) { | |
if (ret[i]) { | |
break; | |
} | |
++numZeroes; | |
} | |
ret.resize(ret.size() - numZeroes); | |
return ret; | |
} | |
bignum::size_type get_num_bytes(const bignum& b) { | |
return sizeof(bignum::value_type) * b.size(); | |
} | |
void print_bignum_stats(const std::string& name, const bignum& b) { | |
std::cout << name << " element count: " << b.size() << std::endl; | |
std::cout << name << " byte size: " << get_num_bytes(b) << std::endl; | |
} | |
void run_mult_bench( | |
const std::string& name, | |
const bignum& a, | |
const bignum& b, | |
bignum& result, | |
bignum (*mul_func)(const bignum&, const bignum&) | |
) { | |
std::cout << "Performing " << name << " Multiplication..." << std::endl; | |
const steady_clock::time_point start = steady_clock::now(); | |
result = std::move(mul_func(a, b)); | |
const steady_clock::time_point end = steady_clock::now(); | |
const duration elapsed = end - start; | |
std::cout << name << " method time: " << elapsed.count() << "s." << std::endl; | |
print_bignum_stats(name, result); | |
std::cout << std::endl; | |
} | |
/** | |
* Main() | |
*/ | |
int main(int, char**) { | |
#ifdef NUM_BASE_16 | |
bignum a = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}; | |
bignum b = {15,14,13,12,11,10,9,8,7,6,5,4,3,2,1}; | |
#elif defined(NUM_BASE_8) | |
bignum a = {0,1,2,3,4,5,6,7,0,1,2,3,4,5,6,7}; | |
bignum b = {0,7,6,5,4,3,2,1,0,7,6,5,4,3,2,1}; | |
#elif defined(NUM_BASE_2) | |
bignum a = {0,1,1,1,0,1,1,1,0,1,1,1,0,1,1,1,0,1,1,1}; | |
bignum b = {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1}; | |
#else | |
bignum a = {1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9}; | |
bignum b = {9,8,7,6,5,4,3,2,1,0,9,8,7,6,5,4,3,2,1}; | |
#endif | |
for (unsigned i = 0; i < SIZE_MULTIPLIER_A; ++i) { | |
a.insert(a.end(), a.begin(), a.end()); | |
} | |
for (unsigned i = 0; i < SIZE_MULTIPLIER_B; ++i) { | |
b.insert(b.end(), b.begin(), b.end()); | |
} | |
/* | |
if (!tokenize(argc, argv, a, b)) { | |
return -1; | |
} | |
*/ | |
std::cout << "Testing bignum multiplication." << std::endl; | |
print_bignum_stats("Test number A", a); | |
print_bignum_stats("Test Number B", b); | |
std::cout << '\n' << a << "\n\n" << b << '\n' << std::endl; | |
/* | |
bignum naive; | |
std::thread t1{ | |
run_mult_bench, | |
"Naive", | |
std::ref(a), | |
std::ref(b), | |
std::ref(naive), | |
mul_naive | |
}; | |
*/ | |
bignum strassen; | |
std::thread t2{ | |
run_mult_bench, | |
"Strassen", | |
std::ref(a), | |
std::ref(b), | |
std::ref(strassen), | |
mul_strassen | |
}; | |
//t1.join(); | |
t2.join(); | |
//std::cout << '\n' << naive << "\n\n" << strassen << '\n' << std::endl; | |
std::cout << '\n' << strassen << '\n' << std::endl; | |
//std::cout << "Naive == Strassen: " << (naive == strassen) << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment