-
-
Save simonlindholm/2216d26945ca5a9579cb9e2142f7d2a6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cstdlib> | |
#include <cassert> | |
using namespace std; | |
typedef long long ll; | |
struct Barrett { | |
int m, k, mod; | |
Barrett(int mod) : mod(mod) { | |
k = 0; | |
int p2 = 1; | |
while (p2 <= mod) | |
p2 <<= 1, k++; | |
m = (int)((ll)p2 * p2 / mod); | |
} | |
int modmul(int A, int B) { | |
ll a = A * B; | |
ll q = (m * a) >> (2*k); | |
int r = (int)(a - q * mod); | |
if (r >= mod) r -= mod; | |
return r; | |
} | |
}; | |
struct N { | |
int x = 0; | |
}; | |
struct Mont { | |
int Mod, R1Mod, R2Mod, NPrime; | |
Mont(int mod); | |
N redc(int a, int b); | |
N raw(int x) { N r; r.x = x; return r; } | |
N from(int x) { assert (x < Mod); return redc(x, R2Mod); } | |
N one() { return raw(R1Mod); } | |
int get(N a) { return redc(a.x, 1).x; } | |
N mul(N a, N b) { return redc(a.x, b.x); } | |
N add(N a, N b) { int x = a.x + b.x; if (x >= Mod) x -= Mod; return raw(x); } | |
N sub(N a, N b) { int x = a.x - b.x; if (a.x < b.x) x += Mod; return raw(x); } | |
}; | |
Mont::Mont(int mod) : Mod(mod) { | |
const ll B = (1LL << 32); | |
assert((mod & 1) != 0); | |
ll R = B % mod; | |
ll xinv = 1, bit = 2; | |
for (int i = 1; i < 32; i++, bit <<= 1) { // Hensel lifting! | |
ll y = xinv * mod; | |
if ((y & bit) != 0) | |
xinv |= bit; | |
} | |
assert(((mod * xinv) & (B-1)) == 1); | |
R1Mod = (int)R; | |
R2Mod = (int)(R * R % mod); | |
NPrime = (int)(B - xinv); | |
} | |
N Mont::redc(int a, int b) { | |
ll T = (ll)a * b; | |
ll m = (unsigned)T * NPrime; | |
T += m * Mod; | |
T >>= 32; | |
if (T >= Mod) | |
T -= Mod; | |
return raw((int)T); | |
} | |
int main(int argc, char** argv) { | |
const int mod = atoi(argv[1]); | |
Barrett bt(mod); | |
Mont mt(mod); | |
cerr << mt.redc(123, 321).x << endl; | |
ll sum = 0; | |
for (int a = 0; a < mod-2; ++a) { | |
for (int b = 0; b < mod-2; ++b) { | |
// sum += a * b % mod; // 4s | |
// sum += bt.modmul(a, b); // 2s (but only 21 bits) | |
sum += mt.redc(a, b).x; // 2s | |
} | |
} | |
sum %= mod; | |
sum = mt.from((int)sum).x; | |
cout << sum << endl; | |
return 0; | |
} | |
int main2(int argc, char** argv) { | |
const int mod = 202171241; // atoi(argv[1]); | |
ll ret = 1; | |
for (int i = 1; i < mod; i++) { | |
ret = (ret * i) % mod; | |
} | |
cout << ret << endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment