Last active
June 16, 2021 06:08
-
-
Save angus-g/24356f290b22bb668d12c0a7844df077 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import vector | |
class V(vector.Vector): | |
def __init__(self): | |
super().__init__() | |
self.val = 0 | |
def plus(self, x): | |
self.val += x.val | |
def __getstate__(self): return self.__dict__.copy() | |
def __setstate__(self, state): self.__dict__.update(state) | |
v = V() | |
v.val = 42 | |
vector.serialise_vector(v) | |
y = vector.load_vector() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// g++ -g -Wall -shared -std=c++11 -fPIC -Icereal/include $(python3 -m pybind11 --includes) vector.cpp -o vector$(python3-config --extension-suffix) | |
#include <pybind11/pybind11.h> | |
namespace py = pybind11; | |
#include <fstream> | |
#include <cereal/access.hpp> | |
#include <cereal/types/polymorphic.hpp> | |
#include <cereal/archives/binary.hpp> | |
template<class Real> | |
class Vector { | |
public: | |
virtual ~Vector<Real>() { } | |
virtual void plus(const Vector<Real> &x) = 0; | |
}; | |
// concrete functions for a pure virtual class | |
class PyVector : public Vector<double> { | |
public: | |
// a few methods like this: | |
virtual void plus(const Vector<double>& x) override { | |
PYBIND11_OVERRIDE_PURE(void, Vector<double>, plus, x); | |
} | |
template<class Archive> | |
void save(Archive &archive) const { | |
py::gil_scoped_acquire gil; | |
py::object pickle = py::module_::import("pickle"); | |
py::object pickled = pickle.attr("dumps")(py::cast(this)); | |
std::string result = pickled.cast<std::string>(); | |
archive(result); | |
} | |
// we don't intend to load a non-pointer instance of PyVector... | |
template<class Archive> | |
void load(Archive &archive) {} | |
}; | |
// specialise cereal's loading into a PtrWrapper so we can allocate our own | |
template <class Archive> | |
void load(Archive &archive, cereal::memory_detail::PtrWrapper<std::shared_ptr<PyVector> &> &wrapper) { | |
uint32_t id; | |
archive(CEREAL_NVP_("id", id)); | |
if (id & cereal::detail::msb_32bit) { | |
py::gil_scoped_acquire gil; | |
// un-pickle to python object | |
py::object pickle = py::module_::import("pickle"); | |
std::string pickled; | |
archive(pickled); | |
py::object loaded = pickle.attr("loads")(py::bytes(pickled)); | |
std::cout << "unpickled: " << loaded.attr("val").cast<int>() << std::endl; | |
// pybind clone pattern | |
auto keep_alive = std::make_shared<py::object>(loaded); | |
auto p = loaded.cast<PyVector *>(); | |
std::shared_ptr<PyVector> ptr(keep_alive, p); // alias to keep_alive, but point to p | |
archive.registerSharedPointer(id, ptr); | |
wrapper.ptr = std::move(ptr); | |
} else { | |
wrapper.ptr = std::static_pointer_cast<PyVector>(archive.getSharedPointer(id)); | |
} | |
} | |
CEREAL_REGISTER_TYPE(PyVector); | |
CEREAL_REGISTER_POLYMORPHIC_RELATION(Vector<double>, PyVector); | |
void serialise_vector(const std::shared_ptr<Vector<double>> &v) { | |
std::ofstream os("vector.cereal", std::ios::binary); | |
cereal::BinaryOutputArchive oarchive(os); | |
oarchive(v); | |
} | |
std::shared_ptr<Vector<double>> load_vector() { | |
std::ifstream is("vector.cereal", std::ios::binary); | |
cereal::BinaryInputArchive iarchive(is); | |
std::shared_ptr<Vector<double>> v; | |
iarchive(v); | |
std::cout << "successfully loaded, going back to pybind" << std::endl; | |
return v; // <---- segfault here-ish | |
} | |
PYBIND11_MODULE(vector, m) { | |
m.doc() = "cereal example"; | |
m.def("serialise_vector", &serialise_vector, "Serialise a Vector pointer to disk"); | |
m.def("load_vector", &load_vector, "Load a serialised Vector from disk"); | |
py::class_<Vector<double>, PyVector, std::shared_ptr<Vector<double>>>(m, "Vector") | |
.def(py::init<>()); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment