Created
June 11, 2018 15:14
-
-
Save nbecker/251112eb10e8effbdeb0545eabef6b31 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "pybind11/numpy.h" // vectorize | |
#include "pybind11/pybind11.h" | |
#include "pybind11/operators.h" | |
#include "numpy/arrayobject.h" | |
#include "ndarray/pybind11.h" | |
#include <inttypes.h> | |
#include <stdexcept> | |
#include "xtensor/xarray.hpp" | |
#include "xtensor/xtensor.hpp" | |
#include "xtensor/xcontainer.hpp" | |
#include "xtensor-python/pyarray.hpp" | |
#include "xtensor-python/pyvectorize.hpp" | |
#include "xtensor/xview.hpp" | |
#include "xtensor-blas/xlinalg.hpp" | |
#include "xtensor/xnorm.hpp" | |
#include "xtensor/xeval.hpp" | |
namespace py = pybind11; | |
#include <complex> | |
#include <algorithm> | |
//#include "apply.hpp" | |
template<typename flt_t> | |
struct mag_sqr { | |
typedef flt_t argument_type; | |
typedef flt_t result_type; | |
flt_t operator()(flt_t x) const { return x * x; } | |
}; | |
typedef std::complex<double> complex_t; | |
typedef std::complex<float> complex64_t; | |
template<typename flt_t> | |
struct mag_sqr<std::complex<flt_t> > { | |
typedef std::complex<flt_t> argument_type; | |
typedef flt_t result_type; | |
flt_t operator()(std::complex<flt_t> x) const { return real(x) * real(x) + imag(x) * imag(x); } | |
}; | |
// template <typename T, int N, int C> | |
// nd::Array<typename mag_sqr<T>::result_type,N,N> do_mag_sqr_flat (nd::Array<T,N,C> const& in) { | |
// ndarray::Array<typename boost::remove_const<T>::type,1,1> flat_in = ndarray::flatten<1>(in); | |
// ndarray::Array<typename mag_sqr<T>::result_type,N,N> out = ndarray::allocate(in.getShape()); | |
// ndarray::Array<typename mag_sqr<T>::result_type,1,1> flat_out = ndarray::flatten<1>(out); | |
// int size = flat_in.template getSize<0>(); | |
// for (int n=0; n < size; ++n) { | |
// flat_out[n] = mag_sqr<T>() (flat_in[n]); | |
// } | |
// return out; | |
// } | |
// template<typename flt_t> | |
// flt_t do_mag_sqr (flt_t x) { return x * x; } | |
// template<typename flt_t> | |
// flt_t do_mag_sqr (std::complex<flt_t> x) { return real(x) * real(x) + imag(x) * imag(x); } | |
// template<typename flt_t> | |
// auto mag_sqr_vec (py::array_t<std::complex<flt_t>> a) { | |
// return py::vectorize([](std::complex<flt_t> x) { return real(x) * real(x) + imag(x) * imag(x); })(a); | |
// } | |
// template<typename flt_t> | |
// auto mag_sqr_vec (py::array_t<flt_t> a) { | |
// return py::vectorize([](flt_t x) { return x * x; })(a); | |
// } | |
PYBIND11_PLUGIN (mag_sqr) { | |
if (_import_array() < 0) { | |
PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import"); | |
return nullptr; | |
} | |
py::module m("mag_sqr", "pybind11 example plugin"); | |
py::object float32_obj = py::module::import("numpy").attr("float32"); | |
py::object complex64_obj = py::module::import("numpy").attr("complex64"); | |
m.def("mag_sqr", [](py::array_t<double> a) { | |
// print ("double"); | |
return py::vectorize([](double x) {return x * x;})(a); | |
}, | |
py::arg("in").noconvert() | |
); | |
m.def("mag_sqr", [](py::array_t<float> a) { | |
// print ("float"); | |
return py::vectorize([](float x) {return x * x;})(a); | |
}, | |
py::arg("in").noconvert() | |
); | |
m.def("mag_sqr", [](py::array_t<std::complex<double>> a) { | |
// print ("complex"); | |
return py::vectorize([](std::complex<double> x) {return real(x) * real(x) + imag(x) * imag(x);})(a); | |
}, | |
py::arg("in").noconvert() | |
); | |
m.def("mag_sqr", [](py::array_t<std::complex<float>> a) { | |
return py::vectorize([](std::complex<float> x) {return real(x) * real(x) + imag(x) * imag(x);})(a); | |
}, | |
py::arg("in").noconvert() | |
); | |
m.def("mag_sqr", [](double x) { return x * x; }); | |
m.def("mag_sqr", [](float x) { return x * x; }); | |
m.def("mag_sqr", [](std::complex<double> x) { return real(x) * real(x) + imag(x) * imag(x);}); | |
m.def("mag_sqr", [](std::complex<float> x) { return real(x) * real(x) + imag(x) * imag(x);}); | |
m.def ("xt_norm_2", [](xt::pyarray<std::complex<double>> x) { return xt::norm_l2 (x)(); }, | |
py::arg("in").noconvert() | |
); | |
m.def("xt_norm_2", [](xt::pyarray<double> x) { return xt::norm_l2 (x)(); }, | |
py::arg("in").noconvert() | |
); | |
m.def("norm_2", [](xt::pyarray<std::complex<double>> x) { return xt::linalg::norm (x, 2); }, | |
py::arg("in").noconvert() | |
); | |
m.def("norm_2", [](xt::pyarray<double> x) { return xt::linalg::norm (x, 2); }, | |
py::arg("in").noconvert() | |
); | |
m.def("norm_2", [](double x) { return xt::linalg::norm(xt::pyarray<double>(x), 2); }); | |
m.def("norm_2", [](std::complex<double> x) { return xt::linalg::norm(xt::pyarray<std::complex<double>>(x), 2); }); | |
// m.def("norm_2", [](xt::pyarray<double> x) { return xt::linalg::norm (x); }); | |
// using py::print; | |
// m.def("mag_sqr", [complex64_obj, float32_obj](py::object a) { | |
// if (py::isinstance<py::array_t<complex_t>>(a) or py::isinstance<complex_t>(a)) { | |
// print ("complex"); | |
// //return mag_sqr_vec<complex_t>(*static_cast<py::array_t<complex_t>*>(&a)); | |
// return py::vectorize([](complex_t x) { return real(x) * real(x) + imag(x) * imag(x); })(a); | |
// } | |
// else if (py::isinstance<py::array_t<double>>(a) or py::isinstance<py::float_>(a)) { | |
// print ("double"); | |
// return py::vectorize([](double x) { return x * x; })(a); | |
// } | |
// else if (py::isinstance<py::array_t<complex64_t>>(a) or PyObject_IsInstance(a.ptr(), complex64_obj.ptr())) { | |
// print ("complex64"); | |
// return py::vectorize([](complex64_t x) { return real(x) * real(x) + imag(x) * imag(x); })(a); | |
// } | |
// else if (py::isinstance<py::array_t<float>>(a) or PyObject_IsInstance(a.ptr(), float32_obj.ptr())) { | |
// print ("float"); | |
// return py::vectorize([](float x) { return x * x; })(a); | |
// } | |
// else | |
// throw py::type_error("mag_sqr unhandled type"); | |
// }); | |
//m.def ("mag_sqr", &apply_1d<mag_sqr,complex_t>); | |
// m.def ("mag_sqr", &apply_1d<mag_sqr,complex64_t>); | |
// m.def ("mag_sqr", &apply_1d<mag_sqr,double>); | |
// m.def ("mag_sqr", &apply_1d<mag_sqr,float>); | |
// m.def ("mag_sqr", &apply_2d<mag_sqr,complex_t>); | |
// m.def ("mag_sqr", &apply_2d<mag_sqr,complex64_t>); | |
// m.def ("mag_sqr", &apply_2d<mag_sqr,double>); | |
// m.def ("mag_sqr", &apply_2d<mag_sqr,float>); | |
// m.def ("mag_sqr", &apply_3d<mag_sqr,complex_t>); | |
// m.def ("mag_sqr", &apply_3d<mag_sqr,complex64_t>); | |
// m.def ("mag_sqr", &apply_3d<mag_sqr,double>); | |
// m.def ("mag_sqr", &apply_3d<mag_sqr,float>); | |
// m.def ("mag_sqr", &apply_scalar<mag_sqr,double>); | |
// m.def ("mag_sqr", &apply_scalar<mag_sqr,complex_t>); | |
// m.def ("mag_sqr_flat", &do_mag_sqr_flat<double,1,1>); | |
// m.def ("mag_sqr_flat", &do_mag_sqr_flat<complex_t,1,1> ); | |
// // m.def ("norm_2", [](py::object o) { | |
// // if (py::isinstance<py::array_t<double>>(o)) { | |
// // py::print("double"); | |
// // return py::cast<nd::Array<double,1>>(o).asEigen().norm(); | |
// // } | |
// // if (py::isinstance<py::array_t<complex_t>>(o)) | |
// // return py::cast<nd::Array<complex_t,1>>(o).asEigen().norm(); | |
// // else | |
// // throw py::type_error("norm_2 unhandled type"); | |
// // }); | |
// m.def ("norm_2", [](nd::Array<double,1> in) { | |
// // py::print("double"); | |
// return in.asEigen().norm(); | |
// }); | |
// m.def ("norm_2", [](nd::Array<float,1> in) { | |
// // py::print("float"); | |
// return in.asEigen().norm(); | |
// }); | |
// m.def ("norm_2", [](nd::Array<complex_t,1> in) { | |
// // py::print("complex"); | |
// return in.asEigen().norm(); | |
// }); | |
return m.ptr(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment