Skip to content

Instantly share code, notes, and snippets.

@daurnimator
Last active May 21, 2016 08:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save daurnimator/23e36762dc62198da8804350df654ecf to your computer and use it in GitHub Desktop.
Save daurnimator/23e36762dc62198da8804350df654ecf to your computer and use it in GitHub Desktop.
Pure-lua cqueues-like
local cqueues = require "cqueues"
local unpack = table.unpack or unpack
local pack = table.pack or function(...) return {n=select("#", ...), ...} end
local POLLIN, POLLOUT, POLLPRI = 1, 2, 4
-- a tostring() that can't throw
local function safe_tostring(t)
local ok, s = pcall(tostring, t)
if not ok then
return "(null)"
else
return s
end
end
local function each_arg(...)
return function(args, last)
if last == args.n then return nil end
local i = last + 1
return i, args[i]
end, pack(...), 0
end
local methods = {}
local mt = {
__name = "lua cqueue";
__index = methods;
}
local function thread_add_head(thread, head)
thread.prev = head
thread.next = head.next
if head.next then
head.next.prev = thread
end
head.next = thread
end
local function thread_del(thread)
local prev, next = thread.prev, thread.next
prev.next = next
if next then
next.prev = prev
end
-- thread.prev, thread.next = nil, nil -- not needed but catches programming errors
end
local function thread_move(thread, head)
-- Remove from old list
thread_del(thread)
-- Add to new list
thread_add_head(thread, head)
end
local _POLL = {} -- something unique
-- Table of scheduler objects
local cstack = setmetatable({}, {__mode="kv"})
-- The currently running scheduler
local cstack_running = nil
local function cstack_push(new)
new.running = cstack_running
cstack_running = new
end
local function cstack_pop()
cstack_running = cstack_running.running
end
local timer_methods = {}
local timer_mt = {
__name = "timer collection";
__index = timer_methods;
}
local function new_timers()
return setmetatable({}, timer_mt)
end
function timer_methods:add(ob, deadline)
self[ob] = deadline
end
function timer_methods:remove(ob)
self[ob] = nil
end
function timer_methods:min()
local min = math.huge
for _, deadline in pairs(self) do
if deadline < min then
min = deadline
end
end
if min == math.huge then
return nil
else
return min
end
end
timer_methods.each = pairs
local condition_methods = {}
local condition_mt = {
__name = "condition";
__index = condition_methods;
}
local function new_condition(lifo)
return setmetatable({
lifo = lifo;
head = 1;
tail = 0;
}, condition_mt)
end
do
local function check(self, why, ...)
if self == why then
return true, ...
else
return false, why, ...
end
end
function condition_methods:wait(...)
return check(self, coroutine.yield(_POLL, self, ...))
end
end
local function condition_add(self, fn)
if self.lifo then
local i = self.head - 1
self[i] = fn
self.head = i
else
local i = self.tail + 1
self[i] = fn
self.tail = i
end
end
function condition_methods:signal(max)
if max then
max = math.min(self.tail, self.head + max)
else
max = self.tail
end
for i=self.head, max do
local event = self[i]
event.pending = true
thread_move(event.thread, event.thread.scheduler.pending)
-- TODO: wakeup
end
end
function condition_methods:pollfd()
return self
end
function condition_methods:events()
return nil
end
function condition_methods:timeout()
return nil
end
local function new()
local self = setmetatable({
thread_count = 0; -- count of items in polling+pending
polling = {next=nil}; -- linked list of threads
pending = {next=nil}; -- linked list of threads
current = nil; -- current coroutine
running = nil; -- linked list of cqueues
timers = new_timers();
}, mt)
cstack[self] = true
return self
end
function methods.interpose(key, func)
local old = methods[key]
methods[key] = func
return old
end
-- local monotonic_clock = 0
-- local function monotime()
-- monotonic_clock = monotonic_clock + 1
-- return monotonic_clock
-- end
local monotime = cqueues.monotime
local function running()
local is_caller
if cstack_running then
is_caller = cstack_running.current == coroutine.running()
else
is_caller = false
end
return cstack_running, is_caller
end
local function cancel(...)
for cq in pairs(cstack) do
cq:cancel(...)
end
end
local function reset()
for cq in pairs(cstack) do
cq:reset()
end
end
local poller = new()
local function poll(...)
if running() then
return coroutine.yield(_POLL, ...)
else
local tuple
poller:wrap(function (...)
tuple = { poll(...) }
end, ...)
-- NOTE: must step twice, once to call poll and
-- again to wake up
assert(poller:step())
assert(poller:step())
return unpack(tuple or {})
end
end
local function sleep(timeout)
poll(timeout)
end
do
local handle_resume
local function do_resume(self, thread, ...)
cstack_push(self)
self.current = thread.co
return handle_resume(self, thread, coroutine.resume(thread.co, ...))
end
local function cleanup(self, thread, ...)
thread_del(thread)
self.thread_count = self.thread_count - 1
return ...
end
function handle_resume(self, thread, ok, first, ...)
self.current = nil
cstack_pop()
if not ok then
return cleanup(self, thread, nil, first, nil, thread.co)
elseif coroutine.status(thread.co) == "dead" then
return cleanup(self, thread, true)
else
if first == _POLL then
local now = monotime()
local thread_timeout = math.huge
for _, v in each_arg(...) do
local pollfd, events, timeout, cond = nil, 0, nil, nil
if type(v) == "number" then
timeout = v
elseif getmetatable(v) == condition_mt then
cond = v
elseif v ~= nil then
if v.pollfd ~= nil then
pollfd = v.pollfd
if type(pollfd) == "function" then
local pcall_ok, err = pcall(pollfd, v)
if not pcall_ok then
return cleanup(self, thread, nil, "error calling method pollfd: " .. safe_tostring(err))
end
pollfd = err
end
if type(pollfd) == "number" then
if pollfd == -1 then
pollfd = nil
end
elseif getmetatable(pollfd) == condition_mt then
pollfd, cond = nil, pollfd
elseif pollfd ~= nil then
return cleanup(self, thread, nil, "invalid pollfd (expected nil, number or condition)", nil, thread.co, v)
end
end
if v.events ~= nil then
events = v.events
if type(events) == "function" then
local pcall_ok, err = pcall(events, v)
if not pcall_ok then
return cleanup(self, thread, nil, "error calling method events: " .. safe_tostring(err))
end
events = err
end
if events == nil then
events = 0
elseif type(events) == "string" then
local e = 0
if events:match "r" then
e = e + POLLIN
end
if events:match "w" then
e = e + POLLOUT
end
if events:match "p" then
e = e + POLLPRI
end
events = e
elseif type(events) ~= "number" then
return cleanup(self, thread, nil, "invalid events (expected nil, number or string)", nil, thread.co, v, pollfd)
end
end
if v.timeout ~= nil then
timeout = v.timeout
if type(timeout) == "function" then
local pcall_ok, err = pcall(timeout, v)
if not pcall_ok then
return cleanup(self, thread, nil, "error calling method timeout: " .. safe_tostring(err))
end
timeout = err
end
end
end
if timeout == nil then
timeout = math.huge
elseif type(timeout) ~= "number" then
return cleanup(self, thread, nil, "invalid timeout (expected nil or number)", nil, thread.co, v, pollfd)
end
if timeout < math.huge or cond or (pollfd and events ~= 0) then
local event = {
thread = thread;
value = v;
deadline = now + timeout;
pending = false;
}
table.insert(thread.events, event)
if cond then
condition_add(cond, event)
end
if pollfd then
error("NYI pollfd")
end
if timeout then
thread_timeout = math.min(timeout, thread_timeout)
end
end
end
if thread.events[1] ~= nil or thread_timeout ~= math.huge then
if thread_timeout ~= math.huge then
-- add thread to timers
self.timers:add(thread, now+thread_timeout)
end
thread_move(thread, self.polling)
end
return true
else
return do_resume(self, thread, coroutine.yield(first, ...))
end
end
end
function methods:step(timeout)
if running() then
poll(self, timeout)
timeout = 0.0
end
assert(self.current == nil, "cannot step live cqueue")
if self.pending.next then
timeout = 0
else
timeout = timeout or math.huge
local t = self.timers:min()
if t then
timeout = math.min(timeout, t-monotime())
end
end
-- Find out what in self.polling is ready; take up to 'timeout'
if timeout > 0 then
print("WAITING FOR", timeout)
cqueues.sleep(timeout)
print("WAITED")
end
-- Move them to .pending list
local now = monotime()
for thread, deadline in self.timers:each() do
if deadline <= now then
for _, event in ipairs(thread.events) do
if event.deadline <= now then
event.pending = true
end
end
thread_move(thread, self.pending)
end
end
-- Run pending threads
local thread = self.pending.next
while thread do
local next_thread = thread.next -- Save next one incase current gets moved
local polled_ready, i = {}, 0
for _, v in ipairs(thread.events) do
if v.pending then
i = i + 1
polled_ready[i] = v.value
end
end
self.timers:remove(thread)
local ok, err, errno, thd, ob, fd = do_resume(self, thread, unpack(polled_ready, 1, i))
if not ok then
return nil, err, errno, thd, ob, fd
end
thread = next_thread
end
return true
end
end
function methods:attach(co)
assert(type(co) == "thread")
local thread = {co = co; events = {}; scheduler = self; next = nil; prev = nil}
thread_add_head(thread, self.pending)
self.thread_count = self.thread_count + 1
-- tryalert
return self
end
function methods:wrap(func, ...)
local co = coroutine.create(function(...)
coroutine.yield()
return func(...)
end)
coroutine.resume(co, ...)
local thread = {co = co; events = {}; scheduler = self; next = nil; prev = nil}
thread_add_head(thread, self.pending)
self.thread_count = self.thread_count + 1
-- tryalert
return self
end
function methods:empty()
return self.thread_count == 0
end
function methods:count()
return self.thread_count
end
function methods:cancel(...)
for _, v in each_arg(...) do
if type(v) ~= "number" then
v = v.pollfd
if type(v.pollfd) == "function" then
local ok, err = pcall(v.pollfd, v)
if not ok then
error("error calling method pollfd: "..safe_tostring(err))
end
v = err
end
if type(v) ~= "number" then
error("error loading field pollfd")
end
end
--
end
end
-- function methods:reset()
-- -- Move polling list to pending
-- local n = #self.pending
-- local e = #self.polling
-- for i=1, e do
-- self.pending[n+i] = self.polling[i]
-- self.polling[i] = nil
-- end
-- end
-- function methods:pollfd()
-- return nil
-- end
function methods:events()
return "r"
end
function methods:timeout()
if self.pending.next then
return 0
else
local t = self.timers:min()
if t then
return t - monotime()
else
return nil
end
end
end
local function todeadline(timeout)
-- special case 0 timeout to avoid monotime call in totimeout
return timeout and (timeout > 0 and monotime() + timeout or 0) or nil
end -- todeadline
local function totimeout(deadline)
if not deadline then
return nil, false
elseif deadline == 0 then
return 0, true
else
local curtime = monotime()
if curtime < deadline then
return deadline - curtime, false
else
return 0, true
end
end
end
function methods:loop(timeout)
local function checkstep(self, deadline, ok, ...)
local timeout, expired = totimeout(deadline)
if not ok then
return false, ...
elseif expired or self:empty() then
return true
else
return checkstep(self, deadline, self:step(timeout))
end
end
local deadline = todeadline(timeout)
return checkstep(self, deadline, self:step(timeout))
end
function methods:errors(timeout)
local deadline = todeadline(timeout)
return function ()
local timeout = totimeout(deadline)
return select(2, self:loop(timeout))
end
end
return {
_POLL = _POLL;
new = new;
running = running;
cancel = cancel;
reset = reset;
monotime = monotime;
poll = poll;
sleep = sleep;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment