Skip to content

Instantly share code, notes, and snippets.

@zac-williamson
Created April 1, 2020 13:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zac-williamson/6c0e08db3f0621d9e3d3f0264403b9f3 to your computer and use it in GitHub Desktop.
Save zac-williamson/6c0e08db3f0621d9e3d3f0264403b9f3 to your computer and use it in GitHub Desktop.
// requires barretenberg numeric dependency
#include <stdio.h>
#include <math.h>
#include <inttypes.h>
#include <x86intrin.h>
#include <stdlib.h>
#include <numeric/random/engine.hpp>
namespace {
auto& engine = numeric::random::get_debug_engine();
}
//#include <float.h>
typedef struct {
float hi;
float lo;
} doublefloat;
doublefloat quick_two_sum(float a, float b)
{
float s = a + b;
float e = b - (s - a);
return (doublefloat){ s, e };
// 3 add
}
doublefloat two_sum(float a, float b)
{
return quick_two_sum(std::max(a, b), std::min(a, b));
// 5 ops
}
doublefloat add_doublefloats(doublefloat a, doublefloat b)
{
float r = a.hi + b.hi;
bool gt = fabsf(a.hi) > fabsf(b.hi);
doublefloat left = gt ? a : b;
doublefloat right = gt ? b : a;
float s = ((left.hi - r + right.hi) + right.lo) + left.lo;
return quick_two_sum(std::max(r, s), std::min(r, s));
// 15 ops
}
doublefloat split(float a)
{
constexpr float splitter = (float)((1 << 12) + 1);
float t = splitter * a;
float hi = t - (t - a);
float lo = a - hi;
return (doublefloat){ hi, lo };
// 4 ops
}
doublefloat two_product(float a, float b)
{
float x = a * b;
doublefloat as = split(a);
doublefloat bs = split(b);
float err1 = x - (as.hi * bs.hi);
float err2 = err1 - (as.lo * bs.hi);
float err3 = err2 - (as.hi * bs.lo);
float y = as.lo * bs.lo - err3;
return { x, y };
// 17 ops
}
void mac(float a, float b, float c, float carry_in, float& out, float& carry_out)
{
auto res = two_product(b, c); // 17
const auto t0 = two_sum(a, carry_in); // 5
res = add_doublefloats(res, t0); // 15
if (res.lo < 0) {
res.lo += (float)(1ULL << 23ULL);
res.hi -= (float)(1ULL << 23ULL);
} // 4
float x = res.hi + (float)(1ULL << 46ULL);
float u = x - (float)(1ULL << 46ULL);
float v = res.hi - u;
u = u * ((float)1 / (float)(1ULL << 23ULL)); // 4
if (v < 0) {
v += (float)(1ULL << 23ULL);
u -= (float)1;
} // 4
out = v + res.lo;
if (out >= (float)(1ULL << 23ULL)) {
res.lo -= (float)(1ULL << 23ULL);
out = v + res.lo;
u += 1;
} // 5
carry_out = u;
// 54 ops
}
void mul_test(float* a, float* b, float* out)
{
for (size_t i = 0; i < 24; ++i) {
out[i] = 0;
}
float carry2 = 0;
for (size_t i = 0; i < 12; ++i) {
float carry = 0;
for (size_t j = 0; j < 11; ++j) {
mac(out[i + j], a[i], b[j], carry, out[i + j], carry);
}
mac(carry2, a[i], b[11], carry, out[i + 11], carry2);
}
out[23] = carry2;
}
void convert_into_floats(uint256_t& input, float* output)
{
constexpr uint64_t bit_mask = (1UL << 23UL) - 1UL;
output[0] = (float)(input.data[0] & bit_mask);
output[1] = (float)((input.data[0] >> 23) & bit_mask);
output[2] = (float)((input.data[0] >> 46) + ((input.data[1] & ((1ULL << 5ULL) - 1ULL)) << 18));
output[3] = (float)((input.data[1] >> 5) & bit_mask);
output[4] = (float)((input.data[1] >> 28) & bit_mask);
output[5] = (float)((input.data[1] >> 51) + ((input.data[2] & ((1ULL << 10ULL) - 1ULL)) << 13));
output[6] = (float)((input.data[2] >> 10) & bit_mask);
output[7] = (float)((input.data[2] >> 33) & bit_mask);
output[8] = (float)((input.data[2] >> 56) + ((input.data[3] & ((1ULL << 15ULL) - 1ULL)) << 8));
output[9] = (float)((input.data[3] >> 15) & bit_mask);
output[10] = (float)((input.data[3] >> 38) & bit_mask);
output[11] = (float)((input.data[3] >> 61));
}
void convert_into_ints(float* input, uint512_t& output)
{
uint64_t t0 = (uint64_t)(input[0]) + ((uint64_t)(input[1]) << 23ULL);
t0 += (((uint64_t)input[2]) << 46ULL);
uint64_t t1 = ((uint64_t)(input[2]) >> 18ULL);
t1 += ((uint64_t)(input[3]) << 5ULL);
t1 += ((uint64_t)(input[4]) << 28ULL);
t1 += ((uint64_t)(input[5]) << 51ULL);
uint64_t t2 = ((uint64_t)(input[5]) >> 13ULL);
t2 += ((uint64_t)(input[6]) << 10ULL);
t2 += ((uint64_t)(input[7]) << 33ULL);
t2 += ((uint64_t)(input[8]) << 56ULL);
uint64_t t3 = ((uint64_t)(input[8]) >> 8ULL);
t3 += ((uint64_t)(input[9]) << 15ULL);
t3 += ((uint64_t)(input[10]) << 38ULL);
t3 += ((uint64_t)(input[11]) << 61ULL);
uint64_t t4 = ((uint64_t)(input[11]) >> 3ULL);
t4 += ((uint64_t)(input[12]) << 20ULL);
t4 += ((uint64_t)(input[13]) << 43ULL);
uint64_t t5 = ((uint64_t)(input[13]) >> 21ULL);
t5 += ((uint64_t)(input[14]) << 2ULL);
t5 += ((uint64_t)(input[15]) << 25ULL);
t5 += ((uint64_t)(input[16]) << 48ULL);
uint64_t t6 = ((uint64_t)(input[16]) >> 16ULL);
t6 += ((uint64_t)(input[17]) << 7ULL);
t6 += ((uint64_t)(input[18]) << 30ULL);
t6 += ((uint64_t)(input[19]) << 53ULL);
uint64_t t7 = ((uint64_t)(input[19]) >> 11ULL);
t7 += ((uint64_t)(input[20]) << 12ULL);
t7 += ((uint64_t)(input[21]) << 35ULL);
t7 += ((uint64_t)(input[22]) << 58ULL);
output.lo.data[0] = t0;
output.lo.data[1] = t1;
output.lo.data[2] = t2;
output.lo.data[3] = t3;
output.hi.data[0] = t4;
output.hi.data[1] = t5;
output.hi.data[2] = t6;
output.hi.data[3] = t7;
}
int main(void)
{
bool valid = true;
std::cout << "testing 1,000 256x256->512 bit muls" << std::endl;
for (size_t i = 0; i < 1000; ++i) {
uint256_t left = engine.get_random_uint256();
uint256_t right = engine.get_random_uint256();
uint512_t expected = uint512_t(left) * uint512_t(right);
uint512_t result;
float left_floats[12];
float right_floats[12];
float output_floats[24];
convert_into_floats(left, left_floats);
convert_into_floats(right, right_floats);
mul_test(left_floats, right_floats, output_floats);
convert_into_ints(output_floats, result);
if (result != expected) {
valid = false;
}
}
if (valid) {
std::cout << "pass" << std::endl;
} else {
std::cout << "fail" << std::endl;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment