Created
May 10, 2020 10:16
-
-
Save nerditation/441a2d4409a778ae77683fc0645b69de to your computer and use it in GitHub Desktop.
proof of concept to use the C++ coroutines with the Lua API.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// PoC to use C++ resumable functions (a.k.a coroutines) to | |
// eliminate the need of manually transforming code into CPS | |
// style when working with the Lua APIs taking continuations, | |
// i.e. `lua_yieldk()`, `lua_callk()` and `lua_pcallk()` | |
// this code need C++17 and the the coroutines TS support | |
#include <cstdint> | |
#include <variant> | |
#include <functional> | |
#ifdef _MSC_VER | |
// cl.exe /std:c++17 /await mylua-cpp-coroutines.cpp lua.lib | |
#include <experimental/coroutine> | |
using std::experimental::coroutine_handle; | |
using std::experimental::suspend_always; | |
#else | |
// g++-10 -std=c++17 -fcoroutines mylua-cpp-coroutines.cpp -llua | |
#include <coroutine> | |
using std::coroutine_handle; | |
using std::suspend_always; | |
#endif | |
#include <lua.hpp> | |
// true: use type-erased std::function as the storage type of the coroutine, | |
// can use stateful lambda or other callable objects | |
// false: use plain function pointer as the storage type, can only use | |
// capture-less lambdas or plain functions, need to use the (extended) api | |
// to manage upvalues | |
#define USE_STD_FUNCTION 1 | |
// to my understanding, currently if we unwind the stack across a coroutine/resumer | |
// boundary, the coroutine will be in a non-resumable state and the coroutine frame | |
// will be destroyed by the runtime. this means: | |
// - we can't call lua_callk() and lua_pcallk() functions directly from | |
// within the awaiter's await_suspend() method | |
// - we can't call lua_yieldk() directly inside promise_type::yield_value() | |
// - we have to let the coroutine reach a suspend point and let the caller/resumer | |
// decide what to do next | |
// so we use the following tag types | |
struct returned_ { | |
int nresults; | |
}; | |
struct yielded_ { | |
int nresults; | |
}; | |
struct called_ { | |
int nargs; | |
int nresults; | |
}; | |
struct pcalled_ { | |
int nargs; | |
int nresults; | |
int errfunc; | |
}; | |
// convinient constructors which mimic the style of Lua API | |
// must be used with `co_await` | |
// co_await mylua_call(L, nargs, nresults); | |
auto mylua_call(lua_State *, int nargs, int nresults) { | |
return called_{ nargs, nresults }; | |
} | |
// int status = co_await mylua_pcall(L, nargs, nresults, errfunc); | |
auto mylua_pcall(lua_State *, int nargs, int nresults, int errfunc = 0) { | |
return pcalled_{ nargs, nresults, errfunc }; | |
} | |
// co_await mylua_yield(L, nresults); | |
// can also use short hand: co_yield nresults; | |
auto mylua_yield(lua_State *, int nresulst) { | |
return yielded_{ nresulst }; | |
} | |
// just a wrapper of coroutine_handle<promise_type> | |
// not a RAII class | |
// use like a non-owning pointer | |
struct mylua_ResumableHandle { | |
struct promise_type { | |
// values yielded to the caller/resumer to call the lua APIs | |
std::variant<returned_, yielded_, called_, pcalled_> state; | |
// value passed by the caller/resumer, return value from the lua APIs | |
// only lua_pcallk() returns a status. | |
// for lua_yieldk(), it is always LUA_YIELD | |
// for lua_callk(), it is always LUA_OK | |
int resumed_value; | |
void return_value(int nresults) noexcept { | |
state = returned_{ nresults }; | |
} | |
#if 0 | |
// currently it is ill-formed if both return_value and return_void are present | |
// but there is a paper P1713 suggesting to relax this restriction | |
// see https://github.com/cplusplus/papers/issues/479 | |
void return_void() noexcept { | |
return_value(0); | |
} | |
#endif | |
auto yield_value(int nresults = 0) noexcept { | |
state = yielded_{ nresults }; | |
return suspend_always{}; | |
}; | |
template <typename T> | |
auto await_transform(T &&args) noexcept { | |
this->state = args; | |
return suspend_always{}; | |
} | |
auto await_transform(pcalled_ &&args) noexcept { | |
this->state = args; | |
struct _ : public suspend_always | |
{ | |
promise_type const &promise; | |
_(promise_type const &promise) : promise(promise) {} | |
int await_resume() noexcept { | |
return promise.resumed_value; | |
} | |
}; | |
return _{ *this }; | |
} | |
auto initial_suspend() const noexcept { | |
return suspend_always{}; | |
} | |
auto final_suspend() const noexcept { | |
return suspend_always{}; | |
} | |
void unhandled_exception() const { | |
// throwing exceptions across coroutine/resumer boundary | |
// would destroy the coroutine frame | |
//throw; | |
abort(); | |
} | |
auto get_return_object() { | |
auto handle = mylua_ResumableHandle::handle_type::from_promise(*this); | |
return mylua_ResumableHandle{ handle }; | |
} | |
}; | |
using handle_type = coroutine_handle<promise_type>; | |
handle_type handle_; | |
#if 0 | |
~mylua_ResumableHandle() noexcept { | |
destroy(); | |
} | |
#endif | |
static mylua_ResumableHandle from_context(lua_KContext ctx) noexcept { | |
auto ptr = reinterpret_cast<void *>(static_cast<intptr_t>(ctx)); | |
return { handle_type::from_address(ptr) }; | |
} | |
lua_KContext to_context() const noexcept { | |
auto ptr = handle_.address(); | |
return static_cast<lua_KContext>(reinterpret_cast<intptr_t>(ptr)); | |
} | |
promise_type & promise() const noexcept { | |
return handle_.promise(); | |
} | |
void destroy() noexcept { | |
if (handle_) { | |
handle_.destroy(); | |
} | |
} | |
void resume() { | |
return handle_.resume(); | |
} | |
}; | |
template<class ... FNs> struct overloaded_visitor : FNs... { using FNs::operator()...; }; | |
template<class ... FNs> auto overloaded(FNs && ... fns) noexcept { return overloaded_visitor<FNs...>{std::forward<FNs>(fns)...}; } | |
static int myluai_resumable_wrapper_k(lua_State *L, int status, lua_KContext ctx) { | |
auto handle = mylua_ResumableHandle::from_context(ctx); | |
handle.promise().resumed_value = status; | |
handle.resume(); | |
return std::visit( | |
overloaded( | |
[&](returned_ const &args) noexcept { | |
auto nresults = args.nresults; | |
handle.destroy(); | |
return nresults; | |
} | |
, [&](yielded_ const &args) noexcept { | |
return lua_yieldk(L, args.nresults, ctx, &myluai_resumable_wrapper_k); | |
} | |
, [&](called_ const &args) noexcept { | |
lua_callk(L, args.nargs, args.nresults, ctx, &myluai_resumable_wrapper_k); | |
return myluai_resumable_wrapper_k(L, LUA_OK, ctx); | |
} | |
, [&](pcalled_ const &args) noexcept { | |
int status = lua_pcallk(L, args.nargs, args.nresults, args.errfunc, ctx, &myluai_resumable_wrapper_k); | |
return myluai_resumable_wrapper_k(L, status, ctx); | |
} | |
), | |
handle.promise().state | |
); | |
} | |
#if USE_STD_FUNCTION | |
using mylua_ResumableFunction = std::function<mylua_ResumableHandle(lua_State *)>; | |
static int myluai_resumable_wrapper(lua_State *L) { | |
auto &coro = *reinterpret_cast<mylua_ResumableFunction *>(lua_touserdata(L, lua_upvalueindex(1))); | |
// call it the first time to create the coroutine frame and get the handle | |
// will suspend at promise::initial_suspend() | |
auto handle = coro(L); | |
// delegate the rest to the continuation | |
return myluai_resumable_wrapper_k(L, LUA_OK, handle.to_context()); | |
} | |
constexpr auto cpp_std_function_typename = "__cplusplus(std::function<>)"; | |
template <typename F> | |
static void mylua_pushresumable(lua_State *L, F&& coro) { | |
auto p = lua_newuserdata(L, sizeof(mylua_ResumableFunction)); | |
new(p) mylua_ResumableFunction(std::forward<F>(coro)); | |
luaL_setmetatable(L, cpp_std_function_typename); | |
lua_pushcclosure(L, myluai_resumable_wrapper, 1); | |
} | |
static int mylua_init(lua_State *L) { | |
luaL_newmetatable(L, cpp_std_function_typename); | |
auto call_dtor = [](lua_State *L) -> int { | |
auto p = luaL_checkudata(L, 1, cpp_std_function_typename); | |
auto &fn = * reinterpret_cast<mylua_ResumableFunction *>(p); | |
fn.~function(); | |
return 0; | |
}; | |
lua_pushcfunction(L, call_dtor); | |
lua_setfield(L, -2, "__gc"); | |
lua_pop(L, 1); | |
return 0; | |
} | |
#else | |
// casting between function pointers and data pointers might not be safe | |
using mylua_ResumableFunction = mylua_ResumableHandle(*)(lua_State *); | |
static_assert(sizeof(void *) >= sizeof(mylua_ResumableFunction)); | |
static int myluai_resumable_wrapper(lua_State *L) { | |
auto p = lua_touserdata(L, lua_upvalueindex(1)); | |
auto coro = reinterpret_cast<mylua_ResumableFunction &>(p); | |
// call it the first time to create the coroutine frame and get the handle | |
// will suspend at promise::initial_suspend() | |
auto handle = coro(L); | |
// delegate the rest to the continuation | |
return myluai_resumable_wrapper_k(L, LUA_OK, handle.to_context()); | |
} | |
static void mylua_pushresumable(lua_State *L, mylua_ResumableFunction coro, int nupvals = 0) { | |
auto p = reinterpret_cast<void *&>(coro); | |
lua_pushlightuserdata(L, p); | |
lua_insert(L, -1 - nupvals); | |
lua_pushcclosure(L, myluai_resumable_wrapper, nupvals + 1); | |
} | |
static int mylua_ResumableHandleupvalueindex(int i) { | |
return lua_upvalueindex(i + 1); | |
} | |
static int mylua_init(lua_State *) { | |
return 0; | |
} | |
#endif | |
/** | |
-- generate all the numbers from 2 to n | |
function gen (n) | |
return coroutine.wrap(function () | |
for i=2,n do coroutine.yield(i) end | |
end) | |
end | |
*/ | |
static int gen(lua_State *L) { | |
auto N = luaL_checkinteger(L, 1); | |
lua_getglobal(L, "coroutine"); | |
lua_getfield(L, -1, "wrap"); | |
#if USE_STD_FUNCTION | |
mylua_pushresumable(L, [N](lua_State *L) -> mylua_ResumableHandle { | |
for (int i = 2; i < N; ++i) { | |
lua_pushinteger(L, i); | |
// same as co_await mylua_yield(L, 1); | |
co_yield 1; | |
} | |
co_return 0; | |
}); | |
#else | |
lua_pushvalue(L, 1); | |
mylua_pushresumable(L, [](lua_State *L) -> mylua_ResumableHandle { | |
auto N = lua_tointeger(L, mylua_ResumableHandleupvalueindex(1)); | |
for (int i = 2; i < N; ++i) { | |
lua_pushinteger(L, i); | |
co_yield 1; | |
} | |
co_return 0; | |
}, 1); | |
#endif | |
lua_call(L, 1, 1); | |
return 1; | |
} | |
/** | |
-- filter the numbers generated by `g', removing multiples of `p' | |
function filter (p, g) | |
return coroutine.wrap(function () | |
for n in g do | |
if n%p ~= 0 then coroutine.yield(n) end | |
end | |
end) | |
end | |
*/ | |
static int filter(lua_State *L) { | |
auto p = luaL_checkinteger(L, 1); | |
luaL_checktype(L, 2, LUA_TFUNCTION); | |
lua_getglobal(L, "coroutine"); | |
lua_getfield(L, -1, "wrap"); | |
#if USE_STD_FUNCTION | |
// need this to be captured by the cpp coroutine | |
lua_pushvalue(L, 2); | |
auto g = luaL_ref(L, LUA_REGISTRYINDEX); | |
mylua_pushresumable(L, [p, g](lua_State *L) -> mylua_ResumableHandle { | |
// repeatedly call the generator till it return nil | |
for (;;) { | |
lua_rawgeti(L, LUA_REGISTRYINDEX, g); | |
co_await mylua_call(L, 0, 1); | |
if (lua_isnil(L, -1)) { | |
break; | |
} | |
auto n = lua_tointeger(L, -1); | |
if ((n % p) != 0) { | |
co_await mylua_yield(L, 1); | |
} else { | |
lua_pop(L, 1); | |
} | |
} | |
luaL_unref(L, LUA_REGISTRYINDEX, g); | |
co_return 0; | |
}); | |
#else | |
lua_pushvalue(L, 1); | |
lua_pushvalue(L, 2); | |
mylua_pushresumable(L, [](lua_State *L) -> mylua_ResumableHandle { | |
auto p = lua_tointeger(L, mylua_ResumableHandleupvalueindex(1)); | |
for (;;) { | |
lua_pushvalue(L, mylua_ResumableHandleupvalueindex(2)); | |
co_await mylua_call(L, 0, 1); | |
if (lua_isnil(L, -1)) { | |
break; | |
} | |
auto n = lua_tointeger(L, -1); | |
if ((n % p) != 0) { | |
co_yield 1; | |
} else { | |
lua_pop(L, 1); | |
} | |
} | |
co_return 0; | |
}, 2); | |
#endif | |
lua_call(L, 1, 1); | |
return 1; | |
} | |
int main() { | |
auto L = luaL_newstate(); | |
luaL_openlibs(L); | |
mylua_init(L); | |
lua_pushcfunction(L, &gen); | |
lua_setglobal(L, "gen"); | |
lua_pushcfunction(L, &filter); | |
lua_setglobal(L, "filter"); | |
luaL_dostring(L, R"( | |
N=N or 500 -- from command line | |
x = gen(N) -- generate primes up to N | |
while 1 do | |
local n = x() -- pick a number until done | |
if n == nil then break end | |
print(n) -- must be a prime number | |
x = filter(n, x) -- now remove its multiples | |
end | |
)"); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment