Skip to content

Instantly share code, notes, and snippets.

@heiner
Last active August 6, 2019 13:21
Show Gist options
  • Save heiner/554ac6b0692919efb12b7b0e84e6ff6c to your computer and use it in GitHub Desktop.
Save heiner/554ac6b0692919efb12b7b0e84e6ff6c to your computer and use it in GitHub Desktop.
arraycapsule
/*
* CXX=c++ python3 setup.py build develop
* Then python run.py
*/
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <sstream>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
struct Holder {
Holder() : data(new std::string) {}
void set_contents(std::string s) { *data = std::move(s); }
std::string *data;
};
PYBIND11_MODULE(arraycapsule, m) {
py::class_<Holder>(m, "Holder")
.def(py::init<>())
.def("set_contents", &Holder::set_contents);
m.def("create", [](const Holder &h, py::dtype dt, std::vector<int> shape) {
std::string *data = h.data;
// Attach capsule as base in order to free data.
return py::array(
dt, shape, {}, data->data(), py::capsule(data, [](void *ptr) {
std::string *s = reinterpret_cast<std::string *>(ptr);
std::cout << std::hex << std::setfill('0');
for (unsigned int i = 0; i < s->size(); i++)
std::cout << std::setfill('0') << std::setw(2) << std::hex
<< static_cast<unsigned int>(
static_cast<unsigned char>((*s)[i]));
std::cout << std::endl;
delete s;
}));
});
}
import time
import torch
import arraycapsule
import numpy as np
a = np.arange(3, dtype=np.float32)
b = a.tobytes()
print("original", b.hex())
holder = arraycapsule.Holder()
holder.set_contents(b)
a2 = arraycapsule.create(holder, a.dtype, a.shape)
a[0] = 42.0
b = a.tobytes()
print("py modif", b.hex())
holder.set_contents(b)
print("C++ modi", a2.tobytes().hex())
a2[1] = 42.0
print("2nd modi", a2.tobytes().hex())
print("end; C++ printout follows")
# Build with
# CXX=c++ python3 setup.py build develop
import setuptools
import sys
from torch.utils import cpp_extension
extra_compile_args = []
extra_link_args = []
if sys.platform == "darwin":
extra_compile_args += ["-stdlib=libc++", "-mmacosx-version-min=10.12"]
extra_link_args += ["-stdlib=libc++"]
arraycapsule = cpp_extension.CppExtension(
name="arraycapsule",
sources=["arraycapsule.cc"],
language="c++",
extra_compile_args=["-std=c++17"] + extra_compile_args,
extra_link_args=extra_link_args,
)
setuptools.setup(
name="arraycapsule",
ext_modules=[arraycapsule],
cmdclass={"build_ext": cpp_extension.BuildExtension},
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment