Skip to content

Instantly share code, notes, and snippets.

@csaftoiu
Created June 20, 2016 23:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csaftoiu/9194ef9ffd98b7b106b359ca55557010 to your computer and use it in GitHub Desktop.
Save csaftoiu/9194ef9ffd98b7b106b359ca55557010 to your computer and use it in GitHub Desktop.
64-bit multiplication using 32-bit math
#include <cstdlib>
#include <iostream>
#include <vector>
#include <string.h>
using namespace std;
static inline int64_t neg_cpp(int64_t a) {
return -a;
}
static inline int64_t neg_cast(int64_t a) {
return (int64_t)(-((uint64_t)a));
}
static inline int64_t neg_alg(int64_t a) {
return (int64_t)(~((uint64_t)a) + 1);
}
uint64_t uiadd64(uint64_t lhs, uint64_t rhs) {
uint32_t res_hi,
res_lo;
uint32_t lhs_hi,
lhs_lo,
rhs_hi,
rhs_lo;
lhs_hi = lhs >> 32;
lhs_lo = lhs & 0xffffffff;
rhs_hi = rhs >> 32;
rhs_lo = rhs & 0xffffffff;
res_lo = lhs_lo + rhs_lo;
res_hi = lhs_hi + rhs_hi;
if (res_lo < lhs_lo || res_lo < rhs_lo) {
res_hi++;
}
return (int64_t(res_hi) << 32) | res_lo;
}
int64_t iadd64(int64_t lhs, int64_t rhs)
{
return (int64_t)(uiadd64((uint64_t)lhs, (uint64_t)rhs));
}
// use 32-bit math to multiply 16-bit numbers and get overflow
static inline uint16_t mul_16(uint16_t a, uint16_t b, uint16_t *overflow) {
uint32_t res = (uint32_t)(a) * (uint32_t)(b);
*overflow = (uint16_t)(res >> 16);
return (uint16_t)res;
}
static inline uint16_t sum_16(uint16_t a, uint16_t b, uint16_t *overflow) {
uint32_t res = (uint32_t)(a) + (uint32_t)(b);
*overflow = (uint16_t)(res >> 16);
return (uint16_t)res;
}
static inline uint16_t sum_5_16(uint16_t a, uint16_t b, uint16_t c, uint16_t d, uint16_t e, uint16_t *overflow) {
uint32_t res = (uint32_t)(a) + (uint32_t)(b) + (uint32_t)(c) + (uint32_t)(d) + (uint32_t)(e);
*overflow = (uint16_t)(res >> 16);
return (uint16_t)res;
}
// implementing this algorithm: http://i.imgur.com/GPyYz5h.jpg
static inline int64_t mul_alg(int64_t a, int64_t b)
{
// convert to positive
int16_t neg = 1;
if (a < 0) {
neg *= -1;
a = -a;
}
if (b < 0) {
neg *= -1;
b = -b;
}
// extract 16 bits out of each
uint16_t A[4], B[4];
A[3] = (uint16_t)(a >> 48); A[2] = (uint16_t)(a >> 32); A[1] = (uint16_t)(a >> 16); A[0] = (uint16_t)(a);
B[3] = (uint16_t)(b >> 48); B[2] = (uint16_t)(b >> 32); B[1] = (uint16_t)(b >> 16); B[0] = (uint16_t)(b);
/*
cout << "A: ";
for (int8_t col=3; col >= 0; col--) {
cout << A[col] << ", ";
}
cout << endl;
cout << "B: ";
for (int8_t col=3; col >= 0; col--) {
cout << B[col] << ", ";
}
cout << endl;*/
// store results in a 4x5 table, rows by columns. 1 extra col for overflow logic
uint16_t F[4][5];
memset(F, 0, sizeof(F));
// start multiplying it out, the long way
for (uint8_t row=0; row < 4; row++) {
for (uint8_t col=row; col < 4; col++) {
// multiply, store remainder in next cell
uint16_t rem, mul_res;
mul_res = mul_16(A[col - row], B[row], &F[row][col + 1]);
// add to current cell (= remainder from previous cell), keep track of remainder
F[row][col] = sum_16(F[row][col], mul_res, &rem);
// add sum remainder to mul remainder to next cell, remainder now should always be 0
F[row][col + 1] = sum_16(F[row][col + 1], rem, &rem);
if (rem != 0) {
throw std::runtime_error("math logic fail");
}
}
}
// now sum the intermediate values into the result
uint16_t R[5]; // 1 extra to simplify logic
memset(R, 0, sizeof(R));
for (uint8_t col=0; col < 4; col++) {
R[col] = sum_5_16(R[col], F[0][col], F[1][col], F[2][col], F[3][col], &R[col + 1]);
}
/*
// print results
for (uint8_t row=0; row < 4; row++) {
cout << "F[" << (uint16_t)row << "]: ";
for (int8_t col=3; col >= 0; col--) {
cout << F[row][col] << ", ";
}
cout << endl;
}
cout << "R: ";
for (int8_t col=3; col >= 0; col--) {
cout << R[col] << ", ";
}
cout << endl;*/
// convert back into a 64-bit number
int64_t res = (int64_t)(uiadd64(
uiadd64((uint64_t)(R[0]), ((uint64_t)(R[1]) << 16)),
uiadd64(((uint64_t)(R[2]) << 32), ((uint64_t)(R[3]) << 48))));
// re-negate if needed
res *= neg;
return res;
}
static inline int64_t mul_cpp(int64_t a, int64_t b) {
return a * b;
}
int main()
{
std::vector<int64_t> nums = {-1, -100, 0, 1, 0x10101010, 0xffffffff, -INT64_C(0x7fffffffffffffff), -INT64_C(0x7fffffffffffffff) - 1, INT64_C(0x7fffffffffffffff)};
for (int64_t a : nums) {
for (int64_t b : nums) {
std::cout << "C++: " << a << " * " << b << " = " << mul_cpp(a, b) << std::endl;
std::cout << "alg: " << a << " * " << b << " = " << mul_alg(a, b) << std::endl;
std::cout << "--" << std::endl;
}
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment