Skip to content

Instantly share code, notes, and snippets.

@mrange
Last active August 9, 2017 20:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrange/7f8001ee767b0224e0e4192d1e87b517 to your computer and use it in GitHub Desktop.
Save mrange/7f8001ee767b0224e0e4192d1e87b517 to your computer and use it in GitHub Desktop.
#define _USE_MATH_DEFINES
#include <cassert>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <string>
#include <tuple>
#include <vector>
#ifdef _MSVC_LANG
# include <intrin.h>
#endif
#include <immintrin.h>
// g++ -g -DNDEBUG --std=c++14 -pipe -Wall -O3 -ffast-math -fno-finite-math-only -march=native -mavx fft.cc
namespace
{
template<typename T>
auto time_it (std::size_t repeat, T a)
{
auto result = a ();
auto before = std::chrono::high_resolution_clock::now ();
for (auto iter = 0U; iter < repeat; ++iter)
{
auto inner = a ();
}
auto after = std::chrono::high_resolution_clock::now ();
auto diff = std::chrono::duration_cast<std::chrono::milliseconds> (after - before).count ();
return std::make_tuple (diff, std::move (result));
}
struct complex
{
static double * align_pointer (double * p)
{
auto pp = reinterpret_cast<std::uintptr_t> (p);
return reinterpret_cast<double*> (((pp + 31) / 32) * 32);
}
complex (std::size_t n)
: n (n)
, realu (new double[n + 32 / sizeof(double)])
, imagu (new double[n + 32 / sizeof(double)])
, owns_data (true)
, real (align_pointer(realu))
, imag (align_pointer(imagu))
{
std::memset (real, 0, sizeof(double)*n);
std::memset (imag, 0, sizeof(double)*n);
}
~complex ()
{
if (owns_data)
{
owns_data = false ;
delete [] realu ;
delete [] imagu ;
}
}
complex () = delete;
complex (complex const &) = delete;
complex & operator= (complex const &) = delete;
complex (complex && o)
: n (o.n)
, realu (o.realu)
, imagu (o.imagu)
, owns_data (o.owns_data)
, real (o.real)
, imag (o.imag)
{
o.owns_data = false;
}
complex & operator= (complex &&) = delete;
void trim_inplace ()
{
auto t = [] (double d)
{
return
std::abs (d) < 1E-10
? 0.0
: d
;
};
for (auto i = 0U; i < n; ++i)
{
real[i] = t(real[i]);
imag[i] = t(imag[i]);
}
}
complex operator- (complex const & o) const
{
if (n != o.n)
{
throw std::runtime_error ("n && o.n must be equal");
}
complex result (n);
for (auto i = 0U; i < n; ++i)
{
result.real[i] = real[i] - o.real[i];
result.imag[i] = imag[i] - o.imag[i];
}
return result;
}
template<typename TFolder, typename T>
T fold (T && initial, TFolder && folder) const
{
auto s = std::forward<T> (initial);
for (auto i = 0U; i < n; ++i)
{
s = folder (s, real[i], imag[i]);
}
return s;
}
std::string to_string () const
{
if (owns_data)
{
char b [128] ;
std::string s ;
s.reserve (32*n);
s += "{";
std::sprintf (b, "n:%zd", n);
s += b ;
for (auto i = 0U; i < n; ++i)
{
std::sprintf (b, ", %.4f%+.4fi", real[i], imag[i]);
s += b ;
}
s += "}";
return s;
}
else
{
return "{n:NaN}";
}
}
std::size_t const n ;
private:
double * const realu ;
double * const imagu ;
bool owns_data ;
public:
double * const real ;
double * const imag ;
};
auto pi = M_PI ;
auto tau = 2.0 * pi;
inline std::tuple<double, double> rotation (double a)
{
auto r = std::cos (a);
auto i = std::sin (a);
return std::make_tuple (r, i);
}
complex dft (complex const & samples)
{
auto n = samples.n;
complex result (n);
auto nm = -tau / n;
for (auto i = 0U; i < n; ++i)
{
auto sr = 0.0;
auto si = 0.0;
for (auto j = 0U; j < n; ++j)
{
auto rot = rotation (i*j*nm);
auto r = samples.real[j];
auto i = samples.imag[j];
sr += std::get<0> (rot)*r;
si += std::get<1> (rot)*i;
}
result.real[i] = sr;
result.imag[i] = si;
}
return result;
}
bool is_power_of_two (std::size_t i)
{
return (i & (i - 1)) == 0;
}
int ilog2 (unsigned int i)
{
assert (is_power_of_two (i));
#ifdef _MSVC_LANG
unsigned long ri;
_BitScanReverse (&ri, i);
return ri;
#else
return 32 - __builtin_clz (i) - 1;
#endif
}
std::tuple<complex *, complex *> fft_loop (std::size_t n2, std::size_t s, std::size_t c, complex * f, complex * t)
{
if (c > 1)
{
auto c2 = c >> 1;
auto r = fft_loop (n2, s << 1, c2, f, t);
auto t = std::get<0> (r);
auto f = std::get<1> (r);
auto nm = -tau / c;
auto fr = f->real;
auto fi = f->imag;
auto tr = t->real;
auto ti = t->imag;
for (auto j = 0U; j < c2; ++j)
{
auto w = rotation (j*nm);
auto wr = std::get<0> (w);
auto wi = std::get<1> (w);
auto off = s*j;
auto off2 = off << 1;
for (auto i = 0U; i < s; ++i)
{
auto er = fr[i + off2 + 0];
auto ei = fi[i + off2 + 0];
auto or_ = fr[i + off2 + s];
auto oi = fi[i + off2 + s];
// a = w*o
auto ar = wr*or_ - wi*oi ;
auto ai = wr*oi + wi*or_;
tr[i + off + 0 ] = er + ar;
ti[i + off + 0 ] = ei + ai;
tr[i + off + n2] = er - ar;
ti[i + off + n2] = ei - ai;
}
}
return std::make_pair (f, t);
}
else
{
return std::make_tuple (t, f);
}
}
std::vector<std::vector<std::tuple<double, double>>> generate_rotations(std::size_t n)
{
std::vector<std::vector<std::tuple<double, double>>> result;
result.reserve (n);
auto c = 1U;
for (auto i = 0U; i < n; ++i)
{
std::vector<std::tuple<double, double>> r;
r.reserve (c);
for (auto j = 0U; j < c; ++j)
{
r.push_back (rotation ((-tau * j) / c));
}
result.push_back (std::move (r));
c *= 2;
}
return result;
}
auto const rotations = generate_rotations (16);
std::tuple<complex *, complex *> fft_loop_avx (
std::size_t n2
, std::tuple<double, double> const * rs
, std::size_t s
, std::size_t c
, complex * f
, complex * t
)
{
if (c > 2)
{
auto r = fft_loop_avx (n2, rs, s << 1, c >> 1, f, t);
auto t = std::get<0> (r);
auto f = std::get<1> (r);
auto fr = f->real;
auto fi = f->imag;
auto tr = t->real;
auto ti = t->imag;
if (s == 1)
{
for (auto j = 0U; j < n2; ++j)
{
auto w = rs[j];
auto wr = std::get<0> (w);
auto wi = std::get<1> (w);
auto er = fr[2*j + 0];
auto ei = fi[2*j + 0];
auto or_ = fr[2*j + s];
auto oi = fi[2*j + s];
// a = w*o
auto ar = wr*or_ - wi*oi ;
auto ai = wr*oi + wi*or_;
tr[j + 0 ] = er + ar;
ti[j + 0 ] = ei + ai;
tr[j + n2] = er - ar;
ti[j + n2] = ei - ai;
}
}
else if (s >= 4)
{
assert (s % 4 == 0);
for (auto j = 0U; j < n2; j += s)
{
auto off = j;
auto off2 = off << 1;
auto w = rs[off];
auto wr = _mm256_set1_pd (std::get<0> (w));
auto wi = _mm256_set1_pd (std::get<1> (w));
for (auto i = 0U; i < s; i += 4)
{
auto er = _mm256_load_pd (&fr[i + off2 + 0]);
auto ei = _mm256_load_pd (&fi[i + off2 + 0]);
auto or_ = _mm256_load_pd (&fr[i + off2 + s]);
auto oi = _mm256_load_pd (&fi[i + off2 + s]);
// a = w*o
auto ar = _mm256_sub_pd (_mm256_mul_pd (wr, or_), _mm256_mul_pd (wi, oi ));
auto ai = _mm256_add_pd (_mm256_mul_pd (wr, oi ), _mm256_mul_pd (wi, or_));
_mm256_store_pd(&tr[i + off + 0 ], _mm256_add_pd (er, ar));
_mm256_store_pd(&ti[i + off + 0 ], _mm256_add_pd (ei, ai));
_mm256_store_pd(&tr[i + off + n2], _mm256_sub_pd (er, ar));
_mm256_store_pd(&ti[i + off + n2], _mm256_sub_pd (ei, ai));
}
}
}
else
{
for (auto j = 0U; j < n2; j += s)
{
auto off = j;
auto off2 = off << 1;
auto w = rs[off];
auto wr = std::get<0> (w);
auto wi = std::get<1> (w);
for (auto i = 0U; i < s; ++i)
{
auto er = fr[i + off2 + 0];
auto ei = fi[i + off2 + 0];
auto or_ = fr[i + off2 + s];
auto oi = fi[i + off2 + s];
// a = w*o
auto ar = wr*or_ - wi*oi ;
auto ai = wr*oi + wi*or_;
tr[i + off + 0 ] = er + ar;
ti[i + off + 0 ] = ei + ai;
tr[i + off + n2] = er - ar;
ti[i + off + n2] = ei - ai;
}
}
}
return std::make_pair (f, t);
}
else if (c == 2)
{
auto fr = f->real;
auto fi = f->imag;
auto tr = t->real;
auto ti = t->imag;
if (s >= 4)
{
assert (s % 4 == 0);
for (auto i = 0U; i < s; i += 4)
{
auto er = _mm256_load_pd (&fr[i + 0]);
auto ei = _mm256_load_pd (&fi[i + 0]);
auto or_ = _mm256_load_pd (&fr[i + s]);
auto oi = _mm256_load_pd (&fi[i + s]);
// a = w*o
auto ar = or_;
auto ai = oi ;
_mm256_store_pd(&tr[i + 0 ], _mm256_add_pd (er, ar));
_mm256_store_pd(&ti[i + 0 ], _mm256_add_pd (ei, ai));
_mm256_store_pd(&tr[i + n2], _mm256_sub_pd (er, ar));
_mm256_store_pd(&ti[i + n2], _mm256_sub_pd (ei, ai));
}
}
else
{
for (auto i = 0U; i < s; ++i)
{
auto er = fr[i + 0];
auto ei = fi[i + 0];
auto or_= fr[i + s];
auto oi = fi[i + s];
// a = w*o
auto ar = or_;
auto ai = oi ;
tr[i + 0 ] = er + ar;
ti[i + 0 ] = ei + ai;
tr[i + n2] = er - ar;
ti[i + n2] = ei - ai;
}
}
return std::make_pair (f, t);
}
else
{
return std::make_tuple (t, f);
}
}
complex fft (complex const & samples)
{
auto n = samples.n;
if (n < 1)
{
throw std::runtime_error ("samples.n must be greater than 1");
}
if (!is_power_of_two (n))
{
throw std::runtime_error ("samples.n must be power of 2");
}
complex r0 (n);
complex r1 (n);
std::memcpy (r0.real, samples.real, sizeof(double)*n);
std::memcpy (r0.imag, samples.imag, sizeof(double)*n);
auto res = fft_loop (n >> 1, 1, n, &r0, &r1);
auto r = std::move (*std::get<1> (res));
return r;
}
complex fft_avx (complex const & samples)
{
auto n = samples.n;
if (n < 1)
{
throw std::runtime_error ("samples.n must be greater than 1");
}
if (!is_power_of_two (n))
{
throw std::runtime_error ("samples.n must be power of 2");
}
complex r0 (n);
complex r1 (n);
std::memcpy (r0.real, samples.real, sizeof(double)*n);
std::memcpy (r0.imag, samples.imag, sizeof(double)*n);
auto & rs = rotations.at(ilog2(n));
auto res = fft_loop_avx (n >> 1, &rs.front (), 1, n, &r0, &r1);
auto r = std::move (*std::get<1> (res));
return r;
}
}
int main ()
{
#ifdef _DEBUG
auto o = 10U ;
auto n = 32U ;
#else
auto o = 10000U ;
auto n = 1024U ;
#endif
complex samples (n);
{
auto nm = tau / n;
for (auto i = 0U; i < n; ++i)
{
auto a = i * nm;
auto v = std::cos (a) + std::cos (2.0*a);
samples.real[i] = 2.0 * v / n;
samples.imag[i] = 0;
}
}
if (n < 2048)
{
auto dft_result = dft (samples);
auto fft_result = fft (samples);
auto fft_avx_result = fft_avx (samples);
auto diff = fft_result - dft_result;
auto diff_avx = fft_avx_result - dft_result;
auto folder = [] (auto && s, auto && r, auto && i)
{
return s + r*r + i*i;
};
auto sum = diff.fold (0.0, folder);
auto sum_avx = diff_avx.fold (0.0, folder);
std::printf ("Data size : %d\n", n);
std::printf ("Sum : %g\n", sum / n);
std::printf ("Sum(avx) : %g\n", sum_avx / n);
if (n < 64)
{
samples.trim_inplace ();
dft_result.trim_inplace ();
fft_result.trim_inplace ();
fft_avx_result.trim_inplace ();
std::printf ("samples : %s\n" , samples.to_string ().c_str ());
std::printf ("dft_result: %s\n" , dft_result.to_string ().c_str ());
std::printf ("fft_result: %s\n" , fft_result.to_string ().c_str ());
std::printf ("fft_avx_result: %s\n" , fft_avx_result.to_string ().c_str ());
}
}
auto fft_timed = time_it (o , [&samples] { return fft (samples); });
auto fft_avx_timed = time_it (o , [&samples] { return fft_avx (samples); });;
auto fft_ms = std::get<0> (fft_timed) / (1.0*o);
auto fft_avx_ms = std::get<0> (fft_avx_timed) / (1.0*o);
std::printf ("FFT ms : %f\n", fft_ms );
std::printf ("FFT(avx) ms : %f\n", fft_avx_ms );
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment