Skip to content

Instantly share code, notes, and snippets.

@nbecker
Created June 11, 2018 15:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nbecker/251112eb10e8effbdeb0545eabef6b31 to your computer and use it in GitHub Desktop.
Save nbecker/251112eb10e8effbdeb0545eabef6b31 to your computer and use it in GitHub Desktop.
#include "pybind11/numpy.h" // vectorize
#include "pybind11/pybind11.h"
#include "pybind11/operators.h"
#include "numpy/arrayobject.h"
#include "ndarray/pybind11.h"
#include <inttypes.h>
#include <stdexcept>
#include "xtensor/xarray.hpp"
#include "xtensor/xtensor.hpp"
#include "xtensor/xcontainer.hpp"
#include "xtensor-python/pyarray.hpp"
#include "xtensor-python/pyvectorize.hpp"
#include "xtensor/xview.hpp"
#include "xtensor-blas/xlinalg.hpp"
#include "xtensor/xnorm.hpp"
#include "xtensor/xeval.hpp"
namespace py = pybind11;
#include <complex>
#include <algorithm>
//#include "apply.hpp"
template<typename flt_t>
struct mag_sqr {
typedef flt_t argument_type;
typedef flt_t result_type;
flt_t operator()(flt_t x) const { return x * x; }
};
typedef std::complex<double> complex_t;
typedef std::complex<float> complex64_t;
template<typename flt_t>
struct mag_sqr<std::complex<flt_t> > {
typedef std::complex<flt_t> argument_type;
typedef flt_t result_type;
flt_t operator()(std::complex<flt_t> x) const { return real(x) * real(x) + imag(x) * imag(x); }
};
// template <typename T, int N, int C>
// nd::Array<typename mag_sqr<T>::result_type,N,N> do_mag_sqr_flat (nd::Array<T,N,C> const& in) {
// ndarray::Array<typename boost::remove_const<T>::type,1,1> flat_in = ndarray::flatten<1>(in);
// ndarray::Array<typename mag_sqr<T>::result_type,N,N> out = ndarray::allocate(in.getShape());
// ndarray::Array<typename mag_sqr<T>::result_type,1,1> flat_out = ndarray::flatten<1>(out);
// int size = flat_in.template getSize<0>();
// for (int n=0; n < size; ++n) {
// flat_out[n] = mag_sqr<T>() (flat_in[n]);
// }
// return out;
// }
// template<typename flt_t>
// flt_t do_mag_sqr (flt_t x) { return x * x; }
// template<typename flt_t>
// flt_t do_mag_sqr (std::complex<flt_t> x) { return real(x) * real(x) + imag(x) * imag(x); }
// template<typename flt_t>
// auto mag_sqr_vec (py::array_t<std::complex<flt_t>> a) {
// return py::vectorize([](std::complex<flt_t> x) { return real(x) * real(x) + imag(x) * imag(x); })(a);
// }
// template<typename flt_t>
// auto mag_sqr_vec (py::array_t<flt_t> a) {
// return py::vectorize([](flt_t x) { return x * x; })(a);
// }
PYBIND11_PLUGIN (mag_sqr) {
if (_import_array() < 0) {
PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import");
return nullptr;
}
py::module m("mag_sqr", "pybind11 example plugin");
py::object float32_obj = py::module::import("numpy").attr("float32");
py::object complex64_obj = py::module::import("numpy").attr("complex64");
m.def("mag_sqr", [](py::array_t<double> a) {
// print ("double");
return py::vectorize([](double x) {return x * x;})(a);
},
py::arg("in").noconvert()
);
m.def("mag_sqr", [](py::array_t<float> a) {
// print ("float");
return py::vectorize([](float x) {return x * x;})(a);
},
py::arg("in").noconvert()
);
m.def("mag_sqr", [](py::array_t<std::complex<double>> a) {
// print ("complex");
return py::vectorize([](std::complex<double> x) {return real(x) * real(x) + imag(x) * imag(x);})(a);
},
py::arg("in").noconvert()
);
m.def("mag_sqr", [](py::array_t<std::complex<float>> a) {
return py::vectorize([](std::complex<float> x) {return real(x) * real(x) + imag(x) * imag(x);})(a);
},
py::arg("in").noconvert()
);
m.def("mag_sqr", [](double x) { return x * x; });
m.def("mag_sqr", [](float x) { return x * x; });
m.def("mag_sqr", [](std::complex<double> x) { return real(x) * real(x) + imag(x) * imag(x);});
m.def("mag_sqr", [](std::complex<float> x) { return real(x) * real(x) + imag(x) * imag(x);});
m.def ("xt_norm_2", [](xt::pyarray<std::complex<double>> x) { return xt::norm_l2 (x)(); },
py::arg("in").noconvert()
);
m.def("xt_norm_2", [](xt::pyarray<double> x) { return xt::norm_l2 (x)(); },
py::arg("in").noconvert()
);
m.def("norm_2", [](xt::pyarray<std::complex<double>> x) { return xt::linalg::norm (x, 2); },
py::arg("in").noconvert()
);
m.def("norm_2", [](xt::pyarray<double> x) { return xt::linalg::norm (x, 2); },
py::arg("in").noconvert()
);
m.def("norm_2", [](double x) { return xt::linalg::norm(xt::pyarray<double>(x), 2); });
m.def("norm_2", [](std::complex<double> x) { return xt::linalg::norm(xt::pyarray<std::complex<double>>(x), 2); });
// m.def("norm_2", [](xt::pyarray<double> x) { return xt::linalg::norm (x); });
// using py::print;
// m.def("mag_sqr", [complex64_obj, float32_obj](py::object a) {
// if (py::isinstance<py::array_t<complex_t>>(a) or py::isinstance<complex_t>(a)) {
// print ("complex");
// //return mag_sqr_vec<complex_t>(*static_cast<py::array_t<complex_t>*>(&a));
// return py::vectorize([](complex_t x) { return real(x) * real(x) + imag(x) * imag(x); })(a);
// }
// else if (py::isinstance<py::array_t<double>>(a) or py::isinstance<py::float_>(a)) {
// print ("double");
// return py::vectorize([](double x) { return x * x; })(a);
// }
// else if (py::isinstance<py::array_t<complex64_t>>(a) or PyObject_IsInstance(a.ptr(), complex64_obj.ptr())) {
// print ("complex64");
// return py::vectorize([](complex64_t x) { return real(x) * real(x) + imag(x) * imag(x); })(a);
// }
// else if (py::isinstance<py::array_t<float>>(a) or PyObject_IsInstance(a.ptr(), float32_obj.ptr())) {
// print ("float");
// return py::vectorize([](float x) { return x * x; })(a);
// }
// else
// throw py::type_error("mag_sqr unhandled type");
// });
//m.def ("mag_sqr", &apply_1d<mag_sqr,complex_t>);
// m.def ("mag_sqr", &apply_1d<mag_sqr,complex64_t>);
// m.def ("mag_sqr", &apply_1d<mag_sqr,double>);
// m.def ("mag_sqr", &apply_1d<mag_sqr,float>);
// m.def ("mag_sqr", &apply_2d<mag_sqr,complex_t>);
// m.def ("mag_sqr", &apply_2d<mag_sqr,complex64_t>);
// m.def ("mag_sqr", &apply_2d<mag_sqr,double>);
// m.def ("mag_sqr", &apply_2d<mag_sqr,float>);
// m.def ("mag_sqr", &apply_3d<mag_sqr,complex_t>);
// m.def ("mag_sqr", &apply_3d<mag_sqr,complex64_t>);
// m.def ("mag_sqr", &apply_3d<mag_sqr,double>);
// m.def ("mag_sqr", &apply_3d<mag_sqr,float>);
// m.def ("mag_sqr", &apply_scalar<mag_sqr,double>);
// m.def ("mag_sqr", &apply_scalar<mag_sqr,complex_t>);
// m.def ("mag_sqr_flat", &do_mag_sqr_flat<double,1,1>);
// m.def ("mag_sqr_flat", &do_mag_sqr_flat<complex_t,1,1> );
// // m.def ("norm_2", [](py::object o) {
// // if (py::isinstance<py::array_t<double>>(o)) {
// // py::print("double");
// // return py::cast<nd::Array<double,1>>(o).asEigen().norm();
// // }
// // if (py::isinstance<py::array_t<complex_t>>(o))
// // return py::cast<nd::Array<complex_t,1>>(o).asEigen().norm();
// // else
// // throw py::type_error("norm_2 unhandled type");
// // });
// m.def ("norm_2", [](nd::Array<double,1> in) {
// // py::print("double");
// return in.asEigen().norm();
// });
// m.def ("norm_2", [](nd::Array<float,1> in) {
// // py::print("float");
// return in.asEigen().norm();
// });
// m.def ("norm_2", [](nd::Array<complex_t,1> in) {
// // py::print("complex");
// return in.asEigen().norm();
// });
return m.ptr();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment