Last active August 9, 2020 05:09
nim getType() fun
# here is an implementation of curry, arguments passed are wrapped up in a new
# closure and a function is returned that accepts the remaining arguments
# usage:
# proc foo (a,b: int): int = a + b
# let f = curry(foo, 10)
# assert f(10) == 20
# note: to use an overloaded function you must annotate its type
# curry((proc(c:char,len:int):string)strutils.repeat, 'x')
import macros
proc type_to_nim (n:NimNode): NimNode {.compileTime.} =
# returns a symbol for a type
# n should be from the typegraph returned by macros.getType
result = case n.typeKind
of ntyRef, ntyPtr, ntyRange, ntyProc:
echo n.typeKind
echo n.treerepr
when false:
#now you can use the symbol returned from getType() in AST for
#the represented type, so this old function is useless
proc type_to_nim (n: NimNode): NimNode {.compileTime.} =
let ty = n.typeKind
case ty
of ntyRef:
result = newNimNode(nnkRefTy).add(n[1].type_to_nim)
of ntyPtr:
result = newNimNode(nnkRefTy).add(n[1].type_to_nim)
of ntyRange:
result = newNimNode(nnkBracketExpr).add(
infix(n[1].type_to_nim, "..", n[2].type_to_nim))
of ntyArray:
result = newNimNode(nnkBracketExpr).add(
n[1].type_to_nim, n[2].type_to_nim)
of ntyEmpty:
result = newEmptyNode()
of ntyProc:
let params = newNimNode(nnkFormalParams)
params.add n[1].type_to_nim
for i in 2 .. len(n)-1:
params.add newIdentDefs(ident("arg"& $(i-2)), n[i].type_to_nim)
result = newNimNode(nnkProcTy).add(params, newEmptyNode())
of ntyInt .. ntyFloat128, ntyString:
if n.kind == nnkSym:
result = ident($ n.symbol)
# literal?
result = n
echo ty, ": ", n.repr
quit "unhandled type"
macro curry (f:stmt; args:varargs[expr]): expr =
let ty = getType(f)
assert($ty[0] == "proc", "first param is not a function")
let n_remaining = ty.len - 2 - args.len
assert n_remaining > 0, "cannot curry all the parameters"
#echo treerepr ty
var callExpr = newCall(f)
args.copyChildrenTo callExpr
var params: seq[NimNode] = @[]
# return type
params.add ty[1].type_to_nim
for i in 0 .. <n_remaining:
let param = ident("arg"& $i)
params.add newIdentDefs(param, ty[i+2+args.len].type_to_nim2)
callExpr.add param
result = newProc(procType = nnkLambda, params = params, body = callExpr)
when defined(Debug):
when isMainModule:
# need better examples
proc foo (a,b,c: int): int =
a + b * c
let f = curry(foo, 1)
assert f(2, 3) == 7 # 1+2*3 = 7
assert curry(foo,1,2)(3) == 7
proc qux (x:int, y:float, z:float): int =
(x.float * y * z).int
let f1 = curry(qux, 42)
let f2 = curry(f1, 100.0)
assert f2(6) == int(42 * 100 * 6)
import strutils
let fz = curry((proc(c:char,n:int):string)strutils.repeat, 'x')
assert fz(3) == "xxx"
# this is an implementation of `==` that properly handles variant types
# this is the expression generated for the type `Variant` near line 122:
discard """
a.xx == b.xx and a.k == b.k and a.k2 == b.k2 and
case a.k
of aa, cc: a.i == b.i and a.z == b.z
of bb: a.f == b.f
of dd: true
else: true
case a.k2
of xa: a.xi == b.xi
of xb: a.xb == b.xb
import macros
proc flatten_reclist (n:NimNode): seq[NimNode] {.compileTime.}=
# flatten a reclist of syms or just a sym into a seq
result.newseq 0
if n.kind == nnkRecList:
for child in n.children: result.add child
elif n.kind == nnkSym:
result.add n
proc `==` [T: object] (a,b: T): bool =
macro mkEq (a,b: stmt): expr =
ty = a.getType
tyKind = ty.typeKind
#echo ty.typeKind, " : ", ty.treerepr
case tyKind
of ntyObject:
when defined(Debug):
echo ty.treerepr
ty[1].expectKind nnkRecList
# the goal here is to check all of the objects normal fields first, including variant discriminators
# after that we will check the union fields
var varfields: seq[NimNode] = @[]
var conditions: seq[NimNode] = @[]
template checkcond (sym): expr =
newDotExpr(ident"a", ident($sym)),
newDotExpr(ident"b", ident($sym)))
for field in ty[1].children:
case field.kind
of nnkSym:
conditions.add checkcond(field.symbol)#parseExpr("a.$1 == b.$1".format($ field.symbol))
of nnkRecCase:
field[0].expectKind nnkSym
conditions.add checkcond(field[0].symbol) #parseExpr("a.$1 == b.$1".format($ field[0].symbol))
varfields.add field
quit "unexpect field member "& treerepr(field)
for vf in varfields:
let cs = newNimNode(nnkCaseStmt)
cs.add newDotExpr(ident"a", ident($ vf[0].symbol))
# iterate over "of"s
for i in 1 .. <len(vf):
let tyBranch = vf[i]
let newBranch = newNimNode(tyBranch.kind)
# last entry is reclist(sym, ...) or sym
let syms = flatten_reclist(tyBranch[< len(tyBranch)])
if tyBranch.len> 1:
for ii in 0 .. len(tyBranch)-2:
newBranch.add tyBranch[ii]
if syms.len > 0:
var res: NimNode
for s in syms:
if res.isNil:
res = checkcond(s.symbol)
res = infix(res, "and", checkcond(s.symbol))
newBranch.add res
# this is nil/discard so it always passes
newBranch.add ident"true"
cs.add newBranch
conditions.add newNimNode(nnkStmtListExpr).add(cs)
var res: NimNode
for c in conditions:
res =
if res.isNil: c
else: infix(res, "and", c)
result = res
if result.isNil:
result = ident"true"
of ntyRef, ntyPtr:
result = quote do: (if a.isNil: b.isNil elif b.isNil: false else: a[] == b[])
#parseExpr("return (if a.isNil: b.isNil elif b.isNil: false else: a[] == b[])")
# of ntyEnum:
# result.add parseExpr("return system.`==`(a,b)")
echo "typekind not handled for ==: "& $tyKind
result = quote do: system.`==`(a,b)
# ^ this causes error when its hit..? but its fine under ntyEnum
if result.isNil:
result = ident"false"
when defined(Debug):
echo "result: ", repr(result)
when isMainModule:
# a test
En = enum
En2 = enum
xa, xb
Variant = object
xx: int
case k: En
of aa,cc:
of bb:
of dd:
else: discard
case k2: En2
of xa: xi: int
of xb: xb: int
template test (boolExpr): stmt =
echo "[", (if boolExpr: "pass" else: "FAIL"), "] ", astToStr(boolExpr)
v1 = Variant(k: aa, i: 42)
v2 = Variant(k: bb, f: 1.9)
test v1 == v1
test v2 == v2
test v1 != v2
oderwat commented Mar 2, 2015

Thats pretty cool!

I ran into this thought... can that be fixed (I works when I define a helper proc to separate them into two different functions)?:

curry_1.nim(88, 21) Error: type mismatch: got (proc (string, int): string{.noSideEffect, gcsafe, locks: 0.} | proc (char, int): string{.noSideEffect, gcsafe, locks: 0.}, string)
but expected one of: 
proc repeat*(c: char, count: int): string {.noSideEffect,
  rtl, extern: "nsuRepeatChar".} =
  result = newString(count)
  for i in 0..count-1: result[i] = c

proc repeat*(s: string, n: int): string {.noSideEffect,
  rtl, extern: "nsuRepeatStr".} =
  ## Returns String `s` concatenated `n` times.
  result = newStringOfCap(n * s.len)
  for i in 1..n: result.add(s)

Copy link

oderwat commented Mar 12, 2015

Another Example (gives the all important Answer to the live, universe you get it...:

    proc multThree(a, b, c: int): int =
      a * b * c
    let multTwoWith7 = curry(multThree, 7)
    let multWith21 = curry(multTwoWith7, 3)
    echo multWith21 2

