Last active
August 9, 2017 20:13
-
-
Save mrange/7f8001ee767b0224e0e4192d1e87b517 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define _USE_MATH_DEFINES | |
#include <cassert> | |
#include <chrono> | |
#include <cmath> | |
#include <cstdint> | |
#include <cstring> | |
#include <stdexcept> | |
#include <string> | |
#include <tuple> | |
#include <vector> | |
#ifdef _MSVC_LANG | |
# include <intrin.h> | |
#endif | |
#include <immintrin.h> | |
// g++ -g -DNDEBUG --std=c++14 -pipe -Wall -O3 -ffast-math -fno-finite-math-only -march=native -mavx fft.cc | |
namespace | |
{ | |
template<typename T> | |
auto time_it (std::size_t repeat, T a) | |
{ | |
auto result = a (); | |
auto before = std::chrono::high_resolution_clock::now (); | |
for (auto iter = 0U; iter < repeat; ++iter) | |
{ | |
auto inner = a (); | |
} | |
auto after = std::chrono::high_resolution_clock::now (); | |
auto diff = std::chrono::duration_cast<std::chrono::milliseconds> (after - before).count (); | |
return std::make_tuple (diff, std::move (result)); | |
} | |
struct complex | |
{ | |
static double * align_pointer (double * p) | |
{ | |
auto pp = reinterpret_cast<std::uintptr_t> (p); | |
return reinterpret_cast<double*> (((pp + 31) / 32) * 32); | |
} | |
complex (std::size_t n) | |
: n (n) | |
, realu (new double[n + 32 / sizeof(double)]) | |
, imagu (new double[n + 32 / sizeof(double)]) | |
, owns_data (true) | |
, real (align_pointer(realu)) | |
, imag (align_pointer(imagu)) | |
{ | |
std::memset (real, 0, sizeof(double)*n); | |
std::memset (imag, 0, sizeof(double)*n); | |
} | |
~complex () | |
{ | |
if (owns_data) | |
{ | |
owns_data = false ; | |
delete [] realu ; | |
delete [] imagu ; | |
} | |
} | |
complex () = delete; | |
complex (complex const &) = delete; | |
complex & operator= (complex const &) = delete; | |
complex (complex && o) | |
: n (o.n) | |
, realu (o.realu) | |
, imagu (o.imagu) | |
, owns_data (o.owns_data) | |
, real (o.real) | |
, imag (o.imag) | |
{ | |
o.owns_data = false; | |
} | |
complex & operator= (complex &&) = delete; | |
void trim_inplace () | |
{ | |
auto t = [] (double d) | |
{ | |
return | |
std::abs (d) < 1E-10 | |
? 0.0 | |
: d | |
; | |
}; | |
for (auto i = 0U; i < n; ++i) | |
{ | |
real[i] = t(real[i]); | |
imag[i] = t(imag[i]); | |
} | |
} | |
complex operator- (complex const & o) const | |
{ | |
if (n != o.n) | |
{ | |
throw std::runtime_error ("n && o.n must be equal"); | |
} | |
complex result (n); | |
for (auto i = 0U; i < n; ++i) | |
{ | |
result.real[i] = real[i] - o.real[i]; | |
result.imag[i] = imag[i] - o.imag[i]; | |
} | |
return result; | |
} | |
template<typename TFolder, typename T> | |
T fold (T && initial, TFolder && folder) const | |
{ | |
auto s = std::forward<T> (initial); | |
for (auto i = 0U; i < n; ++i) | |
{ | |
s = folder (s, real[i], imag[i]); | |
} | |
return s; | |
} | |
std::string to_string () const | |
{ | |
if (owns_data) | |
{ | |
char b [128] ; | |
std::string s ; | |
s.reserve (32*n); | |
s += "{"; | |
std::sprintf (b, "n:%zd", n); | |
s += b ; | |
for (auto i = 0U; i < n; ++i) | |
{ | |
std::sprintf (b, ", %.4f%+.4fi", real[i], imag[i]); | |
s += b ; | |
} | |
s += "}"; | |
return s; | |
} | |
else | |
{ | |
return "{n:NaN}"; | |
} | |
} | |
std::size_t const n ; | |
private: | |
double * const realu ; | |
double * const imagu ; | |
bool owns_data ; | |
public: | |
double * const real ; | |
double * const imag ; | |
}; | |
auto pi = M_PI ; | |
auto tau = 2.0 * pi; | |
inline std::tuple<double, double> rotation (double a) | |
{ | |
auto r = std::cos (a); | |
auto i = std::sin (a); | |
return std::make_tuple (r, i); | |
} | |
complex dft (complex const & samples) | |
{ | |
auto n = samples.n; | |
complex result (n); | |
auto nm = -tau / n; | |
for (auto i = 0U; i < n; ++i) | |
{ | |
auto sr = 0.0; | |
auto si = 0.0; | |
for (auto j = 0U; j < n; ++j) | |
{ | |
auto rot = rotation (i*j*nm); | |
auto r = samples.real[j]; | |
auto i = samples.imag[j]; | |
sr += std::get<0> (rot)*r; | |
si += std::get<1> (rot)*i; | |
} | |
result.real[i] = sr; | |
result.imag[i] = si; | |
} | |
return result; | |
} | |
bool is_power_of_two (std::size_t i) | |
{ | |
return (i & (i - 1)) == 0; | |
} | |
int ilog2 (unsigned int i) | |
{ | |
assert (is_power_of_two (i)); | |
#ifdef _MSVC_LANG | |
unsigned long ri; | |
_BitScanReverse (&ri, i); | |
return ri; | |
#else | |
return 32 - __builtin_clz (i) - 1; | |
#endif | |
} | |
std::tuple<complex *, complex *> fft_loop (std::size_t n2, std::size_t s, std::size_t c, complex * f, complex * t) | |
{ | |
if (c > 1) | |
{ | |
auto c2 = c >> 1; | |
auto r = fft_loop (n2, s << 1, c2, f, t); | |
auto t = std::get<0> (r); | |
auto f = std::get<1> (r); | |
auto nm = -tau / c; | |
auto fr = f->real; | |
auto fi = f->imag; | |
auto tr = t->real; | |
auto ti = t->imag; | |
for (auto j = 0U; j < c2; ++j) | |
{ | |
auto w = rotation (j*nm); | |
auto wr = std::get<0> (w); | |
auto wi = std::get<1> (w); | |
auto off = s*j; | |
auto off2 = off << 1; | |
for (auto i = 0U; i < s; ++i) | |
{ | |
auto er = fr[i + off2 + 0]; | |
auto ei = fi[i + off2 + 0]; | |
auto or_ = fr[i + off2 + s]; | |
auto oi = fi[i + off2 + s]; | |
// a = w*o | |
auto ar = wr*or_ - wi*oi ; | |
auto ai = wr*oi + wi*or_; | |
tr[i + off + 0 ] = er + ar; | |
ti[i + off + 0 ] = ei + ai; | |
tr[i + off + n2] = er - ar; | |
ti[i + off + n2] = ei - ai; | |
} | |
} | |
return std::make_pair (f, t); | |
} | |
else | |
{ | |
return std::make_tuple (t, f); | |
} | |
} | |
std::vector<std::vector<std::tuple<double, double>>> generate_rotations(std::size_t n) | |
{ | |
std::vector<std::vector<std::tuple<double, double>>> result; | |
result.reserve (n); | |
auto c = 1U; | |
for (auto i = 0U; i < n; ++i) | |
{ | |
std::vector<std::tuple<double, double>> r; | |
r.reserve (c); | |
for (auto j = 0U; j < c; ++j) | |
{ | |
r.push_back (rotation ((-tau * j) / c)); | |
} | |
result.push_back (std::move (r)); | |
c *= 2; | |
} | |
return result; | |
} | |
auto const rotations = generate_rotations (16); | |
std::tuple<complex *, complex *> fft_loop_avx ( | |
std::size_t n2 | |
, std::tuple<double, double> const * rs | |
, std::size_t s | |
, std::size_t c | |
, complex * f | |
, complex * t | |
) | |
{ | |
if (c > 2) | |
{ | |
auto r = fft_loop_avx (n2, rs, s << 1, c >> 1, f, t); | |
auto t = std::get<0> (r); | |
auto f = std::get<1> (r); | |
auto fr = f->real; | |
auto fi = f->imag; | |
auto tr = t->real; | |
auto ti = t->imag; | |
if (s == 1) | |
{ | |
for (auto j = 0U; j < n2; ++j) | |
{ | |
auto w = rs[j]; | |
auto wr = std::get<0> (w); | |
auto wi = std::get<1> (w); | |
auto er = fr[2*j + 0]; | |
auto ei = fi[2*j + 0]; | |
auto or_ = fr[2*j + s]; | |
auto oi = fi[2*j + s]; | |
// a = w*o | |
auto ar = wr*or_ - wi*oi ; | |
auto ai = wr*oi + wi*or_; | |
tr[j + 0 ] = er + ar; | |
ti[j + 0 ] = ei + ai; | |
tr[j + n2] = er - ar; | |
ti[j + n2] = ei - ai; | |
} | |
} | |
else if (s >= 4) | |
{ | |
assert (s % 4 == 0); | |
for (auto j = 0U; j < n2; j += s) | |
{ | |
auto off = j; | |
auto off2 = off << 1; | |
auto w = rs[off]; | |
auto wr = _mm256_set1_pd (std::get<0> (w)); | |
auto wi = _mm256_set1_pd (std::get<1> (w)); | |
for (auto i = 0U; i < s; i += 4) | |
{ | |
auto er = _mm256_load_pd (&fr[i + off2 + 0]); | |
auto ei = _mm256_load_pd (&fi[i + off2 + 0]); | |
auto or_ = _mm256_load_pd (&fr[i + off2 + s]); | |
auto oi = _mm256_load_pd (&fi[i + off2 + s]); | |
// a = w*o | |
auto ar = _mm256_sub_pd (_mm256_mul_pd (wr, or_), _mm256_mul_pd (wi, oi )); | |
auto ai = _mm256_add_pd (_mm256_mul_pd (wr, oi ), _mm256_mul_pd (wi, or_)); | |
_mm256_store_pd(&tr[i + off + 0 ], _mm256_add_pd (er, ar)); | |
_mm256_store_pd(&ti[i + off + 0 ], _mm256_add_pd (ei, ai)); | |
_mm256_store_pd(&tr[i + off + n2], _mm256_sub_pd (er, ar)); | |
_mm256_store_pd(&ti[i + off + n2], _mm256_sub_pd (ei, ai)); | |
} | |
} | |
} | |
else | |
{ | |
for (auto j = 0U; j < n2; j += s) | |
{ | |
auto off = j; | |
auto off2 = off << 1; | |
auto w = rs[off]; | |
auto wr = std::get<0> (w); | |
auto wi = std::get<1> (w); | |
for (auto i = 0U; i < s; ++i) | |
{ | |
auto er = fr[i + off2 + 0]; | |
auto ei = fi[i + off2 + 0]; | |
auto or_ = fr[i + off2 + s]; | |
auto oi = fi[i + off2 + s]; | |
// a = w*o | |
auto ar = wr*or_ - wi*oi ; | |
auto ai = wr*oi + wi*or_; | |
tr[i + off + 0 ] = er + ar; | |
ti[i + off + 0 ] = ei + ai; | |
tr[i + off + n2] = er - ar; | |
ti[i + off + n2] = ei - ai; | |
} | |
} | |
} | |
return std::make_pair (f, t); | |
} | |
else if (c == 2) | |
{ | |
auto fr = f->real; | |
auto fi = f->imag; | |
auto tr = t->real; | |
auto ti = t->imag; | |
if (s >= 4) | |
{ | |
assert (s % 4 == 0); | |
for (auto i = 0U; i < s; i += 4) | |
{ | |
auto er = _mm256_load_pd (&fr[i + 0]); | |
auto ei = _mm256_load_pd (&fi[i + 0]); | |
auto or_ = _mm256_load_pd (&fr[i + s]); | |
auto oi = _mm256_load_pd (&fi[i + s]); | |
// a = w*o | |
auto ar = or_; | |
auto ai = oi ; | |
_mm256_store_pd(&tr[i + 0 ], _mm256_add_pd (er, ar)); | |
_mm256_store_pd(&ti[i + 0 ], _mm256_add_pd (ei, ai)); | |
_mm256_store_pd(&tr[i + n2], _mm256_sub_pd (er, ar)); | |
_mm256_store_pd(&ti[i + n2], _mm256_sub_pd (ei, ai)); | |
} | |
} | |
else | |
{ | |
for (auto i = 0U; i < s; ++i) | |
{ | |
auto er = fr[i + 0]; | |
auto ei = fi[i + 0]; | |
auto or_= fr[i + s]; | |
auto oi = fi[i + s]; | |
// a = w*o | |
auto ar = or_; | |
auto ai = oi ; | |
tr[i + 0 ] = er + ar; | |
ti[i + 0 ] = ei + ai; | |
tr[i + n2] = er - ar; | |
ti[i + n2] = ei - ai; | |
} | |
} | |
return std::make_pair (f, t); | |
} | |
else | |
{ | |
return std::make_tuple (t, f); | |
} | |
} | |
complex fft (complex const & samples) | |
{ | |
auto n = samples.n; | |
if (n < 1) | |
{ | |
throw std::runtime_error ("samples.n must be greater than 1"); | |
} | |
if (!is_power_of_two (n)) | |
{ | |
throw std::runtime_error ("samples.n must be power of 2"); | |
} | |
complex r0 (n); | |
complex r1 (n); | |
std::memcpy (r0.real, samples.real, sizeof(double)*n); | |
std::memcpy (r0.imag, samples.imag, sizeof(double)*n); | |
auto res = fft_loop (n >> 1, 1, n, &r0, &r1); | |
auto r = std::move (*std::get<1> (res)); | |
return r; | |
} | |
complex fft_avx (complex const & samples) | |
{ | |
auto n = samples.n; | |
if (n < 1) | |
{ | |
throw std::runtime_error ("samples.n must be greater than 1"); | |
} | |
if (!is_power_of_two (n)) | |
{ | |
throw std::runtime_error ("samples.n must be power of 2"); | |
} | |
complex r0 (n); | |
complex r1 (n); | |
std::memcpy (r0.real, samples.real, sizeof(double)*n); | |
std::memcpy (r0.imag, samples.imag, sizeof(double)*n); | |
auto & rs = rotations.at(ilog2(n)); | |
auto res = fft_loop_avx (n >> 1, &rs.front (), 1, n, &r0, &r1); | |
auto r = std::move (*std::get<1> (res)); | |
return r; | |
} | |
} | |
int main () | |
{ | |
#ifdef _DEBUG | |
auto o = 10U ; | |
auto n = 32U ; | |
#else | |
auto o = 10000U ; | |
auto n = 1024U ; | |
#endif | |
complex samples (n); | |
{ | |
auto nm = tau / n; | |
for (auto i = 0U; i < n; ++i) | |
{ | |
auto a = i * nm; | |
auto v = std::cos (a) + std::cos (2.0*a); | |
samples.real[i] = 2.0 * v / n; | |
samples.imag[i] = 0; | |
} | |
} | |
if (n < 2048) | |
{ | |
auto dft_result = dft (samples); | |
auto fft_result = fft (samples); | |
auto fft_avx_result = fft_avx (samples); | |
auto diff = fft_result - dft_result; | |
auto diff_avx = fft_avx_result - dft_result; | |
auto folder = [] (auto && s, auto && r, auto && i) | |
{ | |
return s + r*r + i*i; | |
}; | |
auto sum = diff.fold (0.0, folder); | |
auto sum_avx = diff_avx.fold (0.0, folder); | |
std::printf ("Data size : %d\n", n); | |
std::printf ("Sum : %g\n", sum / n); | |
std::printf ("Sum(avx) : %g\n", sum_avx / n); | |
if (n < 64) | |
{ | |
samples.trim_inplace (); | |
dft_result.trim_inplace (); | |
fft_result.trim_inplace (); | |
fft_avx_result.trim_inplace (); | |
std::printf ("samples : %s\n" , samples.to_string ().c_str ()); | |
std::printf ("dft_result: %s\n" , dft_result.to_string ().c_str ()); | |
std::printf ("fft_result: %s\n" , fft_result.to_string ().c_str ()); | |
std::printf ("fft_avx_result: %s\n" , fft_avx_result.to_string ().c_str ()); | |
} | |
} | |
auto fft_timed = time_it (o , [&samples] { return fft (samples); }); | |
auto fft_avx_timed = time_it (o , [&samples] { return fft_avx (samples); });; | |
auto fft_ms = std::get<0> (fft_timed) / (1.0*o); | |
auto fft_avx_ms = std::get<0> (fft_avx_timed) / (1.0*o); | |
std::printf ("FFT ms : %f\n", fft_ms ); | |
std::printf ("FFT(avx) ms : %f\n", fft_avx_ms ); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment