Skip to content

Instantly share code, notes, and snippets.

@MaskRay
Last active August 16, 2019 04:46
Show Gist options
  • Save MaskRay/258495af940ca1d50fd1a26088403b38 to your computer and use it in GitHub Desktop.
Save MaskRay/258495af940ca1d50fd1a26088403b38 to your computer and use it in GitHub Desktop.
FFT技巧
https://www.hackerrank.com/contests/w23/challenges/sasha-and-swaps-ii
感谢ftiasch老师教导,参考了 https://async.icpc-camp.org/d/408-fft 和其他一些地方的东西
P = 1e9+7
Q = ceil(sqrt(P))
a, b为两个向量
c = convolution(a, b)
c系数取值为 [0, n*(P-1)^2] 的整数,若n*(P-1)^2的表示需要超过53 bits(double mantissa)则很可能会出错
通常用的 IFFT 是 unnormalized 的,中间结果是 [0, n^2*(P-1)^2] 的整数,但因为涉及到 `x /= n` 所以只要求能表示 [0, n*(P-1)^2] 的整数
# 技巧
## 技巧0:折半
分解a = a0 + Q * a1, 0 <= a0,a1 < Q
求出
c00 = convolution(a0, b0)
c01 = convolution(a0, b1)
c10 = convolution(a1, b0)
c11 = convolution(a1, b1)
以上四个 convolution 系数最大值为 n*(Q-1)^2 ~= n*P
c = c00 + Q*(c01+c10) + Q*Q*c11
使用Toom-2 (Karatsuba)可以简化为三次convolution
## 技巧1:[0,P-1] => [-(P-1)/2, (P-1)/2]
设 a,b 系数取自 [0,P-1] 的 uniform distribution
则 c 系数均值阶为 np^2/4,方差阶为 np^4/9
若平移至 [-(P-1)/2, (P-1)/2]
则 c 系数均值 0,方差阶为 np^4/144
由Chebyshev's inequality,系数绝对值在若干倍标准差以内
## 技巧2a:正交地计算两个FFT;辅助技巧0
取S与sqrt(P)接近且 M=P-S*S%P 尽可能小
分解 a = a0 + S * a1, b = b0 + S * b1
用两次FFT一次IFFT计算 convolution(a0+i*sqrt(M)*a1, b0+i*sqrt(M)*b1) 即得到
convolution(a0,b0) - M*convolution(a1,b1) + i*sqrt(M)*(convolution(a0,b1)+convolution(a1,b0))
分离real和imag即可算出 c = convolution(a0, b0) + S * (convolution(a0, b1) + convolution(a1, b0)) - M * convolution(a1, b1)
## 技巧2b
效率比技巧2a略低,用两次FFT和两次IFFT,但系数绝对值更小
分解 a = a0 + Q * a1, b = b0 + Q * b1
记 A = a0+i*a1, B = b0+i*b1,函数 rev(a) = {a[0], a[n-1], a[n-2], ..., a[1]}
计算 FFT(A) 与 FFT(B) 后求出:
- FFT(a0) = FFT(re(A)) = [FFT(A) + FFT(conj(A))] / 2 = [FFT(A) + conj(rev(FFT(A)))] / 2
- FFT(a1) = FFT(im(A)) = [FFT(A) - FFT(conj(A))] * -0.5i = [FFT(A) - conj(rev(FFT(A)))] * -0.5i
- FFT(b0) = FFT(re(B)) = [FFT(B) + FFT(conj(B))] / 2 = [FFT(B) + conj(rev(FFT(B)))] / 2
- FFT(b1) = FFT(im(B)) = [FFT(B) - FFT(conj(B))] * -0.5i = [FFT(B) - conj(rev(FFT(B)))] * -0.5i
再用 IFFT 计算:
convolution(a0, b0) + i * convolution(a0, b1) = IFFT(FFT(a0)*FFT(b0) + i*FFT(a0)*FFT(b1))
convolution(a1, b0) + i * convolution(a1, b1) = IFFT(FFT(a1)*FFT(b0) + i*FFT(a1)*FFT(b1))
分离real和imag即可算出 c = convolution(a0, b0) + Q * (convolution(a0, b1) + convolution(a1, b0)) + Q * Q * convolution(a1, b1)
另外,用Haskell的记号:ifft = (/n) . fft . rev
https://www.hackerrank.com/rest/contests/w23/challenges/sasha-and-swaps-ii/hackers/Hezhu/download_solution
# 题目
sasha-and-swaps-ii 题中 P = 10^9+7,取S=10^5,M=70
绝对值<=P-1的原系数调整为 (floor((P-1)/2/Q) + i*M*(Q-1))
结果的real绝对值最大值约为 n * ((P/Q)^2/4+M*Q^2),55.96 bits,但均值为0取到最大值可能性极低
防止最坏情况下出问题,可以取 independent and identically distributed 随机数向量 noise
convolution(a, b) = convolution(a+noise, b) - convolution(noise, b)
# 其他
根据 Roundoff Error Analysis of the Fast Fourier Transform,没仔细看
relative error 均值为 log2(n)*浮点运算精度*变换前系数最大值
哪里看到的,unit root一定要用cos(2*M_PI/n*m) sin(2*M_PI/n*m)或者lookup table,用乘法`w *= dw`计算会使误差达到指数级。我觉得可能指误差达到 n*浮点运算精度*变换前系数最大值
有了技巧0+1+2,感觉 complex<double> 的 Fast Fourier transform 恒优于 Fast number theoretic transform
涉及 int64 时,a*b % m 性能很差。`long x = a*b, r = x - mod*long(double(a)*double(b)/mod+0.5); return r < 0 ? r + mod : r;`比 汇编MUL DIV 快,但还是不如 complex<double>
代码中调用 complex<double>::operator* 的地方会编译为 call __muldc3,__muldc3 会判断NAN INF,有很多多余操作,性能很低。如果编译时带上 -ffast-math 可以快很多,但 __attribute__((optimize("fast-math"))) 这些没有效果,因为 __muldc3 在其他文件中不受到影响。最好自行实现 complex<double> 的乘法
#include <cmath>
#include <complex>
#include <iostream>
#include <type_traits>
#include <utility>
#include <vector>
using namespace std;
typedef complex<double> cd;
#define FOR(i, a, b) for (remove_cv<remove_reference<decltype(b)>::type>::type i = (a); i < (b); i++)
#define REP(i, n) FOR(i, 0, n)
#define ROF(i, a, b) for (remove_cv<remove_reference<decltype(b)>::type>::type i = (b); --i >= (a); )
const long MOD = 1000000007, SQ = 100000, NN = 262144;
const double IROOT = sqrt(double(MOD-SQ*SQ%MOD));
long bitrev[NN];
cd units[NN];
void fft_prepare(long n)
{
long logn = 63-__builtin_clzl(n);
REP(i, n)
bitrev[i] = bitrev[i>>1] >> 1 | (i & 1) << logn-1;
double ph = 2*M_PI/n;
REP(i, n)
units[i] = {cos(ph*i), sin(ph*i)};
}
void fft_dit2(cd a[], long n, int is)
{
long logn = 63-__builtin_clzl(n);
if (is < 0)
for (long i = 1, j = n-1; i < j; i++, j--)
swap(a[i], a[j]);
REP(i, n)
if (i < bitrev[i])
swap(a[i], a[bitrev[i]]);
for (long m = 2, dwi = n>>1; m <= n; m <<= 1, dwi >>= 1) {
long mh = m >> 1;
for (long r = 0; r < n; r += m) {
cd *x = a+r, *y = a+r+mh, *w = units;
REP(j, mh) {
cd t{y->real()*w->real()-y->imag()*w->imag(), y->real()*w->imag()+y->imag()*w->real()};
*y++ = *x-t;
*x++ += t;
w += dwi;
}
}
}
if (is < 0)
REP(i, n)
a[i] *= 1.0/n;
}
vector<cd> fft_interleave(const vector<int>& a, long n)
{
vector<cd> r(n);
REP(i, a.size()) {
long z = a[i] <= MOD/2 ? a[i] : a[i]-MOD;
r[i] = cd(z%SQ, z/SQ*IROOT);
}
fft_dit2(&r[0], n, 1);
return r;
}
vector<int> ifft_interleave(vector<cd>& a)
{
fft_dit2(&a[0], a.size(), -1);
vector<int> r(a.size());
REP(i, a.size()) {
long x = round(a[i].real()), y = long(round(a[i].imag()/IROOT));
r[i] = (x+y%MOD*SQ)%MOD;
if (r[i] < 0) r[i] += MOD;
}
return r;
}
vector<int> rising_factorial(long l, long h)
{
if (h-l <= 64-1) {
vector<int> r(h-l+1);
r[0] = 1;
FOR(i, l, h) {
int ul = r[0];
r[0] = r[0]*i%MOD;
REP(j, i-l+1) {
int t = (r[j+1]*i+ul)%MOD;
ul = r[j+1];
r[j+1] = t;
}
}
r.resize(1 << 63-__builtin_clzl(r.size()-1)+1);
return r;
}
long m = l+h >> 1;
auto a = rising_factorial(l, m), b = rising_factorial(m, h);
long n = 1 << 63-__builtin_clzl(a.size()+b.size()-2)+1;
fft_prepare(n);
auto aa = fft_interleave(a, n), bb = fft_interleave(b, n);
REP(i, n)
aa[i] *= bb[i];
return ifft_interleave(aa);
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
long n;
cin >> n;
vector<int> stirling1 = rising_factorial(0, n);
ROF(i, 0, n-1)
stirling1[i] = (stirling1[i]+stirling1[i+2])%MOD;
ROF(i, 1, n)
cout << stirling1[i] << ' ';
cout << endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment