Created
June 25, 2020 16:46
-
-
Save pesterie/84ad3622e0e6ad176d784759f32b8dee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<class T> | |
inline auto create_fftw_iodim64(xt::xarray<T> const& input, std::size_t axis) | |
-> std::conditional_t<(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>), fftw_iodim64, fftwf_iodim64> | |
{ | |
if constexpr(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>) | |
{ | |
return fftw_iodim64{static_cast<int>(input.shape(axis)), static_cast<int>(input.strides()[axis]), | |
static_cast<int>(input.strides()[axis])}; | |
} | |
else if constexpr(std::is_same_v<T, float> || std::is_same_v<T, std::complex<float>>) | |
{ | |
return fftwf_iodim64{static_cast<int>(input.shape(axis)), static_cast<int>(input.strides()[axis]), | |
static_cast<int>(input.strides()[axis])}; | |
} | |
else | |
{ | |
return {}; | |
} | |
} | |
inline auto create_non_target_axes(std::vector<std::size_t> const& target_axes, std::size_t dims) | |
-> std::vector<std::size_t> | |
{ | |
std::vector<std::size_t> non_target_axes(dims); | |
std::iota(non_target_axes.begin(), non_target_axes.end(), 0); | |
non_target_axes.erase( | |
std::remove_if(non_target_axes.begin(), non_target_axes.end(), | |
[&](auto x) { return find(target_axes.begin(), target_axes.end(), x) != target_axes.end(); }), | |
non_target_axes.end()); | |
return non_target_axes; | |
} | |
inline std::mutex& fftw_global_mutex() | |
{ | |
static std::mutex m; | |
return m; | |
} | |
template<typename T> | |
inline auto make_fftw_iodim64s(xt::xarray<T> const& input, std::vector<std::size_t> const& target_axes) | |
-> std::tuple<std::vector<std::conditional_t<(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>), fftw_iodim64, fftwf_iodim64>>, | |
std::vector<std::conditional_t<(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>), fftw_iodim64, fftwf_iodim64>>> | |
{ | |
using fftw_iodim64_type = std::conditional_t<(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>), fftw_iodim64, fftwf_iodim64>; | |
std::vector<std::size_t> non_target_axes(create_non_target_axes(target_axes, input.dimension())); | |
std::vector<fftw_iodim64_type> dims, howmany_dims; | |
for(const auto& axe: target_axes) | |
{ | |
dims.push_back(create_fftw_iodim64(input, axe)); | |
} | |
for(const auto& axe: non_target_axes) | |
{ | |
howmany_dims.push_back(create_fftw_iodim64(input, axe)); | |
} | |
return std::make_tuple(dims, howmany_dims); | |
} | |
inline auto execute_rfftw(std::vector<fftw_iodim64> const& dims, std::vector<fftw_iodim64> const& howmany_dims, double* input, | |
fftw_complex* output) -> void | |
{ | |
fftw_plan many_plan = fftw_plan_guru64_dft_r2c(int(dims.size()), dims.data(), int(howmany_dims.size()), | |
howmany_dims.data(), input, output, FFTW_ESTIMATE); | |
fftw_execute_dft_r2c(many_plan, input, output); | |
fftw_destroy_plan(many_plan); | |
} | |
inline auto execute_irfftw(std::vector<fftw_iodim64> const& dims, std::vector<fftw_iodim64> const& howmany_dims, fftw_complex* input, | |
double* output) -> void | |
{ | |
fftw_plan many_plan = | |
fftw_plan_guru64_dft_c2r(int(dims.size()), dims.data(), int(howmany_dims.size()), | |
howmany_dims.data(), input, output, FFTW_ESTIMATE); | |
fftw_execute_dft_c2r(many_plan, input, output); | |
fftw_destroy_plan(many_plan); | |
} | |
inline auto execute_rfftw(std::vector<fftw_iodim64> const& dims, std::vector<fftw_iodim64> const& howmany_dims, float* input, | |
fftwf_complex* output) -> void | |
{ | |
fftwf_plan many_plan = fftwf_plan_guru64_dft_r2c( | |
int(dims.size()), dims.data(), int(howmany_dims.size()), | |
howmany_dims.data(), input, output, FFTW_ESTIMATE); | |
fftwf_execute_dft_r2c(many_plan, input, output); | |
fftwf_destroy_plan(many_plan); | |
} | |
inline auto execute_irfftw(std::vector<fftwf_iodim64> const& dims, std::vector<fftwf_iodim64> const& howmany_dims, | |
fftwf_complex* input, float* output) -> void | |
{ | |
fftwf_plan many_plan = fftwf_plan_guru64_dft_c2r( | |
int(dims.size()), dims.data(), int(howmany_dims.size()), | |
howmany_dims.data(), input, output, FFTW_ESTIMATE); | |
fftwf_execute_dft_c2r(many_plan, input, output); | |
fftwf_destroy_plan(many_plan); | |
} | |
inline auto execute_fftw(std::vector<fftw_iodim64> const& dims, std::vector<fftw_iodim64> const& howmany_dims, fftw_complex* input, | |
fftw_complex* output, int sign) -> void | |
{ | |
fftw_plan many_plan = fftw_plan_guru64_dft(static_cast<int>(dims.size()), dims.data(), static_cast<int>(howmany_dims.size()), | |
howmany_dims.data(), input, output, sign, FFTW_ESTIMATE); | |
if(many_plan == NULL) std::cout << "no plan created\n"; | |
fftw_execute_dft(many_plan, input, output); | |
fftw_destroy_plan(many_plan); | |
} | |
inline auto execute_fftw(std::vector<fftwf_iodim64> const& dims, std::vector<fftwf_iodim64> const& howmany_dims, fftwf_complex* input, | |
fftwf_complex* output, int sign) -> void | |
{ | |
fftwf_plan many_plan = | |
fftwf_plan_guru64_dft(int(dims.size()), dims.data(), int(howmany_dims.size()), | |
howmany_dims.data(), input, output, sign, FFTW_ESTIMATE); | |
fftwf_execute_dft(many_plan, input, output); | |
fftwf_destroy_plan(many_plan); | |
} | |
template<typename T> | |
inline auto rfft(xt::xarray<T> const& input, std::vector<std::size_t> const& axis) -> xt::xarray<std::complex<T>> | |
{ | |
auto fft_out = xt::xarray<std::complex<T>>::from_shape(input.shape()); | |
using fftw_iodim64_type = std::conditional_t<std::is_same_v<T, double>, fftw_iodim64, fftwf_iodim64>; | |
std::vector<fftw_iodim64_type> dims; | |
std::vector<fftw_iodim64_type> howmany_dims; | |
std::tie(dims, howmany_dims) = make_fftw_iodim64s(input, axis); | |
execute_rfftw(dims, howmany_dims, const_cast<T*>(input.data()), | |
reinterpret_cast<fftw_complex*>(fft_out.data())); | |
return fft_out; | |
} | |
template<typename T> | |
inline auto fft(xt::xarray<std::complex<T>> const& input, std::vector<std::size_t> const& axis, int sign) | |
-> xt::xarray<std::complex<T>> | |
{ | |
auto fft_out = xt::xarray<std::complex<T>>::from_shape(input.shape()); | |
using fftw_iodim64_type = std::conditional_t<std::is_same_v<T, double>, fftw_iodim64, fftwf_iodim64>; | |
using fftw_complex_type = std::conditional_t<std::is_same_v<T, double>, fftw_complex, fftwf_complex>; | |
std::vector<fftw_iodim64_type> dims; | |
std::vector<fftw_iodim64_type> howmany_dims; | |
std::tie(dims, howmany_dims) = make_fftw_iodim64s(input, axis); | |
execute_fftw(dims, howmany_dims, reinterpret_cast<fftw_complex_type*>(const_cast<std::complex<T>*>(input.data())), | |
reinterpret_cast<fftw_complex_type*>(fft_out.data()), sign); | |
return fft_out; | |
} | |
template<typename T> | |
inline auto fft(xt::xarray<std::complex<T>> const& input, xt::xarray<std::complex<T>>& output, | |
std::vector<std::size_t> const& axis, int sign) -> void | |
{ | |
using fftw_iodim64_type = std::conditional_t<std::is_same_v<T, double>, fftw_iodim64, fftwf_iodim64>; | |
using fftw_complex_type = std::conditional_t<std::is_same_v<T, double>, fftw_complex, fftwf_complex>; | |
std::vector<fftw_iodim64_type> dims; | |
std::vector<fftw_iodim64_type> howmany_dims; | |
std::tie(dims, howmany_dims) = make_fftw_iodim64s(input, axis); | |
execute_fftw(dims, howmany_dims, reinterpret_cast<fftw_complex_type*>(const_cast<std::complex<T>*>(input.data())), | |
reinterpret_cast<fftw_complex_type*>(output.data()), sign); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment