Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
FFT & NTT benchmark
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cmath>
#include <complex>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
using namespace std;
typedef complex<double> cd;
#define ALL(x) (x).begin(), (x).end()
#define FOR(i, a, b) for (remove_cv<remove_reference<decltype(b)>::type>::type i = (a); i < (b); i++)
#define REP(i, n) FOR(i, 0, n)
const long P_int = 998244353, M_int = (1L<<61)/P_int, G_int = 3; // 998244353 = 7*17*2**23+1
const long P_long = 1000000000949747713, G_long = 3;
typedef uint64_t u64;
typedef int64_t i64;
long times = 1;
const long NN = 1<<23;
extern inline long inv(long a, long b)
{
long u = 1, x = 0, q, bb = b;
while (a) {
q = b/a;
swap(x -= q*u, u);
swap(b -= q*a, a);
}
if (x < 0) x += bb;
return x;
}
extern inline int mul_mod(int a, int b, int m)
{
return long(a)*b%m;
}
extern inline long mul_mod(long a, long b, long m)
{
auto x = (unsigned long)a*b;
auto y = m*(unsigned long)((long double)a*(long double)b/m+0.5);
long r = x-y;
if (r < 0)
r += m;
return r;
}
extern inline long pow_mod(long a, long b, long mod)
{
long r = 1;
for (; b; b >>= 1) {
if (b & 1)
r = mul_mod(r, a, mod);
a = mul_mod(a, a, mod);
}
return r;
}
template<typename T>
vector<T> setup(long n)
{
vector<T> a(n);
iota(ALL(a), 0);
return a;
}
namespace Montgomery
{
extern inline u64 barrett30(u64 a, u64 P, u64 M)
{ // 2^29 <= P < 2^30
u64 r = a-((a>>28)*M>>33)*P;
if (r >= P) r -= P;
return r;
}
long pow_mod(long a, long b, long P, long M)
{
long r = 1;
for (; b; b >>= 1) {
if (b & 1)
r = barrett30(r*a, P, M);
a = barrett30(a*a, P, M);
}
return r;
}
void ntt_dif2(int a[], long n, long P, long M, long G, int is)
{
static int units[NN];
long invP = inv(P, 1L<<32), w1 = pow_mod(G, is > 0 ? (P-1)/n : P-1-(P-1)/n, P, M), wt = (1L<<32)%P;
REP(i, n>>1) {
units[i] = wt;
if (barrett30(wt*w1, P, M) != wt*w1%P) {
int*t=0;
*t=1;
}
wt = barrett30(wt*w1, P, M);
}
for (long m = n, dwi = 1; m >= 2; m >>= 1, dwi <<= 1)
for (long r = 0; r < n; r += m) {
int *x = a+r, *y = a+r+(m>>1), *w = units;
REP(j, m>>1) {
long u = long(*x)+*y;
auto v = ((unsigned long)(*x)-*y+2*P)**w;
if (u >= 2*P) u -= 2*P;
*x++ = u;
*y++ = (v>>32)-(((v<<32)*invP>>32)*P>>32)+P;
w += dwi;
}
}
REP(i, n)
if (a[i] >= P)
a[i] -= P;
long logn = 63-__builtin_clzl(n);
REP(i, n) {
unsigned int x = i, t;
t = x & 0x00ff00ff; x = t << 16 | t >> 32-16 | t ^ x;
t = x & 0x0f0f0f0f; x = t << 8 | t >> 32- 8 | t ^ x;
t = x & 0x33333333; x = t << 4 | t >> 32- 4 | t ^ x;
t = x & 0x55555555; x = t << 2 | t >> 32- 2 | t ^ x;
x >>= 31-logn;
if (i < x)
swap(a[i], a[x]);
}
if (is < 0) {
long invn = inv(n, P);
REP(i, n) {
if (barrett30(a[i]*invn, P, M) != a[i]*invn%P) {
int *t=0;
*t=1;
}
a[i] = barrett30(a[i]*invn, P, M);
}
}
}
void check(int a[], long n)
{
ntt_dif2(&a[0], n, P_int, M_int, G_int, 1);
ntt_dif2(&a[0], n, P_int, M_int, G_int, -1);
}
void run(int a[], long n)
{
ntt_dif2(a, n, P_int, M_int, G_int, 1);
}
}
namespace NTT_dif2
{
template<typename T, T P, T G>
void ntt_dif2(T a[], long n, int is)
{
static T units[NN/2];
T logn = 63-__builtin_clzl(n), w1 = pow_mod(G, is > 0 ? (P-1)/n : P-1-(P-1)/n, P), wt = 1;
REP(i, n>>1) {
units[i] = wt;
wt = mul_mod(wt, w1, P);
}
for (long m = n, dwi = 1; m >= 2; m >>= 1, dwi <<= 1)
for (long r = 0; r < n; r += m) {
T *x = a+r, *y = a+r+(m>>1), *w = units;
REP(j, m>>1) {
T u = *x+*y, v = mul_mod(*x-*y+P, *w, P);
if (u >= P) u -= P;
*x++ = u;
*y++ = v;
w += dwi;
}
}
if (is < 0) {
T invn = pow_mod(n, P-2, P);
REP(i, n)
a[i] = mul_mod(a[i], invn, P);
}
REP(i, n) {
unsigned int x = i, t;
t = x & 0x00ff00ff; x = t << 16 | t >> 32-16 | t ^ x;
t = x & 0x0f0f0f0f; x = t << 8 | t >> 32- 8 | t ^ x;
t = x & 0x33333333; x = t << 4 | t >> 32- 4 | t ^ x;
t = x & 0x55555555; x = t << 2 | t >> 32- 2 | t ^ x;
x >>= 31-logn;
if (i < x)
swap(a[i], a[x]);
}
}
template<typename T, T P, T G>
void check(T a[], long n)
{
ntt_dif2<T, P, G>(a, n, 1);
ntt_dif2<T, P, G>(a, n, -1);
}
template<typename T, T P, T G>
void run(T a[], long n)
{
ntt_dif2<T, P, G>(a, n, 1);
}
}
namespace NTT_dif2_variable_P
{
template<typename T>
void ntt_dif2_p(T a[], long n, T P, T G, int is)
{
static T units[NN/2];
T logn = 63-__builtin_clzl(n), w1 = pow_mod(G, is > 0 ? (P-1)/n : P-1-(P-1)/n, P), wt = 1;
REP(i, n>>1) {
units[i] = wt;
wt = mul_mod(wt, w1, P);
}
for (long m = n, dwi = 1; m >= 2; m >>= 1, dwi <<= 1)
for (long r = 0; r < n; r += m) {
T *x = a+r, *y = a+r+(m>>1), *w = units;
REP(j, m>>1) {
T u = *x+*y, v = mul_mod(*x-*y+P, *w, P);
if (u >= P) u -= P;
*x++ = u;
*y++ = v;
w += dwi;
}
}
if (is < 0) {
T invn = pow_mod(n, P-2, P);
REP(i, n)
a[i] = mul_mod(a[i], invn, P);
}
REP(i, n) {
unsigned int x = i, t;
t = x & 0x00ff00ff; x = t << 16 | t >> 32-16 | t ^ x;
t = x & 0x0f0f0f0f; x = t << 8 | t >> 32- 8 | t ^ x;
t = x & 0x33333333; x = t << 4 | t >> 32- 4 | t ^ x;
t = x & 0x55555555; x = t << 2 | t >> 32- 2 | t ^ x;
x >>= 31-logn;
if (i < x)
swap(a[i], a[x]);
}
}
template<typename T>
void check(T a[], long n)
{
volatile long p_int = P_int;
ntt_dif2_p<T>(a, n, p_int, G_int, 1);
ntt_dif2_p<T>(a, n, p_int, G_int, -1);
}
template<>
void check(long a[], long n)
{
volatile long p_long = P_long;
ntt_dif2_p<long>(a, n, p_long, G_long, 1);
ntt_dif2_p<long>(a, n, p_long, G_long, -1);
}
template<typename T>
void run(T a[], long n)
{
volatile long p_int = P_int;
ntt_dif2_p<T>(a, n, p_int, G_int, 1);
}
template<>
void run(long a[], long n)
{
volatile long p_long = P_long;
ntt_dif2_p<long>(a, n, p_long, G_long, 1);
}
}
namespace NTT_dit2
{
template<typename T, T P, T G>
void ntt_dit2(T a[], long n, int is)
{
static T units[NN/2];
T logn = 63-__builtin_clzl(n), w1 = pow_mod(G, is > 0 ? (P-1)/n : P-1-(P-1)/n, P), wt = 1;
REP(i, n) {
unsigned int x = i, t;
t = x & 0x00ff00ff; x = t << 16 | t >> 32-16 | t ^ x;
t = x & 0x0f0f0f0f; x = t << 8 | t >> 32- 8 | t ^ x;
t = x & 0x33333333; x = t << 4 | t >> 32- 4 | t ^ x;
t = x & 0x55555555; x = t << 2 | t >> 32- 2 | t ^ x;
x >>= 31-logn;
if (i < x)
swap(a[i], a[x]);
}
REP(i, n>>1) {
units[i] = wt;
wt = mul_mod(wt, w1, P);
}
for (long m = 2, dwi = n>>1; m <= n; m <<= 1, dwi >>= 1)
for (long r = 0; r < n; r += m) {
T *x = a+r, *y = a+r+(m>>1), *w = units;
REP(j, m>>1) {
T u = *x, v = mul_mod(*y, *w, P), x1 = u+v, y1 = u-v;
if (x1 >= P) x1 -= P;
if (y1 < 0) y1 += P;
*x++ = x1;
*y++ = y1;
w += dwi;
}
}
if (is < 0) {
T invn = pow_mod(n, P-2, P);
REP(i, n)
a[i] = mul_mod(a[i], invn, P);
}
}
template<typename T, T P, T G>
void check(T a[], long n)
{
ntt_dit2<T, P, G>(a, n, 1);
ntt_dit2<T, P, G>(a, n, -1);
}
template<typename T, T P, T G>
void run(T a[], long n)
{
ntt_dit2<int, P, G>(a, n, 1);
}
}
namespace FFT_dif2
{
void fft_dif2(cd a[], long n)
{ // sign = -1
static cd units[NN/2];
double ph = 2*M_PI/n;
REP(i, n/2)
units[i] = {cos(ph*i), sin(ph*i)};
for (long m = n, dwi = 1; m >= 2; m >>= 1, dwi <<= 1)
for (long r = 0; r < n; r += m) {
cd *x = a+r, *y = a+r+(m>>1), *w = units;
REP(j, m>>1) {
cd v = *y, t = *x-v;
*y++ = {t.real()*w->real()-t.imag()*w->imag(), t.real()*w->imag()-t.imag()*w->real()};
*x++ += v;
w += dwi;
}
}
}
void run(cd a[], long n)
{
fft_dif2(a, n);
}
}
namespace FFT_dit2
{
void fft_dit2(cd a[], long n)
{
static cd units[NN/2];
double ph = 2*M_PI/n;
REP(i, n/2)
units[i] = {cos(ph*i), sin(ph*i)};
for (long m = 2, dwi = n>>1; m <= n; m <<= 1, dwi >>= 1)
for (long r = 0; r < n; r += m) {
cd *x = a+r, *y = a+r+(m>>1), *w = units;
REP(j, m>>1) {
cd t{y->real()*w->real()-y->imag()*w->imag(), y->real()*w->imag()+y->imag()*w->real()};
*y++ = *x-t;
*x++ += t;
w += dwi;
}
}
}
void run(cd a[], long n)
{
fft_dit2(a, n);
}
}
template<typename T>
void check(long n, void(*fn)(T a[], long))
{
auto a = setup<T>(n);
fn(&a[0], n);
REP(i, n)
assert(a[i] == i);
}
template<typename T>
long test(long n, void(*fn)(T a[], long))
{
auto a = setup<T>(n);
auto start = chrono::steady_clock::now();
REP(_, times)
fn(&a[0], n);
return chrono::duration_cast<chrono::microseconds>(chrono::steady_clock::now() - start).count() / times;
}
int main(int argc, char* argv[])
{
if (argc > 1)
times = atoi(argv[1]);
for (long n = 1<<4; n <= 1<<4; n <<= 1) {
check<int>(n, Montgomery::check);
check<int>(n, NTT_dif2::check<int, P_int, G_int>);
check<long>(n, NTT_dif2::check<long, P_long, G_long>);
check<int>(n, NTT_dit2::check<int, P_int, G_int>);
check<long>(n, NTT_dit2::check<long, P_long, G_long>);
check<int>(n, NTT_dif2_variable_P::check<int>);
check<long>(n, NTT_dif2_variable_P::check<long>);
}
for (long n = 1<<8; n <= NN; n <<= 1) {
vector<pair<long, string>> res;
res.emplace_back(test<int>(n, Montgomery::run), "Montgomery+Barrett NTT dif2 int");
res.emplace_back(test<int>(n, NTT_dif2::run<int, P_int, G_int>), "NTT dif2 int");
res.emplace_back(test<long>(n, NTT_dif2::run<long, P_long, G_long>), "NTT dif2 long");
res.emplace_back(test<int>(n, NTT_dif2::run<int, P_int, G_int>), "NTT dit2 int");
res.emplace_back(test<long>(n, NTT_dif2::run<long, P_long, G_long>), "NTT dit2 long");
res.emplace_back(test<int>(n, NTT_dif2_variable_P::run<int>), "NTT dif2 int non-constant P");
res.emplace_back(test<long>(n, NTT_dif2_variable_P::run<long>), "NTT dif2 long non-constant P");
res.emplace_back(test<cd>(n, FFT_dif2::run), "FFT dif2");
res.emplace_back(test<cd>(n, FFT_dit2::run), "FFT dit2");
sort(ALL(res));
for (auto& x: res) cout << n << '\t' << x.first << '\t' << x.second << '\n';
cout << '\n';
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment