Skip to content

Instantly share code, notes, and snippets.

@lukicdarkoo
Created December 28, 2015 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lukicdarkoo/72c566bbb09bfc0e3675 to your computer and use it in GitHub Desktop.
Save lukicdarkoo/72c566bbb09bfc0e3675 to your computer and use it in GitHub Desktop.
Parallel FFT with TBB
#include "FFT.h"
#if defined(MODE_PARALLEL_DEV) || defined(MODE_PARALLEL_TEST)
class FFTTask: public task {
public:
const int N;
std::complex<double> *x;
FFTTask(std::complex<double> *_x, int _N) :
x(_x), N(_N)
{}
task* execute() {
if (N <= 1) {
return NULL;
}
// Split even and odd
std::complex<double> odd[N/2];
std::complex<double> even[N/2];
for (int i = 0; i < N / 2; i++) {
even[i] = x[i*2];
odd[i] = x[i*2+1];
}
FFTTask& a = *new( allocate_child() ) FFTTask(even, N / 2);
FFTTask& b = *new( allocate_child() ) FFTTask(odd, N / 2);
set_ref_count(3);
spawn(b);
spawn_and_wait_for_all(a);
for (int k = 0; k < N / 2; k++) {
std::complex<double> t = exp(std::complex<double>(0, -2 * M_PI * k / N)) * odd[k];
x[k] = even[k] + t;
x[N / 2 + k] = even[k] - t;
}
return NULL;
}
};
#endif
void fft(double *x_in,
std::complex<double> *x_out,
int N) {
// Make copy of array and apply window
for (int i = 0; i < N; i++) {
x_out[i] = std::complex<double>(x_in[i], 0);
x_out[i] *= 1; // Window
}
// Start recursion
#if defined(MODE_TEST) || defined(MODE_DEV)
fft_rec(x_out, N);
#endif
#if defined(MODE_PARALLEL_DEV) || defined(MODE_PARALLEL_TEST)
FFTTask &task = *new(task::allocate_root()) FFTTask(x_out, N);
task::spawn_root_and_wait(task);
#endif
}
#if defined(MODE_TEST) || defined(MODE_DEV)
void fft_rec(std::complex<double> *x, int N) {
// Check if it is splitted enough
if (N <= 1) {
return;
}
// Split even and odd
std::complex<double> odd[N/2];
std::complex<double> even[N/2];
for (int i = 0; i < N / 2; i++) {
even[i] = x[i*2];
odd[i] = x[i*2+1];
}
// Split on tasks
fft_rec(even, N/2);
fft_rec(odd, N/2);
// Calculate DFT
for (int k = 0; k < N / 2; k++) {
std::complex<double> t = exp(std::complex<double>(0, -2 * M_PI * k / N)) * odd[k];
x[k] = even[k] + t;
x[N / 2 + k] = even[k] - t;
}
}
#endif
#ifndef FFT_h
#define FFT_h
#include <cmath>
#include <complex>
#include "Config.h"
#if defined(MODE_PARALLEL_DEV) || defined(MODE_PARALLEL_TEST)
#include "tbb/task.h"
using namespace tbb;
#endif
extern void fft(double *x_in,
std::complex<double> *x_out,
int N);
void fft_rec(std::complex<double> *x, int N);
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment