Last active
May 29, 2020 03:25
-
-
Save justiceHui/f9aa98357cbdc4fc5372738a026a87a4 to your computer and use it in GitHub Desktop.
FFT 성능 측정
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <bits/stdc++.h> | |
#define all(v) v.begin(), v.end() | |
using namespace std; | |
/* | |
FFT_Recursion : https://justicehui.github.io/hard-algorithm/2019/09/04/FFT/ | |
FFT_Recursion_Fast : https://namnamseo.tistory.com/entry/FFT-in-competitive-programming | |
FFT_Non_Recursion : https://blog.myungwoo.kr/54 | |
NTT : https://algoshitpo.github.io/2020/05/20/fft-ntt/ | |
Hell_Joseon_FFT : https://github.com/koosaga/DeobureoMinkyuParty | |
*/ | |
typedef long long ll; | |
typedef complex<double> cpx; | |
typedef vector<cpx> poly; | |
typedef vector<ll> n_poly; | |
const double pi = acos(-1); | |
const ll w = 3, mod = 998244353; | |
namespace FFT_Recursion{ | |
void fft_recursion(poly &f, const cpx w){ | |
const int n = f.size(); if(n == 1) return; | |
poly v[2]; | |
for(int i=0; i<n; i++) v[i&1].push_back(f[i]); | |
fft_recursion(v[0], w*w); | |
fft_recursion(v[1], w*w); | |
cpx wp(1, 0); | |
for(int i=0; i<n/2; i++){ | |
f[i] = v[0][i] + wp * v[1][i]; | |
f[i+n/2] = v[0][i] - wp * v[1][i]; | |
wp *= w; | |
} | |
} | |
vector<ll> multiply(const vector<ll> &_a, const vector<ll> &_b){ | |
poly a, b; a.reserve(_a.size()); b.reserve(_b.size()); | |
for(auto i : _a) a.push_back(i); | |
for(auto i : _b) b.push_back(i); | |
int n = 1; | |
while(n < a.size() + b.size()) n <<= 1; | |
a.resize(n); b.resize(n); | |
cpx w(cos(2*pi/n), sin(2*pi/n)); | |
fft_recursion(a, w); fft_recursion(b, w); | |
for(int i=0; i<n; i++) a[i] *= b[i]; | |
vector<ll> ret(n); | |
fft_recursion(a, cpx(1, 0) / w); | |
for(int i=0; i<n; i++){ | |
a[i] /= cpx(n, 0); | |
ret[i] = round(a[i].real()); | |
} | |
while(ret.size() > 1 && !ret.back()) ret.pop_back(); | |
return ret; | |
} | |
} | |
namespace FFT_Recursion_Fast{ | |
void __fft_recursion_fast(poly &f, poly &res, int st, int step, const int n, int save_pos, bool inv = false){ | |
if(n == 1){ res[save_pos] = f[st]; return; } | |
__fft_recursion_fast(f, res, st, step << 1, n >> 1, save_pos, inv); | |
__fft_recursion_fast(f, res, st + step, step << 1, n >> 1, save_pos + (n >> 1), inv); | |
cpx w(cos(2*pi/n), sin(2*pi/n)), wp(1, 0); | |
if(inv) w = cpx(1, 0) / w; | |
for(int i=0; i<n/2; i++){ | |
auto a = res[save_pos+i]; | |
auto b = wp * res[save_pos+i+n/2]; | |
res[save_pos+i] = a + b; | |
res[save_pos+i+n/2] = a - b; | |
wp *= w; | |
} | |
} | |
poly fft_recursion_fast(poly f, bool inv = false){ | |
poly res(f.size()); | |
__fft_recursion_fast(f, res, 0, 1, f.size(), 0, inv); | |
return res; | |
} | |
vector<ll> multiply(const vector<ll> &_a, const vector<ll> &_b){ | |
poly a, b; a.reserve(_a.size()); b.reserve(_b.size()); | |
for(auto i : _a) a.push_back(i); | |
for(auto i : _b) b.push_back(i); | |
int n = 1; | |
while(n < a.size() + b.size()) n <<= 1; | |
a.resize(n); b.resize(n); | |
cpx w(cos(2*pi/n), sin(2*pi/n)); | |
a = fft_recursion_fast(a); | |
b = fft_recursion_fast(b); | |
for(int i=0; i<n; i++) a[i] *= b[i]; | |
vector<ll> ret(n); | |
a = fft_recursion_fast(a, 1); | |
for(int i=0; i<n; i++){ | |
a[i] /= cpx(n, 0); | |
ret[i] = round(a[i].real()); | |
} | |
while(ret.size() > 1 && !ret.back()) ret.pop_back(); | |
return ret; | |
} | |
} | |
namespace FFT_Non_Recursion{ | |
void fft_non_recursion(poly &f, bool inv = 0){ | |
int n = f.size(), j = 0; | |
vector<cpx> root(n >> 1); | |
for(int i=1; i<n; i++){ | |
int bit = (n >> 1); | |
while(j >= bit){ | |
j -= bit; bit >>= 1; | |
} | |
j += bit; | |
if(i < j) swap(f[i], f[j]); | |
} | |
double ang = 2 * pi / n; if(inv) ang *= -1; | |
for(int i=0; i<(n >> 1); i++) root[i] = cpx(cos(ang*i), sin(ang*i)); | |
for(int i=2; i<=n; i<<=1){ | |
int step = n / i; | |
for(int j=0; j<n; j+=i){ | |
for(int k=0; k<(i >> 1); k++){ | |
cpx u = f[j | k], v = f[j | k | i >> 1] * root[step * k]; | |
f[j | k] = u + v; | |
f[j | k | i >> 1] = u - v; | |
} | |
} | |
} | |
if(inv) for(int i=0; i<n; i++) f[i] /= n; | |
} | |
vector<ll> multiply(const vector<ll> &_a, const vector<ll> &_b){ | |
poly a, b; a.reserve(_a.size()); b.reserve(_b.size()); | |
for(auto i : _a) a.push_back(i); | |
for(auto i : _b) b.push_back(i); | |
int n = 1; | |
while(n < a.size() + b.size()) n <<= 1; | |
a.resize(n); b.resize(n); | |
cpx w(cos(2*pi/n), sin(2*pi/n)); | |
fft_non_recursion(a); | |
fft_non_recursion(b); | |
for(int i=0; i<n; i++) a[i] *= b[i]; | |
vector<ll> ret(n); | |
fft_non_recursion(a, 1); | |
for(int i=0; i<n; i++) ret[i] = round(a[i].real()); | |
while(ret.size() > 1 && !ret.back()) ret.pop_back(); | |
return ret; | |
} | |
} | |
namespace NTT{ | |
ll pw(ll a, ll b){ | |
ll ret = 1; | |
while(b){ | |
if(b & 1) ret = ret * a % mod; | |
b >>= 1; a = a * a % mod; | |
} | |
return ret; | |
} | |
void ntt(n_poly &f, bool inv = false){ | |
int n = f.size(), j = 0; | |
vector<ll> root(n >> 1); | |
for(int i=1; i<n; i++){ | |
int bit = (n >> 1); | |
for(; bit<=j; bit>>=1) j -= bit; | |
j += bit; | |
if(i < j) swap(f[i], f[j]); | |
} | |
ll ang = pw(w, (mod-1)/n); if(inv) ang = pw(ang, mod-2); | |
root[0] = 1; for(int i=1; i<(n >> 1); i++) root[i] = root[i-1] * ang % mod; | |
for(int len=2; len<=n; len<<=1){ | |
int step = n / len; | |
for(int i=0; i<n; i+=len){ | |
for(int j=0; j<len/2; j++){ | |
ll a = f[i+j], b = f[i+j+len/2] * root[step*j] % mod; | |
f[i+j] = (a + b) % mod; | |
f[i+j+len/2] = (a - b) % mod; | |
if(f[i+j+len/2] < 0) f[i+j+len/2] += mod; | |
} | |
} | |
} | |
ll t = pw(n, mod-2); | |
if(inv) for(int i=0; i<n; i++) f[i] = f[i] * t % mod; | |
} | |
vector<ll> multiply(n_poly &_a, n_poly &_b){ | |
vector<ll> a(all(_a)), b(all(_b)); | |
int n = 2; | |
while(n < a.size() + b.size()) n <<= 1; | |
a.resize(n); b.resize(n); | |
ntt(a); ntt(b); | |
for(int i=0; i<n; i++) a[i] = a[i] * b[i] % mod; | |
ntt(a, 1); | |
while(a.size() > 1 && !a.back()) a.pop_back(); | |
return a; | |
} | |
} | |
namespace Hell_Joseon_FFT{ | |
#include <smmintrin.h> | |
#include <immintrin.h> | |
__m256d mult(__m256d a, __m256d b){ | |
__m256d c = _mm256_movedup_pd(a); | |
__m256d d = _mm256_shuffle_pd(a, a, 15); | |
__m256d cb = _mm256_mul_pd(c, b); | |
__m256d db = _mm256_mul_pd(d, b); | |
__m256d e = _mm256_shuffle_pd(db, db, 5); | |
__m256d r = _mm256_addsub_pd(cb, e); | |
return r; | |
} | |
void hell_joseon_fft(int n, __m128d a[], bool inv = false){ | |
for(int i=1, j=0; i<n; ++i){ | |
int bit = n>>1; | |
for(;j>=bit;bit>>=1) j -= bit; | |
j += bit; | |
if(i<j) swap(a[i], a[j]); | |
} | |
for(int len=2; len<=n; len<<=1){ | |
double ang = 2*pi/len*(inv?-1:1); | |
__m256d wlen; wlen[0] = cos(ang), wlen[1] = sin(ang); | |
for(int i=0; i<n; i += len){ | |
__m256d w; w[0] = 1; w[1] = 0; | |
for(int j=0; j<len/2; ++j){ | |
w = _mm256_permute2f128_pd(w, w, 0); | |
wlen = _mm256_insertf128_pd(wlen, a[i+j+len/2], 1); | |
w = mult(w, wlen); | |
__m128d vw = _mm256_extractf128_pd(w, 1); | |
__m128d u = a[i+j]; | |
a[i+j] = _mm_add_pd(u, vw); | |
a[i+j+len/2] = _mm_sub_pd(u, vw); | |
} | |
} | |
} | |
if(inv){ | |
__m128d inv; inv[0] = inv[1] = 1.0/n; | |
for(int i=0; i<n; ++i) a[i] = _mm_mul_pd(a[i], inv); | |
} | |
} | |
vector<ll> multiply(vector<ll>& v, vector<ll>& w){ | |
int n = 2; while(n < v.size()+w.size()) n<<=1; | |
__m128d* fv = new __m128d[n]; | |
for(int i=0; i<n; ++i) fv[i][0] = fv[i][1] = 0; | |
for(int i=0; i<v.size(); ++i) fv[i][0] = v[i]; | |
for(int i=0; i<w.size(); ++i) fv[i][1] = w[i]; | |
hell_joseon_fft(n, fv); // (a+bi) is stored in FFT | |
for(int i=0; i<n; i += 2){ | |
__m256d a; | |
a = _mm256_insertf128_pd(a, fv[i], 0); | |
a = _mm256_insertf128_pd(a, fv[i+1], 1); | |
a = mult(a, a); | |
fv[i] = _mm256_extractf128_pd(a, 0); | |
fv[i+1] = _mm256_extractf128_pd(a, 1); | |
} | |
hell_joseon_fft(n, fv, 1); | |
vector<ll> ret(n); | |
for(int i=0; i<n; ++i) ret[i] = (ll)round(fv[i][1]/2); | |
delete[] fv; | |
return ret; | |
} | |
} | |
mt19937 rd((unsigned)chrono::steady_clock::now().time_since_epoch().count()); | |
inline ll now(){ | |
auto time = std::chrono::system_clock::now(); | |
auto mill = std::chrono::duration_cast<std::chrono::milliseconds>(time.time_since_epoch()); | |
return mill.count(); | |
} | |
int main(){ | |
const int n = 1000000; | |
vector<ll> a(n); | |
vector<ll> b(n); | |
uniform_int_distribution<int> rnd(0, 1000); | |
for(int i=0; i<n; i++){ | |
a[i] = rnd(rd); | |
b[i] = rnd(rd); | |
} | |
ll t; | |
cout << "[[FFT speed test]]\n"; | |
cout << "compiler option : -std=c++14 -lm -Ofast -ffast-math -mavx -mavx2 -mfma -funroll-loops\n"; | |
t = now(); | |
vector<ll> c1 = FFT_Recursion::multiply(a, b); | |
cout << "recursion : " << now() - t << "ms\n"; | |
t = now(); | |
vector<ll> c2 = FFT_Recursion_Fast::multiply(a, b); | |
cout << "recursion opt : " << now() - t << "ms\n"; | |
t = now(); | |
vector<ll> c3 = FFT_Non_Recursion::multiply(a, b); | |
cout << "non recursion : " << now() - t << "ms\n"; | |
t = now(); | |
vector<ll> c4 = NTT::multiply(a, b); | |
cout << "NTT : " << now() - t << "ms\n"; | |
t = now(); | |
vector<ll> c5 = Hell_Joseon_FFT::multiply(a, b); | |
cout << "hell joseon : " << now() - t << "ms\n"; | |
} | |
/* | |
[[FFT speed test]] | |
compiler option : -std=c++14 -lm -Ofast -ffast-math -mavx -mavx2 -mfma -funroll-loops | |
recursion : 2447ms | |
recursion opt : 1279ms | |
non recursion : 1131ms | |
NTT : 2282ms | |
hell joseon : 324ms | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment