Last active
October 9, 2019 04:40
-
-
Save sizmailov/6c0a4476561cfbdf90c92914c6dca50a to your computer and use it in GitHub Desktop.
Create no-copy numpy array from data member in pybind11
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<Foo.data at 0x55cab343a4c0> [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | |
<Foo.data at 0x55cab343a4c0> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0] | |
numpy array at 0x55cab343a4c0, readonly: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import test_module | |
foo = test_module.Foo() | |
foo.show() | |
np = foo.as_numpy() | |
np[0] = 1 | |
foo.show() | |
print("numpy array at 0x%x, readonly: %s" % np.__array_interface__['data']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <pybind11/pybind11.h> | |
#include <pybind11/numpy.h> | |
#include <array> | |
#include <iostream> | |
namespace py = pybind11; | |
struct Foo { | |
void print() const{ | |
std::cout << "<Foo.data at " << std::hex << data.data() << "> "; | |
std::cout << data[0] ; | |
for (int i=1;i<size;i++){ | |
std::cout << ", " << data[i]; | |
} | |
std::cout << "]" << std::endl; | |
} | |
static const int size = 10; | |
std::array<double, size> data; | |
}; | |
PYBIND11_MODULE(test_module, m) | |
{ | |
py::class_<Foo> (m,"Foo") | |
.def(py::init<>()) | |
.def("show", &Foo::print) | |
.def("as_numpy", [](py::object& pyfoo){ /* Note: `py::object&` argument instead of `Foo&` */ | |
Foo& foo = py::cast<Foo&>(pyfoo); | |
auto info = py::buffer_info( | |
foo.data.data(), /* Pointer to buffer */ | |
sizeof(double), /* Size of one scalar */ | |
py::format_descriptor<double>::format(), /* Python struct-style format descriptor */ | |
1, /* Number of dimensions */ | |
{ Foo::size }, /* Buffer dimensions */ | |
{ sizeof(double) } /* Strides (in bytes) for each index */ | |
); | |
/* Here we pass pyfoo as `base` argument to `py::array` to inform it about who owns the passed buffer */ | |
return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, pyfoo); | |
}); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Great, Thanks a lot for your help.