Skip to content

Instantly share code, notes, and snippets.

@asford
Last active December 2, 2017 21:24
Show Gist options
  • Save asford/c9a27b2e0c910d93140d570dd991a5c9 to your computer and use it in GitHub Desktop.
Save asford/c9a27b2e0c910d93140d570dd991a5c9 to your computer and use it in GitHub Desktop.
Simple demonstration of pybind11 gil handling.
import logging
import cppimport
import time
import threading
import numpy
spinner = cppimport.imp("spinner")
class CallbackSpinner(spinner.Spinner):
def tick(self):
logging.info("CallbackSpinner.tick")
class HeartBeat(threading.Thread):
@property
def beat_intervals(self):
return self.beats[1:] - self.beats[:-1]
@property
def beats(self):
return numpy.array(self._beats)
def __init__(self, interval):
self._beats = []
self.alive = False
self.interval = interval
threading.Thread.__init__(self)
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.alive = False
self.join()
return False
def run(self):
self._beats = []
self.alive = True
while self.alive:
self.tick()
time.sleep(self.interval)
def tick(self):
self._beats.append(time.time())
logging.info("HeartBeat.tick")
/*
<%
setup_pybind11(cfg)
cfg['compiler_args'] = ['-std=c++11']
%>
*/
#include <pybind11/pybind11.h>
#include <chrono>
#include <thread>
namespace py = pybind11;
class Spinner {
public:
void spin(int n_times, float wait) {
auto wait_duration = std::chrono::duration<float, std::chrono::seconds::period>(wait);
for(int i = 0; i < n_times; ++i) {
std::this_thread::sleep_for(wait_duration);
tick();
}
}
virtual void tick() = 0;
};
class PySpinner : public Spinner {
public:
using Spinner::Spinner; // Inherit constructors
void tick() override {
PYBIND11_OVERLOAD_PURE(
void,
Spinner,
tick,
);
}
};
PYBIND11_MODULE(spinner, m) {
py::class_<Spinner, PySpinner> spinner(m, "Spinner");
spinner.def(py::init<>());
spinner.def("spin_gil", &Spinner::spin);
spinner.def("spin_nogil", &Spinner::spin, py::call_guard<py::gil_scoped_release>());
}
import faulthandler
import logging
import unittest
import numpy
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s.%(msecs).03d %(name)s %(message)s",
datefmt='%Y-%m-%dT%H:%M:%S'
)
faulthandler.enable()
class TestGIL(unittest.TestCase):
def test_gil(self):
import pyspinner
with pyspinner.HeartBeat(.05) as hb:
pyspinner.CallbackSpinner().spin_gil(3, .2)
numpy.testing.assert_allclose(hb.beat_intervals, .05, rtol=.3)
def test_nogil(self):
import pyspinner
with pyspinner.HeartBeat(.05) as hb:
pyspinner.CallbackSpinner().spin_nogil(3, .2)
numpy.testing.assert_allclose(hb.beat_intervals, .05, rtol=.3)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment