Skip to content

Instantly share code, notes, and snippets.

@simonlindholm
Created January 3, 2019 17:47
Show Gist options
  • Save simonlindholm/2216d26945ca5a9579cb9e2142f7d2a6 to your computer and use it in GitHub Desktop.
Save simonlindholm/2216d26945ca5a9579cb9e2142f7d2a6 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <cstdlib>
#include <cassert>
using namespace std;
typedef long long ll;
struct Barrett {
int m, k, mod;
Barrett(int mod) : mod(mod) {
k = 0;
int p2 = 1;
while (p2 <= mod)
p2 <<= 1, k++;
m = (int)((ll)p2 * p2 / mod);
}
int modmul(int A, int B) {
ll a = A * B;
ll q = (m * a) >> (2*k);
int r = (int)(a - q * mod);
if (r >= mod) r -= mod;
return r;
}
};
struct N {
int x = 0;
};
struct Mont {
int Mod, R1Mod, R2Mod, NPrime;
Mont(int mod);
N redc(int a, int b);
N raw(int x) { N r; r.x = x; return r; }
N from(int x) { assert (x < Mod); return redc(x, R2Mod); }
N one() { return raw(R1Mod); }
int get(N a) { return redc(a.x, 1).x; }
N mul(N a, N b) { return redc(a.x, b.x); }
N add(N a, N b) { int x = a.x + b.x; if (x >= Mod) x -= Mod; return raw(x); }
N sub(N a, N b) { int x = a.x - b.x; if (a.x < b.x) x += Mod; return raw(x); }
};
Mont::Mont(int mod) : Mod(mod) {
const ll B = (1LL << 32);
assert((mod & 1) != 0);
ll R = B % mod;
ll xinv = 1, bit = 2;
for (int i = 1; i < 32; i++, bit <<= 1) { // Hensel lifting!
ll y = xinv * mod;
if ((y & bit) != 0)
xinv |= bit;
}
assert(((mod * xinv) & (B-1)) == 1);
R1Mod = (int)R;
R2Mod = (int)(R * R % mod);
NPrime = (int)(B - xinv);
}
N Mont::redc(int a, int b) {
ll T = (ll)a * b;
ll m = (unsigned)T * NPrime;
T += m * Mod;
T >>= 32;
if (T >= Mod)
T -= Mod;
return raw((int)T);
}
int main(int argc, char** argv) {
const int mod = atoi(argv[1]);
Barrett bt(mod);
Mont mt(mod);
cerr << mt.redc(123, 321).x << endl;
ll sum = 0;
for (int a = 0; a < mod-2; ++a) {
for (int b = 0; b < mod-2; ++b) {
// sum += a * b % mod; // 4s
// sum += bt.modmul(a, b); // 2s (but only 21 bits)
sum += mt.redc(a, b).x; // 2s
}
}
sum %= mod;
sum = mt.from((int)sum).x;
cout << sum << endl;
return 0;
}
int main2(int argc, char** argv) {
const int mod = 202171241; // atoi(argv[1]);
ll ret = 1;
for (int i = 1; i < mod; i++) {
ret = (ret * i) % mod;
}
cout << ret << endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment