Last active
December 24, 2021 05:43
-
-
Save glyh/c411ee7dbff26ef49c2048c1fbbe4f9c to your computer and use it in GitHub Desktop.
Macros for subclassing in nimpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import nimpy, std/macros, ast_pattern_matching, sequtils, unittest, strutils | |
let py = pyBuiltinsModule() | |
macro pyInjectMethods*(obj, memFuncs: untyped): untyped = | |
obj.expectKind nnkIdent | |
memFuncs.expectKind nnkStmtList | |
var stmts = newStmtList() | |
for mem_func in memFuncs.children: | |
memFunc.matchAst: | |
of nnkProcDef(`func_name` @ nnkIdent, _, _, `params` @ nnkFormalParams, _, _, _): | |
var params_updated : NimNode = params | |
paramsUpdated.del(1) | |
var call = newNimNode(nnkCall) | |
call.add(func_name, obj) | |
for i, x in params.pairs: | |
if i == 0: continue | |
call.add(x[0]) | |
let | |
wrapedStatements = quote do: | |
`memFunc` | |
`call` | |
wrapFunc = newProc( | |
params=toSeq(paramsUpdated.children), | |
body=wrapedStatements) | |
stmts.add(quote do: | |
let f = `wrapFunc` | |
`obj`.`funcName` = f) | |
else: | |
raise newException(FieldDefect, "Input data wrong") | |
return newBlockStmt(stmts) | |
macro pyInstance*(args: varargs[untyped]): untyped = | |
var bases = newNimNode(nnkTupleConstr) | |
for i in args[0..^2]: | |
bases.add(i) | |
let codes = args[^1] | |
let obj = genSym(nskVar) | |
var | |
stmts = newStmtList(quote do: | |
var `obj`: PyObject) | |
properties = newNimNode(nnkTableConstr) | |
for piece in codes.children: | |
stmts.add(piece) | |
piece.matchAst: | |
of nnkProcDef(`func_name` @ nnkIdent, _, _, `params` @ nnkFormalParams, _, _, _): | |
var params_updated : NimNode = params | |
params_updated.del(1) | |
var call = newNimNode(nnkCall) | |
call.add(func_name, obj) | |
for i, x in params.pairs: | |
if i == 0: continue | |
call.add(x[0]) | |
let f_sym = genSym(nskProc) | |
stmts.add(newProc( | |
name = f_sym, | |
params = toSeq(params_updated.children), | |
body = newStmtList(call) | |
)) | |
var item = newNimNode(nnkExprColonExpr) | |
item.add(newLit(func_name.strVal), f_sym) | |
properties.add(item) | |
of nnkVarSection: | |
for def in piece.children: | |
if def.kind in {nnkIdentDefs, nnkVarTuple}: | |
for i,x in def.pairs(): | |
if i >= len(def) - 2: | |
break | |
x.expectKind nnkIdent | |
var item = newNimNode(nnkExprColonExpr) | |
item.add(newLit(x.strVal), x) | |
properties.add(item) | |
else: | |
raise newException(FieldDefect, "Malformed var section!") | |
else: discard | |
let construct = quote("@") do: | |
@obj = py.`type`("_", @bases, toPyDict(@properties)).to(proc(): PyObject {.gcsafe.})() | |
@obj # The repeat is neccessary for pointers to attach to the object | |
stmts.add(construct) | |
return newBlockStmt(stmts) | |
test "multi inheritance": | |
discard py.exec(""" | |
class Foo: | |
def hello(self): | |
return "Hello, " | |
class Bar: | |
def world(self): | |
return "World!" | |
""".dedent()) | |
let | |
Foo = pyGlobals()["Foo"] | |
Bar = pyGlobals()["Bar"] | |
var a:PyObject = pyInstance(Foo, Bar): | |
proc connect(self: PyObject): string = | |
let (a, b) = (self.hello().to(string), self.world().to(string)) | |
a & b | |
assert(a.connect().to(string) == "Hello, World!") | |
test "Inject methods": | |
discard py.exec(""" | |
class Foo: | |
... | |
""".dedent()) | |
let Foo = pyGlobals()["Foo"] | |
var counter : PyObject = pyInstance(Foo): | |
var (a, b) = (0, 1) | |
pyInjectMethods(counter): | |
# This is actually not needed, but here I just add this to demonstrate | |
# how it works | |
proc set(self: PyObject, a: int, b: int) = | |
self.a = a | |
self.b = b | |
proc fib(self: PyObject): int = | |
let | |
(a, b) = (self.a.to(int), self.b.to(int)) | |
c = a + b | |
result = c | |
(self.a, self.b) = (b, c) | |
proc fib(x: int): int = | |
case x: | |
of 1, 2: x | |
else: fib(x-1) + fib(x-2) | |
for i in 1..20: | |
assert(counter.fib().to(int) == fib(i)) | |
discard counter.set(0, 1) | |
assert(counter.fib().to(int) == 1) | |
test "define variables in python classes in different ways": | |
discard py.exec(""" | |
class Foo: | |
... | |
""".dedent()) | |
let Foo = pyGlobals()["Foo"] | |
var crap : PyObject = pyInstance(Foo): | |
var | |
(a, b, c) = (1,2,3) | |
d = "wow" | |
e = 3.14 | |
var f = 99 | |
var g = @[1,2,3,4] | |
for i in 'a'..'g': | |
discard getAttr(crap, ($i).cstring) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment