Skip to content

Instantly share code, notes, and snippets.

@Mivik
Last active March 1, 2021 11:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Mivik/34d2e8d7f718b915642d6a7b8081c56c to your computer and use it in GitHub Desktop.
Save Mivik/34d2e8d7f718b915642d6a7b8081c56c to your computer and use it in GitHub Desktop.
mod_solver
// Mivik 2021.3.1
#include <algorithm>
#include <cassert>
#include <cctype>
#include <random>
#include <vector>
typedef long long qe;
const int mod = 998244353;
inline int add(int x, int y) { return (x += y) >= mod? x - mod: x; }
inline void Add(int &x, int y) { if ((x += y) >= mod) x -= mod; }
inline int sub(int x, int y) { return (x -= y) < 0? x + mod: x; }
inline void Sub(int &x, int y) { if ((x -= y) < 0) x += mod; }
inline qe ksm(qe x, int p = mod - 2) {
qe ret = 1;
for (; p; p >>= 1, (x *= x) %= mod) if (p & 1) (ret *= x) %= mod;
return ret;
}
// coco: preserve_begin
inline void ntt(int *v, int len, bool rev) {
static const int N = 19; // So we support polynomials containing up to (2 ^ (N - 1)) terms.
static struct _
{ int rt[1 << N], inv[N + 1]; _() {
for (int i = 0, j = 1, q = 0; i <= N; ++i, j = (q = j) << 1) {
inv[i] = ksm(j);
const int del = ksm(3, (mod - 1) / j);
for (int k = 0, c = 1; k < q; ++k, c = (qe)c * del % mod) rt[q | k] = c;
}
} } _;
static int last, r[1 << N];
if (last != len) {
for (int i = 1; i < len; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1)? (len >> 1): 0);
last = len;
}
for (int i = 1; i < len; ++i)
if (i < r[i]) std::swap(v[i], v[r[i]]);
assert(!(len & (len - 1)));
for (int i = 1, q = 2; i < len; i = q, q <<= 1) {
for (int j = 0; j < len; j += q)
for (int k = 0; k < i; ++k) {
const int x = v[j | k], y = (qe)v[i | j | k] * _.rt[i | k] % mod;
v[j | k] = add(x, y);
v[i | j | k] = sub(x, y);
}
}
if (!rev) return;
std::reverse(v + 1, v + len);
for (int i = 0, r = _.inv[__builtin_ctz(len)]; i < len; ++i) v[i] = (qe)v[i] * r % mod;
}
struct poly : public std::vector<int> {
#define v (*this)
#define F(n) \
inline poly n(int len = -1) const { return poly().from_##n(v, len); } \
inline poly& from_##n(const poly &a, int len = -1)
#define I(n) \
inline poly n() const { return poly(v).inplace_##n(); } \
inline poly& inplace_##n()
static inline int round_up(int x) { return (x & (x - 1))? (1 << (32 - __builtin_clz(x))): x; }
poly() {}
poly(const std::initializer_list<int> &list): std::vector<int>(list) {}
inline poly& from(const poly &a, int n) {
resize(n); const int len = std::min(n, (int)a.size());
std::copy(a.begin(), a.begin() + len, begin());
std::fill(begin() + len, end(), 0);
return v;
}
inline poly take(int n) const { return poly().from(v, n); }
inline poly& ntt(int len, bool rev) { resize(round_up(len), 0); ::ntt(data(), size(), rev); return v; }
inline poly& ntt(bool rev) { return ntt(size(), rev); }
inline poly& trim() { while (!empty() && !back()) pop_back(); return v; }
inline poly& imul(const poly &a, int l) {
static poly tmp;
if (empty() || a.empty()) return clear(), v;
const int len = round_up(l);
(tmp = a).ntt(len, 0); ntt(len, 0);
for (int i = 0; i < len; ++i) v[i] = (qe)v[i] * tmp[i] % mod;
return ntt(len, 1), resize(l), v;
}
inline poly& operator*=(const poly &a) { return imul(a, size() + a.size() - 1); }
inline poly operator*(const poly &t) const { return poly(v) *= t; }
I(reverse) { return std::reverse(begin(), end()), v; }
template<class Func>
inline poly& newton(const poly &a, int len, int initial, const Func &trans) {
static poly tmp;
if (len == -1) len = a.size();
if (a.empty()) return clear(), v;
const int lim = round_up(len);
clear(); push_back(initial);
for (int l = 2; l <= lim; l <<= 1) { tmp.from(a, l); trans(l, tmp); }
return resize(len), v;
}
F(inv) {
assert(!a.empty());
return newton(a, len, ksm(a[0]), [this](int l, poly &tmp) {
const int q = l << 1;
tmp.ntt(q, 0); ntt(q, 0);
for (int i = 0; i < q; ++i)
v[i] = (qe)v[i] * sub(2, (qe)v[i] * tmp[i] % mod) % mod;
ntt(1); resize(l);
});
}
inline poly& from_division(const poly &a, const poly &b) {
const int len = a.size() - b.size() + 1; assert(len >= 0);
resize(len); std::reverse_copy(a.end() - len, a.end(), begin());
v *= b.reverse().inv(len);
return resize(len), inplace_reverse();
}
inline poly& from_remainder(const poly &a, const poly &b, const poly &q) {
const int m = b.size() - 1;
from(b, m) *= q.take(m); resize(m);
for (int i = 0; i < m; ++i) v[i] = sub(a[i], v[i]);
return v;
}
inline poly operator/(const poly &a) const { return poly().from_division(v, a); }
inline poly operator%(const poly &a) const { return poly().from_remainder(v, a, v / a); }
#undef v
#undef F
#undef I
};
inline void print(const poly &A) {
putchar('[');
for (int v : A) printf(" %d", v);
printf(" ]\n");
}
// coco: preserve_end
const int T = (mod - 1) / 2;
poly F;
std::vector<int> fac, ifac;
std::vector<int> prs, low;
std::vector<bool> comp;
inline void init(int n) {
fac.resize(n + 1); ifac.resize(n + 1);
for (int i = fac[0] = ifac[0] = 1; i <= n; ++i) fac[i] = (qe)fac[i - 1] * i % mod;
ifac[n] = ksm(fac[n]); for (int i = n; i > 1; --i) ifac[i - 1] = (qe)ifac[i] * i % mod;
low.resize(n + 1); comp.resize(n + 1);
for (int i = 2; i <= n; ++i) {
if (!comp[i]) { prs.push_back(i); low[i] = i; }
for (int j = 0, k; j < prs.size() && (k = i * prs[j]) <= n; ++j) {
comp[k] = true; low[k] = prs[j];
if (!(i % prs[j])) break;
}
}
}
inline int read_mod_int() {
char lst = '?', c = getchar();
while (c != -1 && !isdigit(c)) { lst = c; c = getchar(); }
const bool f = lst == '-'; int r = 0;
do { r = ((qe)r * 10 + c - '0') % mod; c = getchar(); } while (isdigit(c));
ungetc(c, stdin); if (f) r = sub(0, r); return r;
}
void offset(poly &A, int c) {
static poly F, G;
const int n = A.size();
F.resize(n); G.resize(n);
for (int i = 0, cur = 1; i < n; ++i) {
G[i] = (qe)A[i] * fac[i] % mod;
F[n - i - 1] = (qe)cur * ifac[i] % mod;
cur = (qe)cur * c % mod;
}
F *= G; assert(F.size() == n * 2 - 1);
for (int i = 0; i < n; ++i) A[i] = (qe)F[i + n - 1] * ifac[i] % mod;
}
inline poly x_k_mod(qe k, const poly &A, poly cur = { 0, 1 }) {
if (cur.empty()) return cur;
if (cur.size() == 2 && cur[0] == 0 && cur[1] == 1 && k < A.size()) { poly r; r.resize(k + 1); r[k] = 1; return r; }
auto try_mod = [&A](poly &B) { if (B.size() >= A.size()) (B = B % A).trim(); };
poly ret = { 1 };
for (qe p = k; p; p >>= 1, try_mod(cur *= cur)) if (p & 1) try_mod(ret *= cur);
return ret;
}
inline poly monic(const poly &A) {
poly r; r.resize(A.size()); const int inv = ksm(A.back());
for (int i = r.size() - 1; ~i; --i) r[i] = (qe)A[i] * inv % mod;
return r;
}
inline poly gcd(poly A, poly B) {
if (A.empty()) return B;
if (B.empty()) return A;
while (true) {
(A = A % B).trim();
if (A.empty()) return monic(B);
// TODO I guess that's right XD
std::swap<std::vector<int>>(A, B);
}
}
inline poly gcd(const poly &A) {
auto B = x_k_mod(T, A); Sub(B[0], 1);
B.trim(); return gcd(A, B);
}
inline bool is_reducible(const poly &A) {
const int n = A.size() - 1; if (n <= 1) return true;
auto calc = [&A](int k) { // (x ^ (mod ^ k) - x) % A
poly tmp = { 0, 1 };
while (k--) tmp = x_k_mod(mod, A, tmp);
if (tmp.size() < 2) tmp.resize(2);
Sub(tmp[1], 1); return tmp.trim();
};
auto tmp = calc(n); if (!tmp.empty()) return true;
for (int cur = n; cur != 1; ) {
const int p = low[cur];
tmp = gcd(calc(n / p), A);
if (!(tmp.size() == 1 && tmp[0] == 1)) return true;
do cur /= p; while (!(cur % p));
}
return false;
}
std::vector<int> solve(poly A) {
static std::mt19937 rng(std::random_device{}());
static std::uniform_int_distribution<int> dist(1, mod - 1);
if (A.size() == 1) return {};
if (A.size() == 2) return { (int)((qe)sub(0, A[0]) * ksm(A[1]) % mod) };
if (!is_reducible(A)) return {};
int off = 0;
while (true) {
const auto g = gcd(A);
if (g.size() == 1 || g == A) {
const int c = dist(rng); offset(A, c);
Add(off, c);
continue;
}
const auto div = ({ // Could just be (A / g), but this would be more efficient (maybe?)
const int o = A.size();
A *= g.inv(o); A.resize(o);
A.trim();
});
auto p1 = solve(g), p2 = solve(div);
p1.insert(p1.end(), p2.begin(), p2.end());
for (int &v : p1) Add(v, off);
std::sort(p1.begin(), p1.end());
p1.erase(std::unique(p1.begin(), p1.end()), p1.end());
return p1;
}
}
int main() {
printf("Degree of the polynomial: "); fflush(stdout);
const int n = ({ int x; scanf("%d", &x); x; });
if (n < 0) { fprintf(stderr, "Negative degree?\n"); return -1; }
if (n == 0) { fprintf(stderr, "What? A constant?\n"); return -1; }
init(n); F.resize(n + 1);
printf("Coefficients (0 ~ %d): ", n - 1); fflush(stdout);
for (int &v : F) v = read_mod_int();
if (F.trim().size() <= 1) { fprintf(stderr, "What? A constant?\n"); return -1; }
printf("Roots: ");
for (int rt : solve(F)) printf("%d ", rt);
putchar('\n');
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment