Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active October 18, 2024 00:06
Show Gist options
  • Save Chillee/4982ba0840745f63f3771bd84f280557 to your computer and use it in GitHub Desktop.
Save Chillee/4982ba0840745f63f3771bd84f280557 to your computer and use it in GitHub Desktop.
FFT
template <int maxn> struct FFT {
constexpr static int lg2(int n) { return 32 - __builtin_clz(n - 1); }
const static int MAXN = 1 << lg2(maxn);
typedef complex<double> cpx;
int rev[MAXN];
cpx rt[MAXN];
FFT() {
rt[1] = cpx{1, 0};
for (int k = 2; k < MAXN; k *= 2) {
cpx z[] = {1, polar(1.0, M_PI / k)};
for (int i = k; i < 2 * k; i++)
rt[i] = rt[i / 2] * z[i & 1];
}
}
void fft(cpx *a, int n) {
for (int i = 0; i < n; i++)
rev[i] = (rev[i / 2] | (i & 1) << lg2(n)) / 2;
for (int i = 0; i < n; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k)
for (int j = 0; j < k; j++) {
auto x = (double *)&rt[j + k], y = (double *)&a[i + j + k];
cpx z(x[0] * y[0] - x[1] * y[1], x[0] * y[1] + x[1] * y[0]);
a[i + j + k] = a[i + j] - z;
a[i + j] += z;
}
}
cpx in[MAXN], out[MAXN];
vector<double> multiply(const vector<double> &a, const vector<double> &b) {
fill(all(in), cpx{0, 0}), fill(all(out), cpx{0, 0});
if (a.empty() || b.empty())
return {};
int sz = a.size() + b.size() - 1, n = 1 << lg2(sz);
vector<double> res(sz);
copy(all(a), begin(in));
for (int i = 0; i < b.size(); i++)
in[i].imag(b[i]);
fft(in, n);
for (int i = 0; i < n; i++)
in[i] *= in[i];
for (int i = 0; i < n; i++)
out[i] = in[(n - i) & (n - 1)] - conj(in[i]);
fft(out, n);
for (int i = 0; i < sz; i++)
res[i] = out[i].imag() / (4 * n);
return res;
}
};
template <int maxn, int MOD> struct FFTMod {
constexpr static int lg2(int n) { return 32 - __builtin_clz(n - 1); }
const static int MAXN = 1 << lg2(maxn);
typedef complex<double> cpx;
int rev[MAXN];
cpx rt[MAXN];
FFTMod() {
rt[1] = cpx{1, 0};
for (int k = 2; k < MAXN; k *= 2) {
cpx z[] = {1, polar(1.0, M_PI / k)};
for (int i = k; i < 2 * k; i++)
rt[i] = rt[i / 2] * z[i & 1];
}
}
void fft(cpx *a, int n) {
for (int i = 0; i < n; i++)
rev[i] = (rev[i / 2] | (i & 1) << lg2(n)) / 2;
for (int i = 0; i < n; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k)
for (int j = 0; j < k; j++) {
auto x = (double *)&rt[j + k], y = (double *)&a[i + j + k];
cpx z(x[0] * y[0] - x[1] * y[1], x[0] * y[1] + x[1] * y[0]);
a[i + j + k] = a[i + j] - z;
a[i + j] += z;
}
}
cpx in[2][MAXN], out[2][MAXN];
vector<ll> multiply(const vector<int> &a, const vector<int> &b) {
fill(all(in[0]), cpx{0, 0}), fill(all(in[1]), cpx{0, 0});
int cut = sqrt(MOD), sz = a.size() + b.size() - 1;
int n = 1 << lg2(sz);
vector<ll> res(sz);
for (int i = 0; i < a.size(); i++)
in[0][i] = {a[i] / cut, a[i] % cut};
for (int i = 0; i < b.size(); i++)
in[1][i] = {b[i] / cut, b[i] % cut};
fft(in[0], n), fft(in[1], n);
for (int i = 0; i < n; i++) {
int j = (n - i) & (n - 1);
cpx fl = (in[0][i] + conj(in[0][j])) * cpx{0.5, 0}, fs = (in[0][i] - conj(in[0][j])) * cpx{0, -0.5},
gl = (in[1][i] + conj(in[1][j])) * cpx{0.5, 0}, gs = (in[1][i] - conj(in[1][j])) * cpx{0, -0.5};
out[0][-i & (n - i)] = (fl * gl) + (fl * gs) * cpx{0, 1};
out[1][-i & (n - i)] = (fs * gl) + (fs * gs) * cpx{0, 1};
}
fft(out[0], n), fft(out[1], n);
for (int i = 0; i < sz; i++) {
out[0][i] /= n, out[1][i] /= n;
ll av = round(out[0][i].real());
ll bv = round(out[0][i].imag()) + round(out[1][i].real());
ll cv = round(out[1][i].imag());
av %= MOD, bv %= MOD, cv %= MOD;
res[i] = av * cut * cut + bv * cut + cv;
res[i] = (res[i] % MOD + MOD) % MOD;
}
return res;
}
};
template <int maxn> struct NTT {
constexpr static int lg2(int n) { return 32 - __builtin_clz(n - 1); }
const static int MAXN = 1 << lg2(maxn), MOD = 998244353, root = 3;
int rev[MAXN], rt[MAXN];
int mul(int a, int b) { return (long long)a * b % MOD; }
int sub(int a, int b) { return b > a ? a - b + MOD : a - b; }
int add(int a, int b) { return a + b >= MOD ? a + b - MOD : a + b; }
int binExp(int base, long long exp) {
if (exp == 0)
return 1;
return mul(binExp(mul(base, base), exp / 2), exp & 1 ? base : 1);
}
NTT() {
rt[1] = 1;
for (int k = 1; k < lg2(MAXN); k++) {
int z[] = {1, binExp(root, (MOD - 1) >> (k + 1))};
for (int i = (1 << k); i < 2 << k; i++)
rt[i] = mul(rt[i / 2], z[i & 1]);
}
}
void ntt(int *a, int n) {
for (int i = 0; i < n; i++)
rev[i] = (rev[i / 2] | (i & 1) << lg2(n)) / 2;
for (int i = 0; i < n; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k)
for (int j = 0; j < k; j++) {
int z = mul(rt[j + k], a[i + j + k]);
a[i + j + k] = sub(a[i + j], z);
a[i + j] = add(a[i + j], z);
}
}
int in[2][MAXN];
vector<int> multiply(const vector<int> &a, const vector<int> &b) {
fill(all(in[0]), 0), fill(all(in[1]), 0);
if (a.empty() || b.empty())
return {};
int sz = a.size() + b.size() - 1, n = 1 << lg2(sz);
copy(all(a), in[0]), copy(all(b), in[1]);
ntt(in[0], n), ntt(in[1], n);
int invN = binExp(n, MOD - 2);
for (int i = 0; i < n; i++)
in[0][i] = mul(mul(in[0][i], in[1][i]), invN);
reverse(in[0] + 1, in[0] + n);
ntt(in[0], n);
return vector<int>(in[0], in[0] + sz);
}
};
@raybbian
Copy link

raybbian commented Oct 18, 2024

If repeatedly multiplying polynomials (such as FFT D&C Knapsack), the above implementations need to be modified. They currently clear the O(MAXN) sized arrays every time multiply is called.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment