Skip to content

Instantly share code, notes, and snippets.

@justiceHui
Last active May 29, 2020 03:25
Show Gist options
  • Save justiceHui/f9aa98357cbdc4fc5372738a026a87a4 to your computer and use it in GitHub Desktop.
Save justiceHui/f9aa98357cbdc4fc5372738a026a87a4 to your computer and use it in GitHub Desktop.
FFT 성능 측정
#include <bits/stdc++.h>
#define all(v) v.begin(), v.end()
using namespace std;
/*
FFT_Recursion : https://justicehui.github.io/hard-algorithm/2019/09/04/FFT/
FFT_Recursion_Fast : https://namnamseo.tistory.com/entry/FFT-in-competitive-programming
FFT_Non_Recursion : https://blog.myungwoo.kr/54
NTT : https://algoshitpo.github.io/2020/05/20/fft-ntt/
Hell_Joseon_FFT : https://github.com/koosaga/DeobureoMinkyuParty
*/
typedef long long ll;
typedef complex<double> cpx;
typedef vector<cpx> poly;
typedef vector<ll> n_poly;
const double pi = acos(-1);
const ll w = 3, mod = 998244353;
namespace FFT_Recursion{
void fft_recursion(poly &f, const cpx w){
const int n = f.size(); if(n == 1) return;
poly v[2];
for(int i=0; i<n; i++) v[i&1].push_back(f[i]);
fft_recursion(v[0], w*w);
fft_recursion(v[1], w*w);
cpx wp(1, 0);
for(int i=0; i<n/2; i++){
f[i] = v[0][i] + wp * v[1][i];
f[i+n/2] = v[0][i] - wp * v[1][i];
wp *= w;
}
}
vector<ll> multiply(const vector<ll> &_a, const vector<ll> &_b){
poly a, b; a.reserve(_a.size()); b.reserve(_b.size());
for(auto i : _a) a.push_back(i);
for(auto i : _b) b.push_back(i);
int n = 1;
while(n < a.size() + b.size()) n <<= 1;
a.resize(n); b.resize(n);
cpx w(cos(2*pi/n), sin(2*pi/n));
fft_recursion(a, w); fft_recursion(b, w);
for(int i=0; i<n; i++) a[i] *= b[i];
vector<ll> ret(n);
fft_recursion(a, cpx(1, 0) / w);
for(int i=0; i<n; i++){
a[i] /= cpx(n, 0);
ret[i] = round(a[i].real());
}
while(ret.size() > 1 && !ret.back()) ret.pop_back();
return ret;
}
}
namespace FFT_Recursion_Fast{
void __fft_recursion_fast(poly &f, poly &res, int st, int step, const int n, int save_pos, bool inv = false){
if(n == 1){ res[save_pos] = f[st]; return; }
__fft_recursion_fast(f, res, st, step << 1, n >> 1, save_pos, inv);
__fft_recursion_fast(f, res, st + step, step << 1, n >> 1, save_pos + (n >> 1), inv);
cpx w(cos(2*pi/n), sin(2*pi/n)), wp(1, 0);
if(inv) w = cpx(1, 0) / w;
for(int i=0; i<n/2; i++){
auto a = res[save_pos+i];
auto b = wp * res[save_pos+i+n/2];
res[save_pos+i] = a + b;
res[save_pos+i+n/2] = a - b;
wp *= w;
}
}
poly fft_recursion_fast(poly f, bool inv = false){
poly res(f.size());
__fft_recursion_fast(f, res, 0, 1, f.size(), 0, inv);
return res;
}
vector<ll> multiply(const vector<ll> &_a, const vector<ll> &_b){
poly a, b; a.reserve(_a.size()); b.reserve(_b.size());
for(auto i : _a) a.push_back(i);
for(auto i : _b) b.push_back(i);
int n = 1;
while(n < a.size() + b.size()) n <<= 1;
a.resize(n); b.resize(n);
cpx w(cos(2*pi/n), sin(2*pi/n));
a = fft_recursion_fast(a);
b = fft_recursion_fast(b);
for(int i=0; i<n; i++) a[i] *= b[i];
vector<ll> ret(n);
a = fft_recursion_fast(a, 1);
for(int i=0; i<n; i++){
a[i] /= cpx(n, 0);
ret[i] = round(a[i].real());
}
while(ret.size() > 1 && !ret.back()) ret.pop_back();
return ret;
}
}
namespace FFT_Non_Recursion{
void fft_non_recursion(poly &f, bool inv = 0){
int n = f.size(), j = 0;
vector<cpx> root(n >> 1);
for(int i=1; i<n; i++){
int bit = (n >> 1);
while(j >= bit){
j -= bit; bit >>= 1;
}
j += bit;
if(i < j) swap(f[i], f[j]);
}
double ang = 2 * pi / n; if(inv) ang *= -1;
for(int i=0; i<(n >> 1); i++) root[i] = cpx(cos(ang*i), sin(ang*i));
for(int i=2; i<=n; i<<=1){
int step = n / i;
for(int j=0; j<n; j+=i){
for(int k=0; k<(i >> 1); k++){
cpx u = f[j | k], v = f[j | k | i >> 1] * root[step * k];
f[j | k] = u + v;
f[j | k | i >> 1] = u - v;
}
}
}
if(inv) for(int i=0; i<n; i++) f[i] /= n;
}
vector<ll> multiply(const vector<ll> &_a, const vector<ll> &_b){
poly a, b; a.reserve(_a.size()); b.reserve(_b.size());
for(auto i : _a) a.push_back(i);
for(auto i : _b) b.push_back(i);
int n = 1;
while(n < a.size() + b.size()) n <<= 1;
a.resize(n); b.resize(n);
cpx w(cos(2*pi/n), sin(2*pi/n));
fft_non_recursion(a);
fft_non_recursion(b);
for(int i=0; i<n; i++) a[i] *= b[i];
vector<ll> ret(n);
fft_non_recursion(a, 1);
for(int i=0; i<n; i++) ret[i] = round(a[i].real());
while(ret.size() > 1 && !ret.back()) ret.pop_back();
return ret;
}
}
namespace NTT{
ll pw(ll a, ll b){
ll ret = 1;
while(b){
if(b & 1) ret = ret * a % mod;
b >>= 1; a = a * a % mod;
}
return ret;
}
void ntt(n_poly &f, bool inv = false){
int n = f.size(), j = 0;
vector<ll> root(n >> 1);
for(int i=1; i<n; i++){
int bit = (n >> 1);
for(; bit<=j; bit>>=1) j -= bit;
j += bit;
if(i < j) swap(f[i], f[j]);
}
ll ang = pw(w, (mod-1)/n); if(inv) ang = pw(ang, mod-2);
root[0] = 1; for(int i=1; i<(n >> 1); i++) root[i] = root[i-1] * ang % mod;
for(int len=2; len<=n; len<<=1){
int step = n / len;
for(int i=0; i<n; i+=len){
for(int j=0; j<len/2; j++){
ll a = f[i+j], b = f[i+j+len/2] * root[step*j] % mod;
f[i+j] = (a + b) % mod;
f[i+j+len/2] = (a - b) % mod;
if(f[i+j+len/2] < 0) f[i+j+len/2] += mod;
}
}
}
ll t = pw(n, mod-2);
if(inv) for(int i=0; i<n; i++) f[i] = f[i] * t % mod;
}
vector<ll> multiply(n_poly &_a, n_poly &_b){
vector<ll> a(all(_a)), b(all(_b));
int n = 2;
while(n < a.size() + b.size()) n <<= 1;
a.resize(n); b.resize(n);
ntt(a); ntt(b);
for(int i=0; i<n; i++) a[i] = a[i] * b[i] % mod;
ntt(a, 1);
while(a.size() > 1 && !a.back()) a.pop_back();
return a;
}
}
namespace Hell_Joseon_FFT{
#include <smmintrin.h>
#include <immintrin.h>
__m256d mult(__m256d a, __m256d b){
__m256d c = _mm256_movedup_pd(a);
__m256d d = _mm256_shuffle_pd(a, a, 15);
__m256d cb = _mm256_mul_pd(c, b);
__m256d db = _mm256_mul_pd(d, b);
__m256d e = _mm256_shuffle_pd(db, db, 5);
__m256d r = _mm256_addsub_pd(cb, e);
return r;
}
void hell_joseon_fft(int n, __m128d a[], bool inv = false){
for(int i=1, j=0; i<n; ++i){
int bit = n>>1;
for(;j>=bit;bit>>=1) j -= bit;
j += bit;
if(i<j) swap(a[i], a[j]);
}
for(int len=2; len<=n; len<<=1){
double ang = 2*pi/len*(inv?-1:1);
__m256d wlen; wlen[0] = cos(ang), wlen[1] = sin(ang);
for(int i=0; i<n; i += len){
__m256d w; w[0] = 1; w[1] = 0;
for(int j=0; j<len/2; ++j){
w = _mm256_permute2f128_pd(w, w, 0);
wlen = _mm256_insertf128_pd(wlen, a[i+j+len/2], 1);
w = mult(w, wlen);
__m128d vw = _mm256_extractf128_pd(w, 1);
__m128d u = a[i+j];
a[i+j] = _mm_add_pd(u, vw);
a[i+j+len/2] = _mm_sub_pd(u, vw);
}
}
}
if(inv){
__m128d inv; inv[0] = inv[1] = 1.0/n;
for(int i=0; i<n; ++i) a[i] = _mm_mul_pd(a[i], inv);
}
}
vector<ll> multiply(vector<ll>& v, vector<ll>& w){
int n = 2; while(n < v.size()+w.size()) n<<=1;
__m128d* fv = new __m128d[n];
for(int i=0; i<n; ++i) fv[i][0] = fv[i][1] = 0;
for(int i=0; i<v.size(); ++i) fv[i][0] = v[i];
for(int i=0; i<w.size(); ++i) fv[i][1] = w[i];
hell_joseon_fft(n, fv); // (a+bi) is stored in FFT
for(int i=0; i<n; i += 2){
__m256d a;
a = _mm256_insertf128_pd(a, fv[i], 0);
a = _mm256_insertf128_pd(a, fv[i+1], 1);
a = mult(a, a);
fv[i] = _mm256_extractf128_pd(a, 0);
fv[i+1] = _mm256_extractf128_pd(a, 1);
}
hell_joseon_fft(n, fv, 1);
vector<ll> ret(n);
for(int i=0; i<n; ++i) ret[i] = (ll)round(fv[i][1]/2);
delete[] fv;
return ret;
}
}
mt19937 rd((unsigned)chrono::steady_clock::now().time_since_epoch().count());
inline ll now(){
auto time = std::chrono::system_clock::now();
auto mill = std::chrono::duration_cast<std::chrono::milliseconds>(time.time_since_epoch());
return mill.count();
}
int main(){
const int n = 1000000;
vector<ll> a(n);
vector<ll> b(n);
uniform_int_distribution<int> rnd(0, 1000);
for(int i=0; i<n; i++){
a[i] = rnd(rd);
b[i] = rnd(rd);
}
ll t;
cout << "[[FFT speed test]]\n";
cout << "compiler option : -std=c++14 -lm -Ofast -ffast-math -mavx -mavx2 -mfma -funroll-loops\n";
t = now();
vector<ll> c1 = FFT_Recursion::multiply(a, b);
cout << "recursion : " << now() - t << "ms\n";
t = now();
vector<ll> c2 = FFT_Recursion_Fast::multiply(a, b);
cout << "recursion opt : " << now() - t << "ms\n";
t = now();
vector<ll> c3 = FFT_Non_Recursion::multiply(a, b);
cout << "non recursion : " << now() - t << "ms\n";
t = now();
vector<ll> c4 = NTT::multiply(a, b);
cout << "NTT : " << now() - t << "ms\n";
t = now();
vector<ll> c5 = Hell_Joseon_FFT::multiply(a, b);
cout << "hell joseon : " << now() - t << "ms\n";
}
/*
[[FFT speed test]]
compiler option : -std=c++14 -lm -Ofast -ffast-math -mavx -mavx2 -mfma -funroll-loops
recursion : 2447ms
recursion opt : 1279ms
non recursion : 1131ms
NTT : 2282ms
hell joseon : 324ms
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment