Last active
March 1, 2021 11:46
-
-
Save Mivik/34d2e8d7f718b915642d6a7b8081c56c to your computer and use it in GitHub Desktop.
mod_solver
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Mivik 2021.3.1 | |
#include <algorithm> | |
#include <cassert> | |
#include <cctype> | |
#include <random> | |
#include <vector> | |
typedef long long qe; | |
const int mod = 998244353; | |
inline int add(int x, int y) { return (x += y) >= mod? x - mod: x; } | |
inline void Add(int &x, int y) { if ((x += y) >= mod) x -= mod; } | |
inline int sub(int x, int y) { return (x -= y) < 0? x + mod: x; } | |
inline void Sub(int &x, int y) { if ((x -= y) < 0) x += mod; } | |
inline qe ksm(qe x, int p = mod - 2) { | |
qe ret = 1; | |
for (; p; p >>= 1, (x *= x) %= mod) if (p & 1) (ret *= x) %= mod; | |
return ret; | |
} | |
// coco: preserve_begin | |
inline void ntt(int *v, int len, bool rev) { | |
static const int N = 19; // So we support polynomials containing up to (2 ^ (N - 1)) terms. | |
static struct _ | |
{ int rt[1 << N], inv[N + 1]; _() { | |
for (int i = 0, j = 1, q = 0; i <= N; ++i, j = (q = j) << 1) { | |
inv[i] = ksm(j); | |
const int del = ksm(3, (mod - 1) / j); | |
for (int k = 0, c = 1; k < q; ++k, c = (qe)c * del % mod) rt[q | k] = c; | |
} | |
} } _; | |
static int last, r[1 << N]; | |
if (last != len) { | |
for (int i = 1; i < len; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1)? (len >> 1): 0); | |
last = len; | |
} | |
for (int i = 1; i < len; ++i) | |
if (i < r[i]) std::swap(v[i], v[r[i]]); | |
assert(!(len & (len - 1))); | |
for (int i = 1, q = 2; i < len; i = q, q <<= 1) { | |
for (int j = 0; j < len; j += q) | |
for (int k = 0; k < i; ++k) { | |
const int x = v[j | k], y = (qe)v[i | j | k] * _.rt[i | k] % mod; | |
v[j | k] = add(x, y); | |
v[i | j | k] = sub(x, y); | |
} | |
} | |
if (!rev) return; | |
std::reverse(v + 1, v + len); | |
for (int i = 0, r = _.inv[__builtin_ctz(len)]; i < len; ++i) v[i] = (qe)v[i] * r % mod; | |
} | |
struct poly : public std::vector<int> { | |
#define v (*this) | |
#define F(n) \ | |
inline poly n(int len = -1) const { return poly().from_##n(v, len); } \ | |
inline poly& from_##n(const poly &a, int len = -1) | |
#define I(n) \ | |
inline poly n() const { return poly(v).inplace_##n(); } \ | |
inline poly& inplace_##n() | |
static inline int round_up(int x) { return (x & (x - 1))? (1 << (32 - __builtin_clz(x))): x; } | |
poly() {} | |
poly(const std::initializer_list<int> &list): std::vector<int>(list) {} | |
inline poly& from(const poly &a, int n) { | |
resize(n); const int len = std::min(n, (int)a.size()); | |
std::copy(a.begin(), a.begin() + len, begin()); | |
std::fill(begin() + len, end(), 0); | |
return v; | |
} | |
inline poly take(int n) const { return poly().from(v, n); } | |
inline poly& ntt(int len, bool rev) { resize(round_up(len), 0); ::ntt(data(), size(), rev); return v; } | |
inline poly& ntt(bool rev) { return ntt(size(), rev); } | |
inline poly& trim() { while (!empty() && !back()) pop_back(); return v; } | |
inline poly& imul(const poly &a, int l) { | |
static poly tmp; | |
if (empty() || a.empty()) return clear(), v; | |
const int len = round_up(l); | |
(tmp = a).ntt(len, 0); ntt(len, 0); | |
for (int i = 0; i < len; ++i) v[i] = (qe)v[i] * tmp[i] % mod; | |
return ntt(len, 1), resize(l), v; | |
} | |
inline poly& operator*=(const poly &a) { return imul(a, size() + a.size() - 1); } | |
inline poly operator*(const poly &t) const { return poly(v) *= t; } | |
I(reverse) { return std::reverse(begin(), end()), v; } | |
template<class Func> | |
inline poly& newton(const poly &a, int len, int initial, const Func &trans) { | |
static poly tmp; | |
if (len == -1) len = a.size(); | |
if (a.empty()) return clear(), v; | |
const int lim = round_up(len); | |
clear(); push_back(initial); | |
for (int l = 2; l <= lim; l <<= 1) { tmp.from(a, l); trans(l, tmp); } | |
return resize(len), v; | |
} | |
F(inv) { | |
assert(!a.empty()); | |
return newton(a, len, ksm(a[0]), [this](int l, poly &tmp) { | |
const int q = l << 1; | |
tmp.ntt(q, 0); ntt(q, 0); | |
for (int i = 0; i < q; ++i) | |
v[i] = (qe)v[i] * sub(2, (qe)v[i] * tmp[i] % mod) % mod; | |
ntt(1); resize(l); | |
}); | |
} | |
inline poly& from_division(const poly &a, const poly &b) { | |
const int len = a.size() - b.size() + 1; assert(len >= 0); | |
resize(len); std::reverse_copy(a.end() - len, a.end(), begin()); | |
v *= b.reverse().inv(len); | |
return resize(len), inplace_reverse(); | |
} | |
inline poly& from_remainder(const poly &a, const poly &b, const poly &q) { | |
const int m = b.size() - 1; | |
from(b, m) *= q.take(m); resize(m); | |
for (int i = 0; i < m; ++i) v[i] = sub(a[i], v[i]); | |
return v; | |
} | |
inline poly operator/(const poly &a) const { return poly().from_division(v, a); } | |
inline poly operator%(const poly &a) const { return poly().from_remainder(v, a, v / a); } | |
#undef v | |
#undef F | |
#undef I | |
}; | |
inline void print(const poly &A) { | |
putchar('['); | |
for (int v : A) printf(" %d", v); | |
printf(" ]\n"); | |
} | |
// coco: preserve_end | |
const int T = (mod - 1) / 2; | |
poly F; | |
std::vector<int> fac, ifac; | |
std::vector<int> prs, low; | |
std::vector<bool> comp; | |
inline void init(int n) { | |
fac.resize(n + 1); ifac.resize(n + 1); | |
for (int i = fac[0] = ifac[0] = 1; i <= n; ++i) fac[i] = (qe)fac[i - 1] * i % mod; | |
ifac[n] = ksm(fac[n]); for (int i = n; i > 1; --i) ifac[i - 1] = (qe)ifac[i] * i % mod; | |
low.resize(n + 1); comp.resize(n + 1); | |
for (int i = 2; i <= n; ++i) { | |
if (!comp[i]) { prs.push_back(i); low[i] = i; } | |
for (int j = 0, k; j < prs.size() && (k = i * prs[j]) <= n; ++j) { | |
comp[k] = true; low[k] = prs[j]; | |
if (!(i % prs[j])) break; | |
} | |
} | |
} | |
inline int read_mod_int() { | |
char lst = '?', c = getchar(); | |
while (c != -1 && !isdigit(c)) { lst = c; c = getchar(); } | |
const bool f = lst == '-'; int r = 0; | |
do { r = ((qe)r * 10 + c - '0') % mod; c = getchar(); } while (isdigit(c)); | |
ungetc(c, stdin); if (f) r = sub(0, r); return r; | |
} | |
void offset(poly &A, int c) { | |
static poly F, G; | |
const int n = A.size(); | |
F.resize(n); G.resize(n); | |
for (int i = 0, cur = 1; i < n; ++i) { | |
G[i] = (qe)A[i] * fac[i] % mod; | |
F[n - i - 1] = (qe)cur * ifac[i] % mod; | |
cur = (qe)cur * c % mod; | |
} | |
F *= G; assert(F.size() == n * 2 - 1); | |
for (int i = 0; i < n; ++i) A[i] = (qe)F[i + n - 1] * ifac[i] % mod; | |
} | |
inline poly x_k_mod(qe k, const poly &A, poly cur = { 0, 1 }) { | |
if (cur.empty()) return cur; | |
if (cur.size() == 2 && cur[0] == 0 && cur[1] == 1 && k < A.size()) { poly r; r.resize(k + 1); r[k] = 1; return r; } | |
auto try_mod = [&A](poly &B) { if (B.size() >= A.size()) (B = B % A).trim(); }; | |
poly ret = { 1 }; | |
for (qe p = k; p; p >>= 1, try_mod(cur *= cur)) if (p & 1) try_mod(ret *= cur); | |
return ret; | |
} | |
inline poly monic(const poly &A) { | |
poly r; r.resize(A.size()); const int inv = ksm(A.back()); | |
for (int i = r.size() - 1; ~i; --i) r[i] = (qe)A[i] * inv % mod; | |
return r; | |
} | |
inline poly gcd(poly A, poly B) { | |
if (A.empty()) return B; | |
if (B.empty()) return A; | |
while (true) { | |
(A = A % B).trim(); | |
if (A.empty()) return monic(B); | |
// TODO I guess that's right XD | |
std::swap<std::vector<int>>(A, B); | |
} | |
} | |
inline poly gcd(const poly &A) { | |
auto B = x_k_mod(T, A); Sub(B[0], 1); | |
B.trim(); return gcd(A, B); | |
} | |
inline bool is_reducible(const poly &A) { | |
const int n = A.size() - 1; if (n <= 1) return true; | |
auto calc = [&A](int k) { // (x ^ (mod ^ k) - x) % A | |
poly tmp = { 0, 1 }; | |
while (k--) tmp = x_k_mod(mod, A, tmp); | |
if (tmp.size() < 2) tmp.resize(2); | |
Sub(tmp[1], 1); return tmp.trim(); | |
}; | |
auto tmp = calc(n); if (!tmp.empty()) return true; | |
for (int cur = n; cur != 1; ) { | |
const int p = low[cur]; | |
tmp = gcd(calc(n / p), A); | |
if (!(tmp.size() == 1 && tmp[0] == 1)) return true; | |
do cur /= p; while (!(cur % p)); | |
} | |
return false; | |
} | |
std::vector<int> solve(poly A) { | |
static std::mt19937 rng(std::random_device{}()); | |
static std::uniform_int_distribution<int> dist(1, mod - 1); | |
if (A.size() == 1) return {}; | |
if (A.size() == 2) return { (int)((qe)sub(0, A[0]) * ksm(A[1]) % mod) }; | |
if (!is_reducible(A)) return {}; | |
int off = 0; | |
while (true) { | |
const auto g = gcd(A); | |
if (g.size() == 1 || g == A) { | |
const int c = dist(rng); offset(A, c); | |
Add(off, c); | |
continue; | |
} | |
const auto div = ({ // Could just be (A / g), but this would be more efficient (maybe?) | |
const int o = A.size(); | |
A *= g.inv(o); A.resize(o); | |
A.trim(); | |
}); | |
auto p1 = solve(g), p2 = solve(div); | |
p1.insert(p1.end(), p2.begin(), p2.end()); | |
for (int &v : p1) Add(v, off); | |
std::sort(p1.begin(), p1.end()); | |
p1.erase(std::unique(p1.begin(), p1.end()), p1.end()); | |
return p1; | |
} | |
} | |
int main() { | |
printf("Degree of the polynomial: "); fflush(stdout); | |
const int n = ({ int x; scanf("%d", &x); x; }); | |
if (n < 0) { fprintf(stderr, "Negative degree?\n"); return -1; } | |
if (n == 0) { fprintf(stderr, "What? A constant?\n"); return -1; } | |
init(n); F.resize(n + 1); | |
printf("Coefficients (0 ~ %d): ", n - 1); fflush(stdout); | |
for (int &v : F) v = read_mod_int(); | |
if (F.trim().size() <= 1) { fprintf(stderr, "What? A constant?\n"); return -1; } | |
printf("Roots: "); | |
for (int rt : solve(F)) printf("%d ", rt); | |
putchar('\n'); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment