Skip to content

Instantly share code, notes, and snippets.

@angus-g
Last active June 16, 2021 06:08
Show Gist options
  • Save angus-g/24356f290b22bb668d12c0a7844df077 to your computer and use it in GitHub Desktop.
Save angus-g/24356f290b22bb668d12c0a7844df077 to your computer and use it in GitHub Desktop.
import vector
class V(vector.Vector):
def __init__(self):
super().__init__()
self.val = 0
def plus(self, x):
self.val += x.val
def __getstate__(self): return self.__dict__.copy()
def __setstate__(self, state): self.__dict__.update(state)
v = V()
v.val = 42
vector.serialise_vector(v)
y = vector.load_vector()
// g++ -g -Wall -shared -std=c++11 -fPIC -Icereal/include $(python3 -m pybind11 --includes) vector.cpp -o vector$(python3-config --extension-suffix)
#include <pybind11/pybind11.h>
namespace py = pybind11;
#include <fstream>
#include <cereal/access.hpp>
#include <cereal/types/polymorphic.hpp>
#include <cereal/archives/binary.hpp>
template<class Real>
class Vector {
public:
virtual ~Vector<Real>() { }
virtual void plus(const Vector<Real> &x) = 0;
};
// concrete functions for a pure virtual class
class PyVector : public Vector<double> {
public:
// a few methods like this:
virtual void plus(const Vector<double>& x) override {
PYBIND11_OVERRIDE_PURE(void, Vector<double>, plus, x);
}
template<class Archive>
void save(Archive &archive) const {
py::gil_scoped_acquire gil;
py::object pickle = py::module_::import("pickle");
py::object pickled = pickle.attr("dumps")(py::cast(this));
std::string result = pickled.cast<std::string>();
archive(result);
}
// we don't intend to load a non-pointer instance of PyVector...
template<class Archive>
void load(Archive &archive) {}
};
// specialise cereal's loading into a PtrWrapper so we can allocate our own
template <class Archive>
void load(Archive &archive, cereal::memory_detail::PtrWrapper<std::shared_ptr<PyVector> &> &wrapper) {
uint32_t id;
archive(CEREAL_NVP_("id", id));
if (id & cereal::detail::msb_32bit) {
py::gil_scoped_acquire gil;
// un-pickle to python object
py::object pickle = py::module_::import("pickle");
std::string pickled;
archive(pickled);
py::object loaded = pickle.attr("loads")(py::bytes(pickled));
std::cout << "unpickled: " << loaded.attr("val").cast<int>() << std::endl;
// pybind clone pattern
auto keep_alive = std::make_shared<py::object>(loaded);
auto p = loaded.cast<PyVector *>();
std::shared_ptr<PyVector> ptr(keep_alive, p); // alias to keep_alive, but point to p
archive.registerSharedPointer(id, ptr);
wrapper.ptr = std::move(ptr);
} else {
wrapper.ptr = std::static_pointer_cast<PyVector>(archive.getSharedPointer(id));
}
}
CEREAL_REGISTER_TYPE(PyVector);
CEREAL_REGISTER_POLYMORPHIC_RELATION(Vector<double>, PyVector);
void serialise_vector(const std::shared_ptr<Vector<double>> &v) {
std::ofstream os("vector.cereal", std::ios::binary);
cereal::BinaryOutputArchive oarchive(os);
oarchive(v);
}
std::shared_ptr<Vector<double>> load_vector() {
std::ifstream is("vector.cereal", std::ios::binary);
cereal::BinaryInputArchive iarchive(is);
std::shared_ptr<Vector<double>> v;
iarchive(v);
std::cout << "successfully loaded, going back to pybind" << std::endl;
return v; // <---- segfault here-ish
}
PYBIND11_MODULE(vector, m) {
m.doc() = "cereal example";
m.def("serialise_vector", &serialise_vector, "Serialise a Vector pointer to disk");
m.def("load_vector", &load_vector, "Load a serialised Vector from disk");
py::class_<Vector<double>, PyVector, std::shared_ptr<Vector<double>>>(m, "Vector")
.def(py::init<>());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment