Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
#include <iostream>
#include <complex>
#include <cmath>
#include <iomanip>
#include <vector>
#include <algorithm>
#include <map>
#include <tuple>
using namespace std;
typedef complex<double> xd;
typedef vector<double> dvec;
typedef vector<xd> xvec;
const double PI = acos(0) * 2;
const xd J(0, 1); // sqrt(-1)
class FFT
{
public:
static dvec convolve(const dvec &a, const dvec &b)
{
// degree of resulting polynomial = size of resulting array
size_t deg = a.size() + b.size() - 1;
// transform array size must be in power of 2 for FFT
size_t N = 1;
while (N < deg) N <<= 1;
// precompute omega, if not yet done so:
for (int i = N; i > 0; i >>= 1) {
if (omega.find({i, 0}) != omega.end()) break;
int p = i / 2;
for (double j = 1 - p; j < p; ++j) {
omega[{i, j}] = exp((2. * PI * J * j) / (double)i);
}
}
xvec acof(N), bcof(N);
copy(a.begin(), a.end(), acof.begin());
copy(b.begin(), b.end(), bcof.begin());
xvec apv, bpv, cpv(N);
// evaluation: fft
apv = transform(acof);
bpv = transform(bcof);
// point-wise multiplcation
for (size_t i = 0; i < N; ++i) {
cpv[i] = apv[i] * bpv[i];
}
// interpolation: ifft
dvec c(deg);
cpv = transform(cpv, true);
for (size_t i = 0; i < deg; ++i) {
c[i] = cpv[i].real() / N;
}
return c;
}
private:
static map<pair<size_t, int>, xd> omega;
static xvec transform(xvec &s, bool inv = false)
{
double N = s.size();
if (N == 1) return s;
int halfN = N / 2;
xvec se, so;
se.reserve(halfN);
so.reserve(halfN);
for (int i = 0; i < N; i += 2) {
se.push_back(s[i]); // even
so.push_back(s[i + 1]); // odd
}
se = transform(se, inv);
so = transform(so, inv);
for (double m = 0; m < halfN; ++m) {
xd omso = omega[{N, inv ? m : -m}] * so[m];
s[m] = se[m] + omso;
s[m + halfN] = se[m] - omso;
}
return s;
}
};
map<pair<size_t, int>, xd> FFT::omega;
int main()
{
dvec a = { 6, 7, -10, 9 };
dvec b = { -2, 0, 4, -5 };
dvec c = FFT::convolve(a, b);
// Output: -12 -14 44 -20 -75 86 -45
for (const auto &t : c) cout << t << ' ';
cout << endl;
a = { 6, 7, -10, 9, 6, 7, -10, 9 };
b = { -2, 0, 4, -5, -2, 0, 4, -5 };
c = FFT::convolve(a, b);
// Output: -12 -14 44 -20 -99 58 43 -40 -162 158 -46 -20 -75 86 -45
for (const auto &t : c) cout << t << ' ';
cout << endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment