Skip to content

Instantly share code, notes, and snippets.

@iffy
Last active May 7, 2019 18:20
Show Gist options
  • Save iffy/092f5b46760d9b4ba62a649afaa69f56 to your computer and use it in GitHub Desktop.
Save iffy/092f5b46760d9b4ba62a649afaa69f56 to your computer and use it in GitHub Desktop.
How to make proc `==`(a,b: Thing):bool for a variant object.
import sequtils
import strutils
import macros
import sugar
export macros
proc replaceNodes*(ast: NimNode): NimNode =
## Replace NimIdent and NimSym by a fresh ident node
##
## Use with the results of ``quote do: ...`` to get
## ASTs without symbol resolution having been done already.
proc inspect(node: NimNode): NimNode =
case node.kind:
of nnkIdent:
return ident(node.strVal)
of nnkSym:
return ident(node.strVal)
of nnkEmpty:
return node
of nnkLiterals:
return node
of nnkOpenSymChoice:
return inspect(node[0])
else:
var rTree = node.kind.newTree()
for child in node:
rTree.add inspect(child)
return rTree
result = inspect(ast)
proc findDescendent(x:NimNode, k:NimNodeKind): NimNode {.compileTime.} =
var stack = @[x]
while stack.len > 0:
let n = stack.pop()
if n.kind == k:
return n
for child in n.children:
stack.add(child)
proc mkComparisonBody(reclist:varargs[string,`$`]):NimNode {.compileTime.} =
for name in reclist:
let comparison = nnkInfix.newTree(
newIdentNode("=="),
nnkDotExpr.newTree(
newIdentNode("a"),
newIdentNode($name)
),
nnkDotExpr.newTree(
newIdentNode("b"),
newIdentNode($name)
)
)
if result.isNil:
result = comparison
else:
result = nnkInfix.newTree(
newIdentNode("and"),
result,
comparison,
)
proc mkComparisonBody(reclist:NimNode):NimNode {.compileTime.} =
reclist.expectKind(nnkRecList)
return mkComparisonBody(reclist.mapIt($it))
proc mkVariantEqProc(x:typedesc):NimNode {.compileTime.} =
let
typedef = x.getType()
recListNode = typedef.findDescendent(nnkRecList)
recCaseNode = recListNode.findDescendent(nnkRecCase)
if recCaseNode.isNil:
# not a variant object
return newStmtList()
recCaseNode[0].expectKind(nnkSym)
let
procname = newIdentNode("==")
typename = newIdentNode($x)
casefield = newIdentNode(recCaseNode[0].strVal())
# Fields common to all kinds
var
cmp_commonfields = newStmtList()
commonfields:seq[string]
for node in recListNode:
if node.kind == nnkSym:
commonfields.add(node.strVal())
if commonfields.len > 0:
let cmp = mkComparisonBody(commonfields)
cmp_commonfields = quote do:
if not (`cmp`):
return false
# Kind-specific fields
var case_stmt = nnkCaseStmt.newTree(
nnkDotExpr.newTree(
newIdentNode("a"),
casefield,
),
)
for branch in recCaseNode:
case branch.kind
of nnkOfBranch:
let body = mkComparisonBody(branch[1])
case_stmt.add(nnkOfBranch.newTree(
branch[0],
quote do:
result = `body`
))
of nnkElse:
let body = mkComparisonBody(branch[0])
case_stmt.add(nnkElse.newTree(quote do:
result = `body`
))
else:
discard
result = replaceNodes(quote do:
proc `procname` *(a,b: `typename`):bool =
if a.`casefield` != b.`casefield`:
return false
`cmp_commonfields`
`case_stmt`
)
template mkEqualityProc*(T:typedesc):untyped =
macro bootstrapEqualityProc(): untyped =
mkVariantEqProc(T)
bootstrapEqualityProc()
type
ThingKind = enum
Unknown,
Variant1,
Variant2,
Thing = object of RootObj
common*: string
case kind*: ThingKind
of Variant1:
field1*: string
of Variant2:
field2*: string
field3*: bool
else:
another*: string
mkEqualityProc(Thing)
# generates:
# proc `==`(a, b: Thing): bool =
# if (
# not (a.kind == b.kind)):
# return result = false
# if not (a.common == b.common):
# return result = false
# case a.kind
# of Variant1:
# result = a.field1 == b.field1
# of Variant2:
# result = a.field2 == b.field2 and a.field3 == b.field3
# else:
# result = a.another == b.another
assert Thing(kind:Variant1) != Thing(kind:Variant2)
assert Thing(kind:Variant1, field1:"hi") == Thing(kind:Variant1, field1:"hi")
assert Thing(kind:Variant1, field1:"hi") != Thing(kind:Variant1, field1:"bye")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment