Skip to content

Instantly share code, notes, and snippets.

@HugoGranstrom
Created November 1, 2021 19:04
Show Gist options
  • Save HugoGranstrom/a38f17f26fe67ca34ea3e34e06c2cfde to your computer and use it in GitHub Desktop.
Save HugoGranstrom/a38f17f26fe67ca34ea3e34e06c2cfde to your computer and use it in GitHub Desktop.
Basic Register adjoint procs
import macros, macrocache, math
const pullbackFunctions* = CacheTable"lol"
proc `$`(t: CacheTable): string =
result = "Table("
for (key, val) in t.pairs:
result &= key & ": " & $val.repr & ", "
result &= ")"
proc getParams*(f: NimNode): tuple[returnType: NimNode, params: seq[tuple[ident: NimNode, paramType: NimNode]]] =
let params = f[3]
params.expectKind nnkFormalParams
result.returnType = params[0]
for i in 1 .. params.len-1: # loop over parameter groups (x, y: float)
let currentType = params[i][params[i].len - 2]
for j in 0 .. params[i].len - 3:
result.params.add (ident: params[i][j], paramType: currentType)
proc isAssignment*(n: NimNode): bool = n.kind in [nnkAsgn, nnkLetSection, nnkVarSection]
proc getProcName*(n: NimNode): NimNode =
case n[0].kind
of nnkIdent, nnkSym:
result = n[0]
of nnkPostFix:
result = n[0][1]
else:
error("Error: getProcName doesn't support kind = " & $n[0].kind)
macro adjoint*(f: untyped): untyped =
f.expectKind(nnkProcDef)
let procName = getProcName(f)
let (returnType, params) = getParams(f)
# TODO: check that returnType is a tuple of correct type
result = f.copy()
let newName = genSym(nskProc, procName.strVal & "pullback")
result[0] = nnkPostFix.newTree(ident"*", newName)
let key = procName.strVal & $params
pullbackFunctions[key] = newName
macro pullback*(f: typed, args: varargs[typed]): untyped =
let p = f.getImpl
p.expectKind(nnkProcDef)
let (returnType, params) = p.getParams
assert params.len == args.len, "Number of arguments doesn't match!"
let procName = p.getProcName
let key = procName.strVal & $params
let pullbackFunc = pullbackFunctions[key]
result = nnkCall.newTree(pullbackFunc)
for arg in args:
result.add arg
proc f(x, y: float, z: int): float =
let a = sin(x)
let b = cos(a)
return b
proc f(x: float, y: float, z: int): (float, proc(y: float): (float, float, int)) {.adjoint.} =
result[0] = f(x, y, z)
result[1] = proc(y: float): (float, float, int) = (1.0, 2.0, 2)
discard pullback(f, 1.0, 2.0, 3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment