Created
October 9, 2019 03:35
-
-
Save sizmailov/d2a456329ad79db04cb243e2fba60657 to your computer and use it in GitHub Desktop.
`py::vectorize` with void return type
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0.0 | |
6.0 | |
9.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from test_module import Sum | |
s = Sum() | |
print(s.value) | |
s.add([1,2,3]) | |
print(s.value) | |
s.add(3) | |
print(s.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <pybind11/pybind11.h> | |
#include <pybind11/numpy.h> | |
#include <array> | |
namespace py = pybind11; | |
namespace pybind11 { | |
namespace detail { | |
template<typename Func, typename... Args> | |
struct vectorize_helper<Func, void, Args...> { | |
private: | |
static constexpr size_t N = sizeof...(Args); | |
static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...); | |
static_assert(NVectorized>=1, | |
"pybind11::vectorize(...) requires a function with at least one vectorizable argument"); | |
public: | |
template<typename T> | |
explicit vectorize_helper(T&& f) | |
: f(std::forward<T>(f)) { } | |
void operator()(typename vectorize_arg<Args>::type... args) { | |
run(args..., | |
make_index_sequence<N>(), | |
select_indices<vectorize_arg<Args>::vectorize...>(), | |
make_index_sequence<NVectorized>()); | |
} | |
private: | |
remove_reference_t<Func> f; | |
// Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling with "/permissive-" flag | |
// when arg_call_types is manually inlined. | |
using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>; | |
template<size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type; | |
// Runs a vectorized function given arguments tuple and three index sequences: | |
// - Index is the full set of 0 ... (N-1) argument indices; | |
// - VIndex is the subset of argument indices with vectorized parameters, letting us access | |
// vectorized arguments (anything not in this sequence is passed through) | |
// - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that | |
// we can store vectorized buffer_infos in an array (argument VIndex has its buffer at | |
// index BIndex in the array). | |
template<size_t... Index, size_t... VIndex, size_t... BIndex> void run( | |
typename vectorize_arg<Args>::type& ...args, | |
index_sequence<Index...> i_seq, index_sequence<VIndex...> vi_seq, index_sequence<BIndex...> bi_seq) { | |
// Pointers to values the function was called with; the vectorized ones set here will start | |
// out as array_t<T> pointers, but they will be changed them to T pointers before we make | |
// call the wrapped function. Non-vectorized pointers are left as-is. | |
std::array<void*, N> params{{&args...}}; | |
// The array of `buffer_info`s of vectorized arguments: | |
std::array<buffer_info, NVectorized> buffers{{reinterpret_cast<array*>(params[VIndex])->request()...}}; | |
/* Determine dimensions parameters of output array */ | |
ssize_t nd = 0; | |
std::vector<ssize_t> shape(0); | |
auto trivial = broadcast(buffers, nd, shape); | |
size_t ndim = (size_t) nd; | |
size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>()); | |
// If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e. | |
// not wrapped in an array). | |
if (size==1 && ndim==0) { | |
PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr); | |
return f(*reinterpret_cast<param_n_t<Index>*>(params[Index])...); | |
} | |
// array_t<Return> result; | |
// if (trivial==broadcast_trivial::f_trivial) result = array_t<Return, array::f_style>(shape); | |
// else result = array_t<Return>(shape); | |
if (size==0) return; | |
/* Call the function */ | |
if (trivial==broadcast_trivial::non_trivial) | |
apply_broadcast(buffers, params, shape, size, i_seq, vi_seq, bi_seq); | |
else | |
apply_trivial(buffers, params, shape, size, i_seq, vi_seq, bi_seq); | |
} | |
template<size_t... Index, size_t... VIndex, size_t... BIndex> | |
void apply_trivial(std::array<buffer_info, NVectorized>& buffers, | |
std::array<void*, N>& params, | |
const std::vector<ssize_t>& shape, | |
size_t size, | |
index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) { | |
// Initialize an array of mutable byte references and sizes with references set to the | |
// appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size | |
// (except for singletons, which get an increment of 0). | |
std::array<std::pair<unsigned char*&, const size_t>, NVectorized> vecparams{{ | |
std::pair<unsigned char*&, | |
const size_t>( | |
reinterpret_cast<unsigned char*&>(params[VIndex] = buffers[BIndex].ptr), | |
buffers[BIndex].size==1 ? 0 | |
: sizeof(param_n_t< | |
VIndex>) | |
)... | |
}}; | |
for (size_t i = 0; i<size; ++i) { | |
f(*reinterpret_cast<param_n_t<Index>*>(params[Index])...); | |
for (auto& x : vecparams) x.first += x.second; | |
} | |
} | |
template<size_t... Index, size_t... VIndex, size_t... BIndex> | |
void apply_broadcast(std::array<buffer_info, NVectorized>& buffers, | |
std::array<void*, N>& params, | |
const std::vector<ssize_t>& shape, | |
size_t size, | |
index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) { | |
multi_array_iterator<NVectorized> input_iter(buffers, shape); | |
for (size_t i = 0; | |
i!=size; | |
++i, ++input_iter) { | |
PYBIND11_EXPAND_SIDE_EFFECTS(( | |
params[VIndex] = input_iter.template data<BIndex>() | |
)); | |
f(*reinterpret_cast<param_n_t<Index>*>(std::get<Index>(params))...); | |
} | |
} | |
}; | |
} | |
} | |
struct Sum { | |
double value; | |
}; | |
PYBIND11_MODULE(test_module, m) { | |
py::class_<Sum>(m, "Sum") | |
.def(py::init<>()) | |
.def_readwrite("value", &Sum::value) | |
.def("add", py::vectorize([](Sum& self, double value) { | |
self.value += value; | |
})); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment