Skip to content

Instantly share code, notes, and snippets.

@nsf
Last active December 13, 2019 14:55
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save nsf/466640dd541ed71f5479 to your computer and use it in GitHub Desktop.
Save nsf/466640dd541ed71f5479 to your computer and use it in GitHub Desktop.
Coroutines in nim
import queues
import locks
import macros
import sequtils
type
SchedulerCommandType = enum
scDone,
scYield,
scWaitForCoroutine,
scWaitForCoroutines,
SchedulerCommand = object
case kind: SchedulerCommandType
of scDone: discard
of scYield: discard
of scWaitForCoroutine:
coroutine: CoroutineBase
of scWaitForCoroutines:
coroutines: seq[CoroutineBase]
Counter = ref object
count: int
waitingCoroutine: CoroutineBase
CoroutineBase = ref object of RootObj
iter: iterator(): SchedulerCommand
counter: Counter
Coroutine[T] = ref object of CoroutineBase
when T is not void:
result: T
type ThreadQueue[T] = object
queue: Queue[T]
mutex: Lock
cond: Cond
proc initThreadQueue[T]: ThreadQueue[T] =
result.queue = initQueue[T]()
initLock(result.mutex)
initCond(result.cond)
proc push[T](tq: var ThreadQueue[T], item: T) =
tq.mutex.acquire()
defer: tq.mutex.release()
tq.queue.enqueue(item)
tq.cond.signal()
proc pop[T](tq: var ThreadQueue[T]): T =
tq.mutex.acquire()
defer: tq.mutex.release()
while tq.queue.len == 0:
tq.cond.wait(tq.mutex)
tq.queue.dequeue()
var queue = initThreadQueue[CoroutineBase]()
proc go(c: CoroutineBase) =
queue.push(c)
proc schedule(c: CoroutineBase) =
let cmd = c.iter()
case cmd.kind
of scDone:
if c.counter != nil and atomicDec(c.counter.count) == 0:
go c.counter.waitingCoroutine
c.counter = nil
of scYield:
go c
of scWaitForCoroutine:
let counter = Counter(count: 1, waitingCoroutine: c)
cmd.coroutine.counter = counter
go cmd.coroutine
of scWaitForCoroutines:
let counter = Counter(count: cmd.coroutines.len, waitingCoroutine: c)
for c in cmd.coroutines:
c.counter = counter
go c
proc worker(n: int) =
echo "worker " & $n & " on duty"
while true:
let c = queue.pop()
if c == nil:
break
schedule(c)
echo "worker " & $n & " shutting down"
var workers: array[4, Thread[int]]
for i in 0..high(workers):
createThread(workers[i], cast[proc(n: int) {.gcsafe.}](worker), i)
proc shutdown() =
for i in 0..high(workers):
queue.push(nil)
template ppTree(e: expr) =
echo "-------------------- AST ---------------------"
echo treeRepr(e)
echo "------------------- Code ---------------------"
echo toStrLit(e)
echo "----------------------------------------------"
proc wrapAwaitValue(tmpSym, cmd, n: NimNode): NimNode =
result = newNimNode(nnkStmtList, n).add(
newLetStmt(
tmpSym,
newCall(
newIdentNode(!"coroutineAwaitValue"),
cmd[1]
)
),
newNimNode(nnkYieldStmt).add(
newDotExpr(tmpSym, newIdentNode(!"command"))
),
)
# Let's recursively convert various aspects of a coroutine's body
# 1. return expr
# ->
# retCoroutine.result = expr
# yield SchedulerCommand(kind: scDone)
#
# 2. return
# ->
# retCoroutine.result = result
# yield SchedulerCommand(kind: scDone)
# or (if there is no result)
# yield SchedulerCommand(kind: scDone)
#
# 3. await expr
# ->
# yield coroutineAwait(expr)
#
# 4. let x = await expr
# ->
# let tmp = coroutineAwaitValue(expr)
# yield tmp.command
# let x = tmp.value
#
# 5. var x = await expr (same as 4)
# 6. x = await expr (same as 4)
# 7. discard await expr (same as 4)
# 8. try statements are not allowed
proc convertToCoroutineBody(n, retSym: NimNode, hasResult: bool): NimNode =
result = n
case n.kind
of nnkReturnStmt:
result = newNimNode(nnkStmtList, n)
if n[0].kind == nnkEmpty:
# return
if hasResult:
result.add(
newAssignment(
newDotExpr(retSym, newIdentNode(!"result")),
newIdentNode(!"result"),
)
)
else:
# return expr
if not hasResult:
error("Non-void return inside a void coroutine")
result.add(
newAssignment(
newDotExpr(retSym, newIdentNode(!"result")),
n[0],
)
)
result.add(
newNimNode(nnkYieldStmt).add(
newNimNode(nnkObjConstr).add(
bindSym"SchedulerCommand",
newColonExpr(
newIdentNode(!"kind"),
bindSym"scDone"
)
)
)
)
of nnkCommand, nnkCall:
if n[0].kind == nnkIdent and n[0].ident == !"await":
# await expr
expectLen(n, 2)
result = newNimNode(nnkYieldStmt, n).add(
newCall(
newIdentNode(!"coroutineAwait"),
n[1]
)
)
of nnkVarSection, nnkLetSection:
let cmd = n[0][2]
case cmd.kind
of nnkCommand, nnkCall:
if cmd[0].kind == nnkIdent and cmd[0].ident == !"await":
# let x = await expr
expectLen(cmd, 2)
let tmpSym = genSym(nskLet, "await" & $n[0][0].ident)
result = wrapAwaitValue(tmpSym, cmd, n).add(
newNimNode(n.kind).add(
newNimNode(nnkIdentDefs).add(
n[0][0],
newNimNode(nnkEmpty),
newCall(
newDotExpr(tmpSym, newIdentNode(!"value"))
)
)
)
)
else:
discard
of nnkAsgn:
let cmd = n[1]
case cmd.kind
of nnkCommand, nnkCall:
if cmd[0].kind == nnkIdent and cmd[0].ident == !"await":
# x = await expr
expectLen(cmd, 2)
let tmpSym = genSym(nskLet, "await" & $n[0].ident)
result = wrapAwaitValue(tmpSym, cmd, n).add(
newAssignment(
n[0],
newCall(
newDotExpr(tmpSym, newIdentNode(!"value"))
)
)
)
else:
discard
of nnkDiscardStmt:
let cmd = n[0]
case cmd.kind
of nnkCommand, nnkCall:
if cmd[0].kind == nnkIdent and cmd[0].ident == !"await":
# discard await x
expectLen(cmd, 2)
let tmpSym = genSym(nskLet, "awaitDiscard")
result = wrapAwaitValue(tmpSym, cmd, n)
else:
discard
of nnkTryStmt:
error("try statements are not allowed in coroutine functions")
else: discard
# TODO: implicit return?
for i in 0..<result.len:
result[i] = convertToCoroutineBody(result[i], retSym, hasResult)
# We create a coroutine with an iterator here:
# let retCoroutine = Coroutine[T]()
# retCoroutine.iter = iterator() SchedulerCommand =
# {.push warning[resultshadowed]: off.}
# var result: T
# {.pop.}
# # <<< body >>>
# retCoroutine.result = result
# retCoroutine
#
# Existing body will be preprocessed and included as iterator body
proc convertToCoroutine(n: NimNode): NimNode =
#ppTree(n)
if n.kind notin {nnkProcDef, nnkLambda}:
error("Cannot transform this node kind into a coroutine")
hint("Converting " & $n[0].ident & " to coroutine")
let unRetType = n[3][0]
var retType: NimNode
case unRetType.kind
of nnkBracketExpr:
if unRetType[0].ident != !"Coroutine":
error("Return type of a coroutine should be Coroutine[T] or void")
retType = unRetType[1]
of nnkEmpty:
retType = newIdentNode(!"void") # all good, no return type means void
else:
error("Return type of a coroutine should be Coroutine[T] or void")
let retSym = genSym(nskLet, "retCoroutine")
let hasResult = retType.ident != !"void"
let coBody = newNimNode(nnkStmtList, n[6]) # second arg is used for line info
let itBody = convertToCoroutineBody(n[6], retSym, hasResult)
if hasResult:
itBody.insert(0,
newNimNode(nnkPragma).add(
newIdentNode("push"),
newNimNode(nnkExprColonExpr).add(
newNimNode(nnkBracketExpr).add(
newIdentNode("warning"),
newIdentNode("resultshadowed")
),
newIdentNode("off")
)
)
)
itBody.insert(1,
newNimNode(nnkVarSection, n[6]).add(
newIdentDefs(newIdentNode("result"), retType)
)
)
itBody.insert(2,
newNimNode(nnkPragma).add(newIdentNode("pop"))
)
itBody.add(
newAssignment(
newDotExpr(retSym, newIdentNode(!"result")),
newIdentNode(!"result"),
)
)
else:
discard
coBody.add(
newLetStmt(
retSym,
newCall(
newNimNode(nnkBracketExpr, n[6]).add(newIdentNode(!"Coroutine"), retType)
)
)
)
coBody.add(
newAssignment(
newDotExpr(retSym, newIdentNode(!"iter")),
newProc(
procType = nnkIteratorDef,
params = [bindSym"SchedulerCommand"],
body = itBody,
)
)
)
coBody.add(retSym)
result = n
# TODO: do I need this?
# for i in 0..<result[4].len:
# if result[4][i].kind == nnkIdent and result[4][i].ident == !"coroutine":
# result[4].del(i)
result[6] = coBody
ppTree(result)
macro coroutine(n: stmt): stmt {.immediate.} =
convertToCoroutine(n)
#==============================================================================
type Awaiter[T] = object
command: SchedulerCommand
value: proc(): T
proc coroutineAwait(coroutines: seq[CoroutineBase]): SchedulerCommand =
SchedulerCommand(kind: scWaitForCoroutines, coroutines: coroutines)
proc coroutineAwait(coroutine: CoroutineBase): SchedulerCommand =
SchedulerCommand(kind: scWaitForCoroutine, coroutine: coroutine)
proc coroutineAwaitValue[T](coroutines: seq[Coroutine[T]]): Awaiter[seq[T]] =
let cb = map(coroutines, proc(c: Coroutine[T]): CoroutineBase = c)
result.command = SchedulerCommand(kind: scWaitForCoroutines, coroutines: cb)
result.value = proc(): seq[T] =
result = newSeq[T](coroutines.len)
for i in 0..high(coroutines):
result[i] = coroutines[i].result
proc computeNumber(i, n: int): Coroutine[int] {.coroutine.} =
echo "computing... ", i
var count = i
for j in 0..<n:
count += j
return count
proc computeAll(n: int): Coroutine[void] {.coroutine.} =
var s = newSeq[Coroutine[int]]()
for i in 0..<100:
s.add(computeNumber(i, n))
let values = await s
for v in values:
echo v
echo "done"
shutdown()
queue.push(computeAll(10000))
joinThreads(workers)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment