Last active
August 23, 2023 02:10
-
-
Save phoemur/151ca1999a76478ed6d74cbbb5d5e1c9 to your computer and use it in GitHub Desktop.
Cooley–Tukey Fast Fourier Transform algorithm - Recursive Divide and Conquer implementation in C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// g++ -o FFT FFT.cpp -Wall -O3 | |
#include <algorithm> | |
#include <cmath> | |
#include <complex> | |
#include <iostream> | |
#include <iterator> | |
#include <type_traits> | |
#include <vector> | |
using namespace std; | |
constexpr auto pi() {return atan(1)*4;} | |
//Folowing is for SFINAE | |
template <typename T> | |
struct extractType; | |
template <template <typename ...> class C, typename D> | |
struct extractType<C<D>> { using subType = D; }; | |
// Cooley–Tukey Fast Fourier Transform algorithm | |
// Recursive Divide and Conquer implementation | |
// Higher memory requirements and redundancy although more intuitive | |
template<typename InputIt, | |
typename value_type = typename iterator_traits<InputIt>::value_type> | |
typename enable_if<is_same<value_type, | |
complex<typename extractType<value_type>::subType>>::value, | |
void>::type // Only accepts std::complex numbers containers | |
FFT(InputIt begin, InputIt end) | |
{ | |
typename iterator_traits<InputIt>::difference_type N = distance(begin, end); | |
if (N < 2) return; | |
else { | |
// divide | |
stable_partition(begin, end, [&begin](auto& a){ | |
return distance(&*begin, &a) % 2 == 0; // pair indexes on the first half and odd on the last | |
}); | |
//conquer | |
FFT(begin, begin + N/2); // recurse even items | |
FFT(begin + N/2, end); // recurse odd items | |
//combine | |
for (decltype(N) k = 0; k < N/2; ++k) { | |
value_type even = *(begin + k); | |
value_type odd = *(begin + k + N/2); | |
value_type w = exp( value_type(0,-2.*pi()*k/N) ) * odd; | |
*(begin + k) = even + w; | |
*(begin + k + N/2) = even - w; | |
} | |
} | |
} | |
// Inverse FFT | |
template<typename InputIt, | |
typename value_type = typename iterator_traits<InputIt>::value_type> | |
typename enable_if<is_same<value_type, | |
complex<typename extractType<value_type>::subType>>::value, | |
void>::type | |
IFFT(InputIt begin, InputIt end) | |
{ | |
typename iterator_traits<InputIt>::difference_type N = distance(begin, end); | |
if (N < 2) return; | |
else { | |
// divide | |
stable_partition(begin, end, [&begin](auto& a){ | |
a = conj(a); // use the conjugate value | |
return distance(&*begin, &a) % 2 == 0; // pair indexes on the first half and odd on the last | |
}); | |
//conquer | |
FFT(begin, begin + N/2); // recurse even items on normal FFT | |
FFT(begin + N/2, end); // recurse odd items on normal FFT | |
//combine | |
for (decltype(N) k = 0; k < N/2; ++k) { | |
value_type even = *(begin + k); | |
value_type odd = *(begin + k + N/2); | |
value_type w = exp( value_type(0,-2.*pi()*k/N) ) * odd; | |
*(begin + k) = conj(even + w); //conjugate again and scale | |
*(begin + k) /= N; | |
*(begin + k + N/2) = conj(even - w); | |
*(begin + k + N/2) /= N; | |
} | |
} | |
} | |
int main() // Example of use for multiplying 2 polynomials | |
{ | |
/*Input: A[] = {5, 0, 10, 6} | |
B[] = {1, 2, 4} | |
Output: prod[] = {5, 10, 30, 26, 52, 24} */ | |
vector<complex<double>> A {5, 0, 10, 6}; | |
vector<complex<double>> B {1, 2, 4}; | |
size_t result_size = A.size()+B.size() - 1; // resulting degree after mult | |
// print originals | |
cout << "Input: A[] = {"; | |
for (auto& n: A) cout << n.real() << ", "; | |
cout << "\b\b}\n"; | |
cout << "Input: B[] = {"; | |
for (auto& n: B) cout << n.real() << ", "; | |
cout << "\b\b}\n"; | |
// pad inputs with zeroes (need to be even) | |
size_t algo_size = A.size()+B.size(); | |
if (algo_size % 2 == 1) ++algo_size; | |
A.resize(algo_size, complex<double>(0)); | |
B.resize(algo_size, complex<double>(0)); | |
vector<complex<double>> result (algo_size, 0); | |
// FFT | |
FFT(begin(A), end(A)); | |
FFT(begin(B), end(B)); | |
// Multiply | |
for (size_t i=0; i<result.size(); ++i) { | |
result[i] = A[i]*B[i]; // O(n) | |
} | |
// Inverse FFT | |
IFFT(begin(result), end(result)); | |
// Remove padding zeroes | |
result.resize(result_size); | |
cout << "Output: prod[] = {"; | |
for (auto& n:result) cout << static_cast<int>(n.real()) << ", "; | |
cout << "\b\b}\n"; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment