Skip to content

Instantly share code, notes, and snippets.

@simonlindholm
Created May 16, 2020 22:56
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simonlindholm/51f88e9626408723cf906c6debd3814b to your computer and use it in GitHub Desktop.
Save simonlindholm/51f88e9626408723cf906c6debd3814b to your computer and use it in GitHub Desktop.
#include <bits/stdc++.h>
#include <immintrin.h>
using namespace std;
#define rep(i, from, to) for (int i = from; i < (to); ++i)
#define trav(a, x) for (auto& a : x)
#define all(x) x.begin(), x.end()
#define sz(x) (int)(x).size()
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
ll modpow(ll a, ll e, ll mod) {
if (e == 0) return 1 % mod;
ll x = modpow(a, e >> 1, mod);
x = x * x % mod;
if (e & 1)
x = x * a % mod;
return x;
}
unsigned modmul(unsigned a, unsigned b, unsigned M) {
int ret = a * b - M * unsigned(1.0 / M * a * b);
return ret + M * (ret < 0) - M * (ret >= (ll)M);
}
int modmul2(int a, int b, int M) {
unsigned ret = (unsigned)a * (unsigned)b - (unsigned)M * (unsigned)(int)(1.0 / M * a * b);
return (int)ret;
}
typedef unsigned long long ull;
typedef __uint128_t L;
struct FastMod {
ull b, m;
FastMod(ull b) : b(b), m(ull((L(1) << 64) / b)) {}
ull reduce(ull a) {
ull q = (ull)((L(m) * a) >> 64), r = a - q * b;
return r >= b ? r - b : r;
}
};
struct N {
int x = 0;
};
struct Mont {
int Mod, R1Mod, R2Mod, NPrime;
Mont(int mod);
N redc(int a, int b);
N raw(int x) { N r; r.x = x; return r; }
N from(int x) { assert (x < Mod); return redc(x, R2Mod); }
N one() { return raw(R1Mod); }
int get(N a) { return redc(a.x, 1).x; }
N mul(N a, N b) { return redc(a.x, b.x); }
N add(N a, N b) { int x = a.x + b.x; if (x >= Mod) x -= Mod; return raw(x); }
N sub(N a, N b) { int x = a.x - b.x; if (a.x < b.x) x += Mod; return raw(x); }
};
Mont::Mont(int mod) : Mod(mod) {
const ll B = (1LL << 32);
assert((mod & 1) != 0);
ll R = B % mod;
ll xinv = 1, bit = 2;
for (int i = 1; i < 32; i++, bit <<= 1) { // Hensel lifting!
ll y = xinv * mod;
if ((y & bit) != 0)
xinv |= bit;
}
assert(((mod * xinv) & (B-1)) == 1);
R1Mod = (int)R;
R2Mod = (int)(R * R % mod);
NPrime = (int)(B - xinv);
}
N Mont::redc(int a, int b) {
ll T = (ll)a * b;
ll m = (unsigned)T * NPrime;
T += m * Mod;
T >>= 32;
if (T >= Mod)
T -= Mod;
return raw((int)T);
}
const int M = 1'000'000'007;
int M_dynamic = M;
int main(int argc, char** argv) {
cin.sync_with_stdio(false);
cin.exceptions(cin.failbit);
int method = atoi(argv[1]);
if (method == 0) {
// Really naive, with dynamic modulo. 19.161s.
ll prod = 1;
for (int i = 1; i < M; i++) {
prod = prod * i % M_dynamic;
}
cout << prod << endl;
}
if (method == 1) {
// Naive. 5.569s.
ll prod = 1;
for (int i = 1; i < M; i++) {
prod = prod * i % M;
}
cout << prod << endl;
}
else if (method == 2) {
// Parallel really naive. 10.423s.
const int PAR = 8;
ll prods[PAR];
rep(i,0,PAR) prods[i] = 1;
int i = 1;
for (; i + PAR <= M;) {
rep(j,0,PAR)
prods[j] = prods[j] * i % M_dynamic, i++;
}
ll prod = 1;
rep(i,0,PAR) prod = prod * prods[i] % M_dynamic;
while (i < M) {
prod = prod * i % M_dynamic; i++;
}
cout << prod << endl;
}
else if (method == 3) {
// Parallel, to avoid latency bottlenecks. 1.453s.
const int PAR = 8;
ll prods[PAR];
rep(i,0,PAR) prods[i] = 1;
int i = 1;
for (; i + PAR <= M;) {
rep(j,0,PAR)
prods[j] = prods[j] * i % M, i++;
}
ll prod = 1;
rep(i,0,PAR) prod = prod * prods[i] % M;
while (i < M) {
prod = prod * i % M; i++;
}
cout << prod << endl;
}
else if (method == 4) {
// Floating-point modmul (like KACTL's version but with doubles). 3.088s.
const int PAR = 8;
unsigned prods[PAR];
rep(i,0,PAR) prods[i] = 1;
int i = 1;
for (; i + PAR <= M;) {
rep(j,0,PAR)
prods[j] = modmul(prods[j], i, M), i++;
}
ll prod = 1;
rep(i,0,PAR) prod = prod * prods[i] % M;
while (i < M) {
prod = prod * i % M; i++;
}
cout << prod << endl;
}
else if (method == 5) {
// Relaxed floating-point modmul (without the final reduction in the
// function, allowing for negative and out-of-range numbers). 2.024s.
const int PAR = 8;
int prods[PAR];
rep(i,0,PAR) prods[i] = 1;
int i = 1;
for (; i + PAR <= M;) {
rep(j,0,PAR)
prods[j] = modmul2(prods[j], i, M), i++;
}
ll prod = 1;
rep(i,0,PAR) prod = prod * (prods[i] % M) % M;
if (prod < 0) prod += M;
while (i < M) {
prod = prod * i % M; i++;
}
cout << prod << endl;
}
else if (method == 6) {
// SIMD using relaxed floating-point modmul. 0.691s.
const int PAR = 8;
typedef __m128i mi;
typedef __m256d md;
mi ones = _mm_set1_epi32(1);
mi prods[PAR];
mi iaccs[PAR];
rep(i,0,PAR) {
prods[i] = ones;
iaccs[i] = _mm_setr_epi32(4*i+1, 4*i+2, 4*i+3, 4*i+4);
}
mi iaccadd = _mm_set1_epi32(4 * PAR);
mi ms = _mm_set1_epi32(M);
md minv = _mm256_set1_pd(1.0 / M);
int i = 1;
for (; i + 4 * PAR <= M; i += 4 * PAR) {
rep(j,0,PAR) {
mi a = prods[j];
mi b = iaccs[j];
iaccs[j] = _mm_add_epi32(b, iaccadd);
mi ab = _mm_mullo_epi32(a, b);
mi fltresult = _mm256_cvtpd_epi32(
_mm256_mul_pd(
_mm256_mul_pd(minv,
_mm256_cvtepi32_pd(b)),
_mm256_cvtepi32_pd(a)
)
);
mi res = _mm_sub_epi32(ab, _mm_mullo_epi32(ms, fltresult));
prods[j] = res;
}
}
ll prod = 1;
rep(i,0,PAR) {
union {
int i[4];
mi m;
} u;
u.m = prods[i];
rep(j,0,4) prod = prod * (u.i[j] % M) % M;
}
if (prod < 0) prod += M;
while (i < M) {
prod = prod * i % M; i++;
}
cout << prod << endl;
}
else if (method == 7) {
// Barrett reduction. 1.732s.
const int PAR = 8;
FastMod fm(M);
ull prods[PAR];
rep(i,0,PAR) prods[i] = 1;
int i = 1;
for (; i + PAR <= M;) {
rep(j,0,PAR)
prods[j] = fm.reduce(prods[j] * i), i++;
}
ull prod = 1;
rep(i,0,PAR) prod = fm.reduce(prod * prods[i]);
while (i < M) {
prod = fm.reduce(prod * i), i++;
}
cout << prod << endl;
}
else if (method == 8) {
// Montgomery multiplication. 1.668s.
const int PAR = 8;
Mont mont(M);
N prods[PAR];
rep(i,0,PAR) prods[i] = mont.one();
int i = 1;
for (; i + PAR <= M;) {
rep(j,0,PAR)
prods[j] = mont.mul(prods[j], mont.raw(i)), i++;
}
N prod = mont.one();
rep(i,0,PAR) prod = mont.mul(prod, prods[i]);
while (i < M) {
prod = mont.mul(prod, mont.raw(i)), i++;
}
// We ought to multiply by R^(M-1) to account for the non-Montgomery
// form numbers that got multiplied in. But that's 1, so no need.
cout << mont.get(prod) << endl;
}
else if (method == 9) {
// SIMD Montgomery multiplication. 0.493s.
const int PAR = 8;
typedef __m256i mi;
Mont mont(M);
mi prods[PAR];
mi accs[PAR];
rep(i,0,PAR) {
prods[i] = _mm256_set1_epi64x(mont.one().x);
accs[i] = _mm256_setr_epi64x(4*i + 1, 4*i + 2, 4*i + 3, 4*i + 4);
}
mi iaccadd = _mm256_set1_epi64x(4 * PAR);
mi mnprime = _mm256_set1_epi64x(mont.NPrime);
mi mmod = _mm256_set1_epi64x(mont.Mod);
int i = 1;
for (; i + PAR * 4 <= M; i += PAR * 4) {
rep(j,0,PAR) {
mi a = prods[j];
mi b = accs[j];
accs[j] = _mm256_add_epi64(b, iaccadd);
mi T = _mm256_mul_epu32(a, b);
mi m = _mm256_mul_epu32(T, mnprime); // uses lo 32 bits of T
T = _mm256_add_epi64(T, _mm256_mul_epu32(m, mmod)); // uses lo 32 bits of m
T = _mm256_srli_epi64(T, 32);
T = _mm256_sub_epi64(T,
_mm256_andnot_si256(
_mm256_cmpgt_epi32(mmod, T),
mmod));
prods[j] = T;
}
}
N prod = mont.one();
rep(i,0,PAR) {
union {
ull i[4];
mi m;
} u;
u.m = prods[i];
rep(j,0,4)
prod = mont.mul(prod, mont.raw((int)u.i[j]));
}
while (i < M) {
prod = mont.mul(prod, mont.raw(i)), i++;
}
// We ought to multiply by R^(M-1) to account for the non-Montgomery
// form numbers that got multiplied in. But that's 1, so no need.
cout << mont.get(prod) << endl;
}
else if (method == 10) {
// Combined SIMD Montgomery multiplication and float multiplication,
// based on the theory that maybe int and float SIMD operations run
// in parallel. Turns out that's not true. 0.620s (slightly above the
// average of the two).
const int PAR = 4;
typedef __m256i mi;
typedef __m128i mi128;
typedef __m256d md;
// Montgomery setup
Mont mont(M);
mi prods[PAR];
mi accs[PAR];
rep(i,0,PAR) {
prods[i] = _mm256_set1_epi64x(mont.one().x);
accs[i] = _mm256_setr_epi64x(8*i+1, 8*i+2, 8*i+3, 8*i+4);
}
mi iaccadd = _mm256_set1_epi64x(8 * PAR);
mi mnprime = _mm256_set1_epi64x(mont.NPrime);
mi mmod = _mm256_set1_epi64x(mont.Mod);
// Float setup
mi128 ones = _mm_set1_epi32(1);
mi128 fpprods[PAR];
mi128 iaccs[PAR];
rep(i,0,PAR) {
fpprods[i] = ones;
iaccs[i] = _mm_setr_epi32(8*i+5, 8*i+6, 8*i+7, 8*i+8);
}
mi128 iaccadd128 = _mm_set1_epi32(8 * PAR);
mi128 ms = _mm_set1_epi32(M);
md minv = _mm256_set1_pd(1.0 / M);
int i = 1;
int montmuls = 0;
for (; i + PAR * 8 <= M; i += PAR * 8) {
montmuls += PAR * 4;
rep(j,0,PAR) {
// Montgomery multiplication
{
mi a = prods[j];
mi b = accs[j];
accs[j] = _mm256_add_epi64(b, iaccadd);
mi T = _mm256_mul_epu32(a, b);
mi m = _mm256_mul_epu32(T, mnprime); // uses lo 32 bits of T
T = _mm256_add_epi64(T, _mm256_mul_epu32(m, mmod)); // uses lo 32 bits of m
T = _mm256_srli_epi64(T, 32);
T = _mm256_sub_epi64(T,
_mm256_andnot_si256(
_mm256_cmpgt_epi32(mmod, T),
mmod));
prods[j] = T;
}
// Float multiplication
{
mi128 a = fpprods[j];
mi128 b = iaccs[j];
iaccs[j] = _mm_add_epi32(b, iaccadd128);
mi128 ab = _mm_mullo_epi32(a, b);
mi128 fltresult = _mm256_cvtpd_epi32(
_mm256_mul_pd(
_mm256_mul_pd(minv,
_mm256_cvtepi32_pd(b)),
_mm256_cvtepi32_pd(a)
)
);
mi128 res = _mm_sub_epi32(ab, _mm_mullo_epi32(ms, fltresult));
fpprods[j] = res;
}
}
}
// Combine Montgomery accumulators
N prod = mont.one();
rep(i,0,PAR) {
union {
ull i[4];
mi m;
} u;
u.m = prods[i];
rep(j,0,4)
prod = mont.mul(prod, mont.raw((int)u.i[j]));
}
while (i < M) {
montmuls++;
prod = mont.mul(prod, mont.raw(i));
i++;
}
ll res = mont.get(prod);
// Multiply by R^montmuls to account for the non-Montgomery form
// numbers that got multiplied in.
res = res * modpow(mont.R1Mod, montmuls, M) % M;
// Combine floating-point accumulators
rep(i,0,PAR) {
union {
int i[4];
mi128 m;
} u;
u.m = fpprods[i];
rep(j,0,4) res = res * (u.i[j] % M) % M;
}
if (res < 0) res += M;
cout << res << endl;
}
exit(0);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment