Last active
May 21, 2016 08:57
-
-
Save daurnimator/23e36762dc62198da8804350df654ecf to your computer and use it in GitHub Desktop.
Pure-lua cqueues-like
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local cqueues = require "cqueues" | |
local unpack = table.unpack or unpack | |
local pack = table.pack or function(...) return {n=select("#", ...), ...} end | |
local POLLIN, POLLOUT, POLLPRI = 1, 2, 4 | |
-- a tostring() that can't throw | |
local function safe_tostring(t) | |
local ok, s = pcall(tostring, t) | |
if not ok then | |
return "(null)" | |
else | |
return s | |
end | |
end | |
local function each_arg(...) | |
return function(args, last) | |
if last == args.n then return nil end | |
local i = last + 1 | |
return i, args[i] | |
end, pack(...), 0 | |
end | |
local methods = {} | |
local mt = { | |
__name = "lua cqueue"; | |
__index = methods; | |
} | |
local function thread_add_head(thread, head) | |
thread.prev = head | |
thread.next = head.next | |
if head.next then | |
head.next.prev = thread | |
end | |
head.next = thread | |
end | |
local function thread_del(thread) | |
local prev, next = thread.prev, thread.next | |
prev.next = next | |
if next then | |
next.prev = prev | |
end | |
-- thread.prev, thread.next = nil, nil -- not needed but catches programming errors | |
end | |
local function thread_move(thread, head) | |
-- Remove from old list | |
thread_del(thread) | |
-- Add to new list | |
thread_add_head(thread, head) | |
end | |
local _POLL = {} -- something unique | |
-- Table of scheduler objects | |
local cstack = setmetatable({}, {__mode="kv"}) | |
-- The currently running scheduler | |
local cstack_running = nil | |
local function cstack_push(new) | |
new.running = cstack_running | |
cstack_running = new | |
end | |
local function cstack_pop() | |
cstack_running = cstack_running.running | |
end | |
local timer_methods = {} | |
local timer_mt = { | |
__name = "timer collection"; | |
__index = timer_methods; | |
} | |
local function new_timers() | |
return setmetatable({}, timer_mt) | |
end | |
function timer_methods:add(ob, deadline) | |
self[ob] = deadline | |
end | |
function timer_methods:remove(ob) | |
self[ob] = nil | |
end | |
function timer_methods:min() | |
local min = math.huge | |
for _, deadline in pairs(self) do | |
if deadline < min then | |
min = deadline | |
end | |
end | |
if min == math.huge then | |
return nil | |
else | |
return min | |
end | |
end | |
timer_methods.each = pairs | |
local condition_methods = {} | |
local condition_mt = { | |
__name = "condition"; | |
__index = condition_methods; | |
} | |
local function new_condition(lifo) | |
return setmetatable({ | |
lifo = lifo; | |
head = 1; | |
tail = 0; | |
}, condition_mt) | |
end | |
do | |
local function check(self, why, ...) | |
if self == why then | |
return true, ... | |
else | |
return false, why, ... | |
end | |
end | |
function condition_methods:wait(...) | |
return check(self, coroutine.yield(_POLL, self, ...)) | |
end | |
end | |
local function condition_add(self, fn) | |
if self.lifo then | |
local i = self.head - 1 | |
self[i] = fn | |
self.head = i | |
else | |
local i = self.tail + 1 | |
self[i] = fn | |
self.tail = i | |
end | |
end | |
function condition_methods:signal(max) | |
if max then | |
max = math.min(self.tail, self.head + max) | |
else | |
max = self.tail | |
end | |
for i=self.head, max do | |
local event = self[i] | |
event.pending = true | |
thread_move(event.thread, event.thread.scheduler.pending) | |
-- TODO: wakeup | |
end | |
end | |
function condition_methods:pollfd() | |
return self | |
end | |
function condition_methods:events() | |
return nil | |
end | |
function condition_methods:timeout() | |
return nil | |
end | |
local function new() | |
local self = setmetatable({ | |
thread_count = 0; -- count of items in polling+pending | |
polling = {next=nil}; -- linked list of threads | |
pending = {next=nil}; -- linked list of threads | |
current = nil; -- current coroutine | |
running = nil; -- linked list of cqueues | |
timers = new_timers(); | |
}, mt) | |
cstack[self] = true | |
return self | |
end | |
function methods.interpose(key, func) | |
local old = methods[key] | |
methods[key] = func | |
return old | |
end | |
-- local monotonic_clock = 0 | |
-- local function monotime() | |
-- monotonic_clock = monotonic_clock + 1 | |
-- return monotonic_clock | |
-- end | |
local monotime = cqueues.monotime | |
local function running() | |
local is_caller | |
if cstack_running then | |
is_caller = cstack_running.current == coroutine.running() | |
else | |
is_caller = false | |
end | |
return cstack_running, is_caller | |
end | |
local function cancel(...) | |
for cq in pairs(cstack) do | |
cq:cancel(...) | |
end | |
end | |
local function reset() | |
for cq in pairs(cstack) do | |
cq:reset() | |
end | |
end | |
local poller = new() | |
local function poll(...) | |
if running() then | |
return coroutine.yield(_POLL, ...) | |
else | |
local tuple | |
poller:wrap(function (...) | |
tuple = { poll(...) } | |
end, ...) | |
-- NOTE: must step twice, once to call poll and | |
-- again to wake up | |
assert(poller:step()) | |
assert(poller:step()) | |
return unpack(tuple or {}) | |
end | |
end | |
local function sleep(timeout) | |
poll(timeout) | |
end | |
do | |
local handle_resume | |
local function do_resume(self, thread, ...) | |
cstack_push(self) | |
self.current = thread.co | |
return handle_resume(self, thread, coroutine.resume(thread.co, ...)) | |
end | |
local function cleanup(self, thread, ...) | |
thread_del(thread) | |
self.thread_count = self.thread_count - 1 | |
return ... | |
end | |
function handle_resume(self, thread, ok, first, ...) | |
self.current = nil | |
cstack_pop() | |
if not ok then | |
return cleanup(self, thread, nil, first, nil, thread.co) | |
elseif coroutine.status(thread.co) == "dead" then | |
return cleanup(self, thread, true) | |
else | |
if first == _POLL then | |
local now = monotime() | |
local thread_timeout = math.huge | |
for _, v in each_arg(...) do | |
local pollfd, events, timeout, cond = nil, 0, nil, nil | |
if type(v) == "number" then | |
timeout = v | |
elseif getmetatable(v) == condition_mt then | |
cond = v | |
elseif v ~= nil then | |
if v.pollfd ~= nil then | |
pollfd = v.pollfd | |
if type(pollfd) == "function" then | |
local pcall_ok, err = pcall(pollfd, v) | |
if not pcall_ok then | |
return cleanup(self, thread, nil, "error calling method pollfd: " .. safe_tostring(err)) | |
end | |
pollfd = err | |
end | |
if type(pollfd) == "number" then | |
if pollfd == -1 then | |
pollfd = nil | |
end | |
elseif getmetatable(pollfd) == condition_mt then | |
pollfd, cond = nil, pollfd | |
elseif pollfd ~= nil then | |
return cleanup(self, thread, nil, "invalid pollfd (expected nil, number or condition)", nil, thread.co, v) | |
end | |
end | |
if v.events ~= nil then | |
events = v.events | |
if type(events) == "function" then | |
local pcall_ok, err = pcall(events, v) | |
if not pcall_ok then | |
return cleanup(self, thread, nil, "error calling method events: " .. safe_tostring(err)) | |
end | |
events = err | |
end | |
if events == nil then | |
events = 0 | |
elseif type(events) == "string" then | |
local e = 0 | |
if events:match "r" then | |
e = e + POLLIN | |
end | |
if events:match "w" then | |
e = e + POLLOUT | |
end | |
if events:match "p" then | |
e = e + POLLPRI | |
end | |
events = e | |
elseif type(events) ~= "number" then | |
return cleanup(self, thread, nil, "invalid events (expected nil, number or string)", nil, thread.co, v, pollfd) | |
end | |
end | |
if v.timeout ~= nil then | |
timeout = v.timeout | |
if type(timeout) == "function" then | |
local pcall_ok, err = pcall(timeout, v) | |
if not pcall_ok then | |
return cleanup(self, thread, nil, "error calling method timeout: " .. safe_tostring(err)) | |
end | |
timeout = err | |
end | |
end | |
end | |
if timeout == nil then | |
timeout = math.huge | |
elseif type(timeout) ~= "number" then | |
return cleanup(self, thread, nil, "invalid timeout (expected nil or number)", nil, thread.co, v, pollfd) | |
end | |
if timeout < math.huge or cond or (pollfd and events ~= 0) then | |
local event = { | |
thread = thread; | |
value = v; | |
deadline = now + timeout; | |
pending = false; | |
} | |
table.insert(thread.events, event) | |
if cond then | |
condition_add(cond, event) | |
end | |
if pollfd then | |
error("NYI pollfd") | |
end | |
if timeout then | |
thread_timeout = math.min(timeout, thread_timeout) | |
end | |
end | |
end | |
if thread.events[1] ~= nil or thread_timeout ~= math.huge then | |
if thread_timeout ~= math.huge then | |
-- add thread to timers | |
self.timers:add(thread, now+thread_timeout) | |
end | |
thread_move(thread, self.polling) | |
end | |
return true | |
else | |
return do_resume(self, thread, coroutine.yield(first, ...)) | |
end | |
end | |
end | |
function methods:step(timeout) | |
if running() then | |
poll(self, timeout) | |
timeout = 0.0 | |
end | |
assert(self.current == nil, "cannot step live cqueue") | |
if self.pending.next then | |
timeout = 0 | |
else | |
timeout = timeout or math.huge | |
local t = self.timers:min() | |
if t then | |
timeout = math.min(timeout, t-monotime()) | |
end | |
end | |
-- Find out what in self.polling is ready; take up to 'timeout' | |
if timeout > 0 then | |
print("WAITING FOR", timeout) | |
cqueues.sleep(timeout) | |
print("WAITED") | |
end | |
-- Move them to .pending list | |
local now = monotime() | |
for thread, deadline in self.timers:each() do | |
if deadline <= now then | |
for _, event in ipairs(thread.events) do | |
if event.deadline <= now then | |
event.pending = true | |
end | |
end | |
thread_move(thread, self.pending) | |
end | |
end | |
-- Run pending threads | |
local thread = self.pending.next | |
while thread do | |
local next_thread = thread.next -- Save next one incase current gets moved | |
local polled_ready, i = {}, 0 | |
for _, v in ipairs(thread.events) do | |
if v.pending then | |
i = i + 1 | |
polled_ready[i] = v.value | |
end | |
end | |
self.timers:remove(thread) | |
local ok, err, errno, thd, ob, fd = do_resume(self, thread, unpack(polled_ready, 1, i)) | |
if not ok then | |
return nil, err, errno, thd, ob, fd | |
end | |
thread = next_thread | |
end | |
return true | |
end | |
end | |
function methods:attach(co) | |
assert(type(co) == "thread") | |
local thread = {co = co; events = {}; scheduler = self; next = nil; prev = nil} | |
thread_add_head(thread, self.pending) | |
self.thread_count = self.thread_count + 1 | |
-- tryalert | |
return self | |
end | |
function methods:wrap(func, ...) | |
local co = coroutine.create(function(...) | |
coroutine.yield() | |
return func(...) | |
end) | |
coroutine.resume(co, ...) | |
local thread = {co = co; events = {}; scheduler = self; next = nil; prev = nil} | |
thread_add_head(thread, self.pending) | |
self.thread_count = self.thread_count + 1 | |
-- tryalert | |
return self | |
end | |
function methods:empty() | |
return self.thread_count == 0 | |
end | |
function methods:count() | |
return self.thread_count | |
end | |
function methods:cancel(...) | |
for _, v in each_arg(...) do | |
if type(v) ~= "number" then | |
v = v.pollfd | |
if type(v.pollfd) == "function" then | |
local ok, err = pcall(v.pollfd, v) | |
if not ok then | |
error("error calling method pollfd: "..safe_tostring(err)) | |
end | |
v = err | |
end | |
if type(v) ~= "number" then | |
error("error loading field pollfd") | |
end | |
end | |
-- | |
end | |
end | |
-- function methods:reset() | |
-- -- Move polling list to pending | |
-- local n = #self.pending | |
-- local e = #self.polling | |
-- for i=1, e do | |
-- self.pending[n+i] = self.polling[i] | |
-- self.polling[i] = nil | |
-- end | |
-- end | |
-- function methods:pollfd() | |
-- return nil | |
-- end | |
function methods:events() | |
return "r" | |
end | |
function methods:timeout() | |
if self.pending.next then | |
return 0 | |
else | |
local t = self.timers:min() | |
if t then | |
return t - monotime() | |
else | |
return nil | |
end | |
end | |
end | |
local function todeadline(timeout) | |
-- special case 0 timeout to avoid monotime call in totimeout | |
return timeout and (timeout > 0 and monotime() + timeout or 0) or nil | |
end -- todeadline | |
local function totimeout(deadline) | |
if not deadline then | |
return nil, false | |
elseif deadline == 0 then | |
return 0, true | |
else | |
local curtime = monotime() | |
if curtime < deadline then | |
return deadline - curtime, false | |
else | |
return 0, true | |
end | |
end | |
end | |
function methods:loop(timeout) | |
local function checkstep(self, deadline, ok, ...) | |
local timeout, expired = totimeout(deadline) | |
if not ok then | |
return false, ... | |
elseif expired or self:empty() then | |
return true | |
else | |
return checkstep(self, deadline, self:step(timeout)) | |
end | |
end | |
local deadline = todeadline(timeout) | |
return checkstep(self, deadline, self:step(timeout)) | |
end | |
function methods:errors(timeout) | |
local deadline = todeadline(timeout) | |
return function () | |
local timeout = totimeout(deadline) | |
return select(2, self:loop(timeout)) | |
end | |
end | |
return { | |
_POLL = _POLL; | |
new = new; | |
running = running; | |
cancel = cancel; | |
reset = reset; | |
monotime = monotime; | |
poll = poll; | |
sleep = sleep; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment