Skip to content

Instantly share code, notes, and snippets.

@weikengchen
Created March 25, 2021 20:41
Show Gist options
  • Save weikengchen/ac6de8086f2144bd3677e2771aaf015c to your computer and use it in GitHub Desktop.
Save weikengchen/ac6de8086f2144bd3677e2771aaf015c to your computer and use it in GitHub Desktop.
utility.h
#ifndef FP_UTILITY_H__
#define FP_UTILITY_H__
#include <emp-tool/emp-tool.h>
using namespace emp;
using namespace std;
#define MERSENNE_PRIME_EXP 62
#define FIELD_TYPE uint64_t
const static __uint128_t p = 4611686018326724609;
const static int r = 1;
const static __uint128_t pr = 4611686018326724609;
const static block prs = makeBlock(4611686018326724609ULL, 4611686018326724609ULL);
const static uint64_t mask = 2305843009213693951LL;
const static uint64_t PR = 4611686018326724609;
const uint64_t add[8] = {
0ULL,
2305843009213693952ULL,
100663295ULL,
2305843009314357247ULL,
201326590ULL,
2305843009415020542ULL,
301989885ULL,
2305843009515683837ULL,
};
static __m128i PRs = makeBlock(PR, PR);
#if defined(__x86_64__) && defined(__BMI2__)
inline uint64_t mul64(uint64_t a, uint64_t b, uint64_t * c) {
return _mulx_u64((unsigned long long )a, (unsigned long long) b, (unsigned long long*)c);
}
//
#else
inline uint64_t mul64(uint64_t a, uint64_t b, uint64_t * c) {
__uint128_t aa = a;
__uint128_t bb = b;
auto cc = aa*bb;
*c = cc>>64;
return (uint64_t)cc;
}
#endif
inline uint64_t mod(uint64_t x) {
uint64_t i = add[x >> 61] + (x & mask);
return (i >= p) ? i - p : i;
}
inline __uint128_t mod(__uint128_t k, __uint128_t p) {
uint64_t hi = _mm_extract_epi64((block)k, 1);
uint64_t lo = _mm_extract_epi64((block)k, 0);
// assume p is pr
unsigned long long res;
res = mod(lo);
// add hi * 2^28
unsigned long long res1;
res1 = ((hi << 29) >> 29) << 28;
res1 += (hi >> 35) << 27;
res1 += (hi >> 35) << 26;
res1 -= (hi >> 35) << 1;
res1 = mod(res1);
// add hi * 2^27
unsigned long long res2;
res2 = ((hi << 28) >> 28) << 27;
res2 += (hi >> 36) << 27;
res2 += (hi >> 36) << 26;
res2 -= (hi >> 36) << 1;
res2 = mod(res2);
// sub hi * 2^2
unsigned long long res3;
res3 = ((hi << 3) >> 3) << 2;
res3 += (hi >> 61) << 27;
res3 += (hi >> 61) << 26;
res3 -= (hi >> 61) << 1;
res3 = mod(res3);
res3 = (p - res3) % p;
res += res1;
res %= p;
res += res2;
res %= p;
res += res3;
res %= p;
return res;
}
inline block vec_partial_mod(block i) {
return _mm_sub_epi64(i, _mm_andnot_si128(_mm_cmpgt_epi64(prs,i), prs));
}
inline block vec_mod(block i) {
uint64_t H = _mm_extract_epi64(i, 1);
uint64_t L = _mm_extract_epi64(i, 0);
return makeBlock(mod(H), mod(L));
}
inline uint64_t mult_mod(uint64_t a, uint64_t b) {
unsigned long long lo, hi;
lo = _mulx_u64(a, b, &hi);
unsigned long long res;
res = mod(lo);
// add hi * 2^28
unsigned long long res1;
res1 = ((hi << 29) >> 29) << 28;
res1 += (hi >> 35) << 27;
res1 += (hi >> 35) << 26;
res1 -= (hi >> 35) << 1;
res1 = mod(res1);
// add hi * 2^27
unsigned long long res2;
res2 = ((hi << 28) >> 28) << 27;
res2 += (hi >> 36) << 27;
res2 += (hi >> 36) << 26;
res2 -= (hi >> 36) << 1;
res2 = mod(res2);
// sub hi * 2^2
unsigned long long res3;
res3 = ((hi << 3) >> 3) << 2;
res3 += (hi >> 61) << 27;
res3 += (hi >> 61) << 26;
res3 -= (hi >> 61) << 1;
res3 = mod(res3);
res3 = (p - res3) % p;
res += res1;
res %= p;
res += res2;
res %= p;
res += res3;
res %= p;
return res;
}
inline block mult_mod(block a, uint64_t b) {
uint64_t a_num[2];
a_num[1] = _mm_extract_epi64(a, 1);
a_num[0] = _mm_extract_epi64(a, 0);
uint64_t c_num[2];
for(int i = 0; i < 2; i++) {
c_num[i] = mult_mod(a_num[i], b);
}
return makeBlock(c_num[1], c_num[0]);
}
inline void mult_mod_bch2(block* res, block *a, uint64_t *b) {
uint64_t a_num[4];
uint64_t c_num[4];
for(int i = 0; i < 2; ++i) {
a_num[2 * i + 1] = _mm_extract_epi64(a[i], 1);
a_num[2 * i] = _mm_extract_epi64(a[i], 0);
c_num[2 * i + 1] = mult_mod(a_num[2 * i + 1], b[i]);
c_num[2 * i] = mult_mod(a_num[2 * i], b[i]);
}
res[0] = makeBlock(c_num[1], c_num[0]);
res[1] = makeBlock(c_num[3], c_num[2]);
}
inline void mult_mod_bch2(uint64_t *res, uint64_t *a, uint64_t *b) {
res[1] = mult_mod(a[1], b[1]);
res[0] = mult_mod(a[0], b[0]);
}
inline void mult_mod_bch4(uint64_t *res, uint64_t *a, uint64_t *b) {
res[3] = mult_mod(a[3], b[3]);
res[2] = mult_mod(a[2], b[2]);
res[1] = mult_mod(a[1], b[1]);
res[0] = mult_mod(a[0], b[0]);
}
inline block add_mod(block a, block b) {
block res = _mm_add_epi64(a, b);
return vec_partial_mod(res);
}
inline block add_mod(block a, uint64_t b) {
block res = _mm_add_epi64(a, _mm_set_epi64((__m64)b, (__m64)b));
return vec_partial_mod(res);
}
inline uint64_t add_mod(uint64_t a, uint64_t b) {
uint64_t res = a + b;
return (res >= PR) ? (res - PR) : res;
}
inline void extract_fp(__uint128_t& x) {
x = mod(_mm_extract_epi64((block)x, 0));
}
inline void extract_fp_whole_uint128(__uint128_t& x) {
uint64_t hi = _mm_extract_epi64((block)x, 1);
uint64_t lo = _mm_extract_epi64((block)x, 0);
uint64_t res;
res = mod(lo);
// add hi * 2^28
uint64_t res1;
res1 = ((hi << 29) >> 29) << 28;
res1 += (hi >> 35) << 27;
res1 += (hi >> 35) << 26;
res1 -= (hi >> 35) << 1;
res1 = mod(res1);
// add hi * 2^27
uint64_t res2;
res2 = ((hi << 28) >> 28) << 27;
res2 += (hi >> 36) << 27;
res2 += (hi >> 36) << 26;
res2 -= (hi >> 36) << 1;
res2 = mod(res2);
// sub hi * 2^2
uint64_t res3;
res3 = ((hi << 3) >> 3) << 2;
res3 += (hi >> 61) << 27;
res3 += (hi >> 61) << 26;
res3 -= (hi >> 61) << 1;
res3 = mod(res3);
res3 = (p - res3) % p;
res += res1;
res %= p;
res += res2;
res %= p;
res += res3;
res %= p;
x = res;
}
template<typename T>
void uni_hash_coeff_gen(T* coeff, T seed, int sz) {
coeff[0] = seed;
for(int i = 1; i < sz; ++i)
coeff[i] = mult_mod(coeff[i-1], seed);
}
template<typename T>
T vector_inn_prdt_sum_red(const T *a, const T *b, int sz) {
T r = (T)0;
for(int i = 0; i < sz; ++i)
r = add_mod(r, mult_mod(a[i], b[i]));
return r;
}
template<typename S, typename T>
T vector_inn_prdt_sum_red(const S *a, const T *b, int sz) {
T r = (T)0;
for(int i = 0; i < sz; ++i)
r = add_mod(r, mult_mod((T)a[i], b[i]));
return r;
}
/*
void feq_send(NetIO *io, void* in, int nbytes) {
Hash hash;
block h = hash.hash_for_block(in, nbytes);
io->send_data(&h, sizeof(block));
}
bool feq_recv(NetIO *io, void* in, int nbytes) {
Hash hash;
block h = hash.hash_for_block(in, nbytes);
block r;
io->recv_data(&r, sizeof(block));
if(!cmpBlock(&r, &h, 1)) return false;
else return true;
}
*/
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment