Skip to content

Instantly share code, notes, and snippets.

@nerditation
Created May 10, 2020 10:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nerditation/441a2d4409a778ae77683fc0645b69de to your computer and use it in GitHub Desktop.
Save nerditation/441a2d4409a778ae77683fc0645b69de to your computer and use it in GitHub Desktop.
proof of concept to use the C++ coroutines with the Lua API.
// PoC to use C++ resumable functions (a.k.a coroutines) to
// eliminate the need of manually transforming code into CPS
// style when working with the Lua APIs taking continuations,
// i.e. `lua_yieldk()`, `lua_callk()` and `lua_pcallk()`
// this code need C++17 and the the coroutines TS support
#include <cstdint>
#include <variant>
#include <functional>
#ifdef _MSC_VER
// cl.exe /std:c++17 /await mylua-cpp-coroutines.cpp lua.lib
#include <experimental/coroutine>
using std::experimental::coroutine_handle;
using std::experimental::suspend_always;
#else
// g++-10 -std=c++17 -fcoroutines mylua-cpp-coroutines.cpp -llua
#include <coroutine>
using std::coroutine_handle;
using std::suspend_always;
#endif
#include <lua.hpp>
// true: use type-erased std::function as the storage type of the coroutine,
// can use stateful lambda or other callable objects
// false: use plain function pointer as the storage type, can only use
// capture-less lambdas or plain functions, need to use the (extended) api
// to manage upvalues
#define USE_STD_FUNCTION 1
// to my understanding, currently if we unwind the stack across a coroutine/resumer
// boundary, the coroutine will be in a non-resumable state and the coroutine frame
// will be destroyed by the runtime. this means:
// - we can't call lua_callk() and lua_pcallk() functions directly from
// within the awaiter's await_suspend() method
// - we can't call lua_yieldk() directly inside promise_type::yield_value()
// - we have to let the coroutine reach a suspend point and let the caller/resumer
// decide what to do next
// so we use the following tag types
struct returned_ {
int nresults;
};
struct yielded_ {
int nresults;
};
struct called_ {
int nargs;
int nresults;
};
struct pcalled_ {
int nargs;
int nresults;
int errfunc;
};
// convinient constructors which mimic the style of Lua API
// must be used with `co_await`
// co_await mylua_call(L, nargs, nresults);
auto mylua_call(lua_State *, int nargs, int nresults) {
return called_{ nargs, nresults };
}
// int status = co_await mylua_pcall(L, nargs, nresults, errfunc);
auto mylua_pcall(lua_State *, int nargs, int nresults, int errfunc = 0) {
return pcalled_{ nargs, nresults, errfunc };
}
// co_await mylua_yield(L, nresults);
// can also use short hand: co_yield nresults;
auto mylua_yield(lua_State *, int nresulst) {
return yielded_{ nresulst };
}
// just a wrapper of coroutine_handle<promise_type>
// not a RAII class
// use like a non-owning pointer
struct mylua_ResumableHandle {
struct promise_type {
// values yielded to the caller/resumer to call the lua APIs
std::variant<returned_, yielded_, called_, pcalled_> state;
// value passed by the caller/resumer, return value from the lua APIs
// only lua_pcallk() returns a status.
// for lua_yieldk(), it is always LUA_YIELD
// for lua_callk(), it is always LUA_OK
int resumed_value;
void return_value(int nresults) noexcept {
state = returned_{ nresults };
}
#if 0
// currently it is ill-formed if both return_value and return_void are present
// but there is a paper P1713 suggesting to relax this restriction
// see https://github.com/cplusplus/papers/issues/479
void return_void() noexcept {
return_value(0);
}
#endif
auto yield_value(int nresults = 0) noexcept {
state = yielded_{ nresults };
return suspend_always{};
};
template <typename T>
auto await_transform(T &&args) noexcept {
this->state = args;
return suspend_always{};
}
auto await_transform(pcalled_ &&args) noexcept {
this->state = args;
struct _ : public suspend_always
{
promise_type const &promise;
_(promise_type const &promise) : promise(promise) {}
int await_resume() noexcept {
return promise.resumed_value;
}
};
return _{ *this };
}
auto initial_suspend() const noexcept {
return suspend_always{};
}
auto final_suspend() const noexcept {
return suspend_always{};
}
void unhandled_exception() const {
// throwing exceptions across coroutine/resumer boundary
// would destroy the coroutine frame
//throw;
abort();
}
auto get_return_object() {
auto handle = mylua_ResumableHandle::handle_type::from_promise(*this);
return mylua_ResumableHandle{ handle };
}
};
using handle_type = coroutine_handle<promise_type>;
handle_type handle_;
#if 0
~mylua_ResumableHandle() noexcept {
destroy();
}
#endif
static mylua_ResumableHandle from_context(lua_KContext ctx) noexcept {
auto ptr = reinterpret_cast<void *>(static_cast<intptr_t>(ctx));
return { handle_type::from_address(ptr) };
}
lua_KContext to_context() const noexcept {
auto ptr = handle_.address();
return static_cast<lua_KContext>(reinterpret_cast<intptr_t>(ptr));
}
promise_type & promise() const noexcept {
return handle_.promise();
}
void destroy() noexcept {
if (handle_) {
handle_.destroy();
}
}
void resume() {
return handle_.resume();
}
};
template<class ... FNs> struct overloaded_visitor : FNs... { using FNs::operator()...; };
template<class ... FNs> auto overloaded(FNs && ... fns) noexcept { return overloaded_visitor<FNs...>{std::forward<FNs>(fns)...}; }
static int myluai_resumable_wrapper_k(lua_State *L, int status, lua_KContext ctx) {
auto handle = mylua_ResumableHandle::from_context(ctx);
handle.promise().resumed_value = status;
handle.resume();
return std::visit(
overloaded(
[&](returned_ const &args) noexcept {
auto nresults = args.nresults;
handle.destroy();
return nresults;
}
, [&](yielded_ const &args) noexcept {
return lua_yieldk(L, args.nresults, ctx, &myluai_resumable_wrapper_k);
}
, [&](called_ const &args) noexcept {
lua_callk(L, args.nargs, args.nresults, ctx, &myluai_resumable_wrapper_k);
return myluai_resumable_wrapper_k(L, LUA_OK, ctx);
}
, [&](pcalled_ const &args) noexcept {
int status = lua_pcallk(L, args.nargs, args.nresults, args.errfunc, ctx, &myluai_resumable_wrapper_k);
return myluai_resumable_wrapper_k(L, status, ctx);
}
),
handle.promise().state
);
}
#if USE_STD_FUNCTION
using mylua_ResumableFunction = std::function<mylua_ResumableHandle(lua_State *)>;
static int myluai_resumable_wrapper(lua_State *L) {
auto &coro = *reinterpret_cast<mylua_ResumableFunction *>(lua_touserdata(L, lua_upvalueindex(1)));
// call it the first time to create the coroutine frame and get the handle
// will suspend at promise::initial_suspend()
auto handle = coro(L);
// delegate the rest to the continuation
return myluai_resumable_wrapper_k(L, LUA_OK, handle.to_context());
}
constexpr auto cpp_std_function_typename = "__cplusplus(std::function<>)";
template <typename F>
static void mylua_pushresumable(lua_State *L, F&& coro) {
auto p = lua_newuserdata(L, sizeof(mylua_ResumableFunction));
new(p) mylua_ResumableFunction(std::forward<F>(coro));
luaL_setmetatable(L, cpp_std_function_typename);
lua_pushcclosure(L, myluai_resumable_wrapper, 1);
}
static int mylua_init(lua_State *L) {
luaL_newmetatable(L, cpp_std_function_typename);
auto call_dtor = [](lua_State *L) -> int {
auto p = luaL_checkudata(L, 1, cpp_std_function_typename);
auto &fn = * reinterpret_cast<mylua_ResumableFunction *>(p);
fn.~function();
return 0;
};
lua_pushcfunction(L, call_dtor);
lua_setfield(L, -2, "__gc");
lua_pop(L, 1);
return 0;
}
#else
// casting between function pointers and data pointers might not be safe
using mylua_ResumableFunction = mylua_ResumableHandle(*)(lua_State *);
static_assert(sizeof(void *) >= sizeof(mylua_ResumableFunction));
static int myluai_resumable_wrapper(lua_State *L) {
auto p = lua_touserdata(L, lua_upvalueindex(1));
auto coro = reinterpret_cast<mylua_ResumableFunction &>(p);
// call it the first time to create the coroutine frame and get the handle
// will suspend at promise::initial_suspend()
auto handle = coro(L);
// delegate the rest to the continuation
return myluai_resumable_wrapper_k(L, LUA_OK, handle.to_context());
}
static void mylua_pushresumable(lua_State *L, mylua_ResumableFunction coro, int nupvals = 0) {
auto p = reinterpret_cast<void *&>(coro);
lua_pushlightuserdata(L, p);
lua_insert(L, -1 - nupvals);
lua_pushcclosure(L, myluai_resumable_wrapper, nupvals + 1);
}
static int mylua_ResumableHandleupvalueindex(int i) {
return lua_upvalueindex(i + 1);
}
static int mylua_init(lua_State *) {
return 0;
}
#endif
/**
-- generate all the numbers from 2 to n
function gen (n)
return coroutine.wrap(function ()
for i=2,n do coroutine.yield(i) end
end)
end
*/
static int gen(lua_State *L) {
auto N = luaL_checkinteger(L, 1);
lua_getglobal(L, "coroutine");
lua_getfield(L, -1, "wrap");
#if USE_STD_FUNCTION
mylua_pushresumable(L, [N](lua_State *L) -> mylua_ResumableHandle {
for (int i = 2; i < N; ++i) {
lua_pushinteger(L, i);
// same as co_await mylua_yield(L, 1);
co_yield 1;
}
co_return 0;
});
#else
lua_pushvalue(L, 1);
mylua_pushresumable(L, [](lua_State *L) -> mylua_ResumableHandle {
auto N = lua_tointeger(L, mylua_ResumableHandleupvalueindex(1));
for (int i = 2; i < N; ++i) {
lua_pushinteger(L, i);
co_yield 1;
}
co_return 0;
}, 1);
#endif
lua_call(L, 1, 1);
return 1;
}
/**
-- filter the numbers generated by `g', removing multiples of `p'
function filter (p, g)
return coroutine.wrap(function ()
for n in g do
if n%p ~= 0 then coroutine.yield(n) end
end
end)
end
*/
static int filter(lua_State *L) {
auto p = luaL_checkinteger(L, 1);
luaL_checktype(L, 2, LUA_TFUNCTION);
lua_getglobal(L, "coroutine");
lua_getfield(L, -1, "wrap");
#if USE_STD_FUNCTION
// need this to be captured by the cpp coroutine
lua_pushvalue(L, 2);
auto g = luaL_ref(L, LUA_REGISTRYINDEX);
mylua_pushresumable(L, [p, g](lua_State *L) -> mylua_ResumableHandle {
// repeatedly call the generator till it return nil
for (;;) {
lua_rawgeti(L, LUA_REGISTRYINDEX, g);
co_await mylua_call(L, 0, 1);
if (lua_isnil(L, -1)) {
break;
}
auto n = lua_tointeger(L, -1);
if ((n % p) != 0) {
co_await mylua_yield(L, 1);
} else {
lua_pop(L, 1);
}
}
luaL_unref(L, LUA_REGISTRYINDEX, g);
co_return 0;
});
#else
lua_pushvalue(L, 1);
lua_pushvalue(L, 2);
mylua_pushresumable(L, [](lua_State *L) -> mylua_ResumableHandle {
auto p = lua_tointeger(L, mylua_ResumableHandleupvalueindex(1));
for (;;) {
lua_pushvalue(L, mylua_ResumableHandleupvalueindex(2));
co_await mylua_call(L, 0, 1);
if (lua_isnil(L, -1)) {
break;
}
auto n = lua_tointeger(L, -1);
if ((n % p) != 0) {
co_yield 1;
} else {
lua_pop(L, 1);
}
}
co_return 0;
}, 2);
#endif
lua_call(L, 1, 1);
return 1;
}
int main() {
auto L = luaL_newstate();
luaL_openlibs(L);
mylua_init(L);
lua_pushcfunction(L, &gen);
lua_setglobal(L, "gen");
lua_pushcfunction(L, &filter);
lua_setglobal(L, "filter");
luaL_dostring(L, R"(
N=N or 500 -- from command line
x = gen(N) -- generate primes up to N
while 1 do
local n = x() -- pick a number until done
if n == nil then break end
print(n) -- must be a prime number
x = filter(n, x) -- now remove its multiples
end
)");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment