Skip to content

Instantly share code, notes, and snippets.

@pesterie
Created June 25, 2020 16:46
Show Gist options
  • Save pesterie/84ad3622e0e6ad176d784759f32b8dee to your computer and use it in GitHub Desktop.
Save pesterie/84ad3622e0e6ad176d784759f32b8dee to your computer and use it in GitHub Desktop.
template<class T>
inline auto create_fftw_iodim64(xt::xarray<T> const& input, std::size_t axis)
-> std::conditional_t<(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>), fftw_iodim64, fftwf_iodim64>
{
if constexpr(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>)
{
return fftw_iodim64{static_cast<int>(input.shape(axis)), static_cast<int>(input.strides()[axis]),
static_cast<int>(input.strides()[axis])};
}
else if constexpr(std::is_same_v<T, float> || std::is_same_v<T, std::complex<float>>)
{
return fftwf_iodim64{static_cast<int>(input.shape(axis)), static_cast<int>(input.strides()[axis]),
static_cast<int>(input.strides()[axis])};
}
else
{
return {};
}
}
inline auto create_non_target_axes(std::vector<std::size_t> const& target_axes, std::size_t dims)
-> std::vector<std::size_t>
{
std::vector<std::size_t> non_target_axes(dims);
std::iota(non_target_axes.begin(), non_target_axes.end(), 0);
non_target_axes.erase(
std::remove_if(non_target_axes.begin(), non_target_axes.end(),
[&](auto x) { return find(target_axes.begin(), target_axes.end(), x) != target_axes.end(); }),
non_target_axes.end());
return non_target_axes;
}
inline std::mutex& fftw_global_mutex()
{
static std::mutex m;
return m;
}
template<typename T>
inline auto make_fftw_iodim64s(xt::xarray<T> const& input, std::vector<std::size_t> const& target_axes)
-> std::tuple<std::vector<std::conditional_t<(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>), fftw_iodim64, fftwf_iodim64>>,
std::vector<std::conditional_t<(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>), fftw_iodim64, fftwf_iodim64>>>
{
using fftw_iodim64_type = std::conditional_t<(std::is_same_v<T, double> || std::is_same_v<T, std::complex<double>>), fftw_iodim64, fftwf_iodim64>;
std::vector<std::size_t> non_target_axes(create_non_target_axes(target_axes, input.dimension()));
std::vector<fftw_iodim64_type> dims, howmany_dims;
for(const auto& axe: target_axes)
{
dims.push_back(create_fftw_iodim64(input, axe));
}
for(const auto& axe: non_target_axes)
{
howmany_dims.push_back(create_fftw_iodim64(input, axe));
}
return std::make_tuple(dims, howmany_dims);
}
inline auto execute_rfftw(std::vector<fftw_iodim64> const& dims, std::vector<fftw_iodim64> const& howmany_dims, double* input,
fftw_complex* output) -> void
{
fftw_plan many_plan = fftw_plan_guru64_dft_r2c(int(dims.size()), dims.data(), int(howmany_dims.size()),
howmany_dims.data(), input, output, FFTW_ESTIMATE);
fftw_execute_dft_r2c(many_plan, input, output);
fftw_destroy_plan(many_plan);
}
inline auto execute_irfftw(std::vector<fftw_iodim64> const& dims, std::vector<fftw_iodim64> const& howmany_dims, fftw_complex* input,
double* output) -> void
{
fftw_plan many_plan =
fftw_plan_guru64_dft_c2r(int(dims.size()), dims.data(), int(howmany_dims.size()),
howmany_dims.data(), input, output, FFTW_ESTIMATE);
fftw_execute_dft_c2r(many_plan, input, output);
fftw_destroy_plan(many_plan);
}
inline auto execute_rfftw(std::vector<fftw_iodim64> const& dims, std::vector<fftw_iodim64> const& howmany_dims, float* input,
fftwf_complex* output) -> void
{
fftwf_plan many_plan = fftwf_plan_guru64_dft_r2c(
int(dims.size()), dims.data(), int(howmany_dims.size()),
howmany_dims.data(), input, output, FFTW_ESTIMATE);
fftwf_execute_dft_r2c(many_plan, input, output);
fftwf_destroy_plan(many_plan);
}
inline auto execute_irfftw(std::vector<fftwf_iodim64> const& dims, std::vector<fftwf_iodim64> const& howmany_dims,
fftwf_complex* input, float* output) -> void
{
fftwf_plan many_plan = fftwf_plan_guru64_dft_c2r(
int(dims.size()), dims.data(), int(howmany_dims.size()),
howmany_dims.data(), input, output, FFTW_ESTIMATE);
fftwf_execute_dft_c2r(many_plan, input, output);
fftwf_destroy_plan(many_plan);
}
inline auto execute_fftw(std::vector<fftw_iodim64> const& dims, std::vector<fftw_iodim64> const& howmany_dims, fftw_complex* input,
fftw_complex* output, int sign) -> void
{
fftw_plan many_plan = fftw_plan_guru64_dft(static_cast<int>(dims.size()), dims.data(), static_cast<int>(howmany_dims.size()),
howmany_dims.data(), input, output, sign, FFTW_ESTIMATE);
if(many_plan == NULL) std::cout << "no plan created\n";
fftw_execute_dft(many_plan, input, output);
fftw_destroy_plan(many_plan);
}
inline auto execute_fftw(std::vector<fftwf_iodim64> const& dims, std::vector<fftwf_iodim64> const& howmany_dims, fftwf_complex* input,
fftwf_complex* output, int sign) -> void
{
fftwf_plan many_plan =
fftwf_plan_guru64_dft(int(dims.size()), dims.data(), int(howmany_dims.size()),
howmany_dims.data(), input, output, sign, FFTW_ESTIMATE);
fftwf_execute_dft(many_plan, input, output);
fftwf_destroy_plan(many_plan);
}
template<typename T>
inline auto rfft(xt::xarray<T> const& input, std::vector<std::size_t> const& axis) -> xt::xarray<std::complex<T>>
{
auto fft_out = xt::xarray<std::complex<T>>::from_shape(input.shape());
using fftw_iodim64_type = std::conditional_t<std::is_same_v<T, double>, fftw_iodim64, fftwf_iodim64>;
std::vector<fftw_iodim64_type> dims;
std::vector<fftw_iodim64_type> howmany_dims;
std::tie(dims, howmany_dims) = make_fftw_iodim64s(input, axis);
execute_rfftw(dims, howmany_dims, const_cast<T*>(input.data()),
reinterpret_cast<fftw_complex*>(fft_out.data()));
return fft_out;
}
template<typename T>
inline auto fft(xt::xarray<std::complex<T>> const& input, std::vector<std::size_t> const& axis, int sign)
-> xt::xarray<std::complex<T>>
{
auto fft_out = xt::xarray<std::complex<T>>::from_shape(input.shape());
using fftw_iodim64_type = std::conditional_t<std::is_same_v<T, double>, fftw_iodim64, fftwf_iodim64>;
using fftw_complex_type = std::conditional_t<std::is_same_v<T, double>, fftw_complex, fftwf_complex>;
std::vector<fftw_iodim64_type> dims;
std::vector<fftw_iodim64_type> howmany_dims;
std::tie(dims, howmany_dims) = make_fftw_iodim64s(input, axis);
execute_fftw(dims, howmany_dims, reinterpret_cast<fftw_complex_type*>(const_cast<std::complex<T>*>(input.data())),
reinterpret_cast<fftw_complex_type*>(fft_out.data()), sign);
return fft_out;
}
template<typename T>
inline auto fft(xt::xarray<std::complex<T>> const& input, xt::xarray<std::complex<T>>& output,
std::vector<std::size_t> const& axis, int sign) -> void
{
using fftw_iodim64_type = std::conditional_t<std::is_same_v<T, double>, fftw_iodim64, fftwf_iodim64>;
using fftw_complex_type = std::conditional_t<std::is_same_v<T, double>, fftw_complex, fftwf_complex>;
std::vector<fftw_iodim64_type> dims;
std::vector<fftw_iodim64_type> howmany_dims;
std::tie(dims, howmany_dims) = make_fftw_iodim64s(input, axis);
execute_fftw(dims, howmany_dims, reinterpret_cast<fftw_complex_type*>(const_cast<std::complex<T>*>(input.data())),
reinterpret_cast<fftw_complex_type*>(output.data()), sign);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment