Skip to content

Instantly share code, notes, and snippets.

@estshorter
Last active March 9, 2023 14:19
Show Gist options
  • Save estshorter/b9059d800b05064cc329207a8ab725dc to your computer and use it in GitHub Desktop.
Save estshorter/b9059d800b05064cc329207a8ab725dc to your computer and use it in GitHub Desktop.
nanobind sample
import numpy as np
import my_ext
a = np.arange(0, 8, dtype=np.uint8).reshape(4, 2)
my_ext.inspect(a)
print("--------\n")
my_ext.prepare()
b = my_ext.calc(a)
print(b)
#include <nanobind/ndarray.h>
namespace nb = nanobind;
static std::uint8_t data[] = { 0, 1, 2, 3, 4, 5, 6, 7 };
using data_type = std::remove_reference<decltype(*data)>::type;
bool prepare() {
for (int i = 0; i < 8; i++) {
data[i] *= 2;
}
return true;
}
auto calc(nb::ndarray<data_type, nb::shape<4, 2>, nb::c_contig, nb::device::cpu> arr_in) {
for (int i = 0; i < 8; i++) {
data[i] += arr_in.data()[i];
//printf("%f, ", arr_in.data()[i]);
}
//printf("\n");
constexpr size_t shape[2] = { 2, 4 };
return nb::ndarray<nb::numpy, data_type, nb::shape<2, nb::any>>(data, /* ndim = */ 2, shape);
}
void inspect(nb::ndarray<> arr) {
printf("Array data pointer : %p\n", arr.data());
printf("Array dimension : %zu\n", arr.ndim());
for (size_t i = 0; i < arr.ndim(); ++i) {
printf("Array dimension [%zu] : %zu\n", i, arr.shape(i));
}
printf("Array dtype: uint8_t=%i, int16=%i, uint32=%i, float32=%i\n",
arr.dtype() == nb::dtype<uint8_t>(),
arr.dtype() == nb::dtype<int16_t>(),
arr.dtype() == nb::dtype<uint32_t>(),
arr.dtype() == nb::dtype<float>()
);
}
NB_MODULE(my_ext, m) {
m.def("inspect", &inspect);
m.def("prepare", &prepare);
m.def("calc", &calc);
}
#include <nanobind/ndarray.h>
#include <array>
#include <vector>
namespace nb = nanobind;
static std::vector<std::uint8_t> s_data = { 0, 1, 2, 3, 4, 5, 6, 7 };
static std::vector<std::uint8_t> s_zeros = { 0, 0, 0, 0, 0, 0, 0, 0 };
static bool s_initialized = false;
using data_type = decltype(s_data)::value_type;
constexpr std::array<size_t, 2> shape = { 2, 4 };
auto arr = nb::ndarray<nb::numpy, data_type, nb::shape<2, nb::any>>(s_data.data(), shape.size(), shape.data());
bool prepare() {
for (int i = 0; i < s_data.size(); i++) {
s_data[i] *= 2;
}
s_initialized = true;
return true;
}
auto calc(nb::ndarray<data_type, nb::shape<4, 2>, nb::c_contig, nb::device::cpu> arr_in) {
if (!s_initialized) {
return nb::ndarray<nb::numpy, data_type, nb::shape<2, nb::any>>(s_zeros.data(), shape.size(), shape.data());
}
for (int i = 0; i < s_data.size(); i++) {
s_data[i] += arr_in.data()[i];
//printf("%f, ", arr_in.data()[i]);
}
//printf("\n");
return arr;
}
void inspect(nb::ndarray<> arr) {
printf("Array data pointer : %p\n", arr.data());
printf("Array dimension : %zu\n", arr.ndim());
for (size_t i = 0; i < arr.ndim(); ++i) {
printf("Array dimension [%zu] : %zu\n", i, arr.shape(i));
}
printf("Array dtype: uint8_t=%i, int16=%i, uint32=%i, float32=%i\n",
arr.dtype() == nb::dtype<uint8_t>(),
arr.dtype() == nb::dtype<int16_t>(),
arr.dtype() == nb::dtype<uint32_t>(),
arr.dtype() == nb::dtype<float>()
);
}
NB_MODULE(my_ext, m) {
m.def("inspect", &inspect);
m.def("prepare", &prepare);
m.def("calc", &calc);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment