Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
#include <iostream>
#include <complex>
#include <cmath>
#include <iomanip>
#include <vector>
#include <algorithm>
#include <map>
#include <tuple>
using namespace std;
typedef complex<double> xd;
typedef vector<double> dvec;
typedef vector<xd> xvec;
const double PI = acos(0) * 2;
const xd J(0, 1); // sqrt(-1)
class FFT
{
public:
static dvec convolve(const dvec &a, const dvec &b)
{
// degree of resulting polynomial = size of resulting array
size_t deg = a.size() + b.size() - 1;
// transform array size must be in power of 2 for FFT
size_t N = 1;
logN = 0;
while (N < deg) {
N <<= 1;
++logN;
}
// precompute omega, if not yet done so:
for (int i = N; i > 0; i >>= 1) {
if (omega.find({i, 0}) != omega.end()) break;
int p = i / 2;
for (double j = 1 - p; j < p; ++j) {
omega[{i, j}] = exp((2. * PI * J * j) / (double)i);
}
}
xvec ax(N), bx(N);
copy(a.begin(), a.end(), ax.begin());
copy(b.begin(), b.end(), bx.begin());
// evaluation: fft
transform(ax);
transform(bx);
// point-wise multiplcation
for (size_t i = 0; i < N; ++i) {
ax[i] *= bx[i];
}
// interpolation: ifft
dvec c(deg);
transform(ax, true);
for (size_t i = 0; i < deg; ++i) {
c[i] = ax[i].real() / N;
}
return c;
}
private:
static map<pair<int, int>, xd> omega;
static int logN;
static void transform(xvec &s, bool inv = false)
{
int N = s.size();
int i, m, u, v;
xd fodd, feven;
// swap all elements with its bit reverse:
u = N - 1;
for (i = 1; i < u; ++i) {
v = reverseBits(i, logN);
if (v > i) swap(s[i], s[v]);
}
// in-place fourier transform:
for (int n = 2, p = 1; n <= N; n <<= 1, p <<= 1) {
for (i = 0; i < N; i += n) {
for (m = 0; m < p; ++m) {
u = i + m;
v = u + n/2;
fodd = omega[{n, inv ? m : -m}] * s[v];
feven = s[u];
s[u] = feven + fodd;
s[v] = feven - fodd;
}
}
}
}
static size_t reverseBits(const size_t &num, const size_t &bitNum)
{
size_t reverse_num = 0;
for (size_t i = 0; i < bitNum; ++i)
{
if (num & (1 << i)) {
reverse_num |= 1 << ((bitNum - 1) - i);
}
}
return reverse_num;
}
};
map<pair<int, int>, xd> FFT::omega;
int FFT::logN = 0;
int main()
{
dvec a = { 6, 7, -10, 9 };
dvec b = { -2, 0, 4, -5 };
dvec c = FFT::convolve(a, b);
// Output: -12 -14 44 -20 -75 86 -45
for (const auto &t : c) cout << t << ' ';
cout << endl;
a = { 6, 7, -10, 9, 6, 7, -10, 9 };
b = { -2, 0, 4, -5, -2, 0, 4, -5 };
c = FFT::convolve(a, b);
// Output: -12 -14 44 -20 -99 58 43 -40 -162 158 -46 -20 -75 86 -45
for (const auto &t : c) cout << t << ' ';
cout << endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment