Skip to content

Instantly share code, notes, and snippets.

@Varriount
Last active September 9, 2017 19:23
Show Gist options
  • Save Varriount/9e32806b5c9dbdfdea2b66ebc4e6b83f to your computer and use it in GitHub Desktop.
Save Varriount/9e32806b5c9dbdfdea2b66ebc4e6b83f to your computer and use it in GitHub Desktop.
import linalg
{.experimental.}
# Global templates
template foreach(expression, iter: untyped): untyped =
for element in mitems(iter):
expression(element)
type
BackPropRoot = ref object of RootObj
BackPropClosure[T, U] = proc (gradient: T): U {.noSideEffect, closure.}
## Proc for Backward propagation are typed `BackProp` and are (implicit) closures without side-effects
## To ease search, backward propagation procedures are prefixed with bp_
BackProp[T, U] = ref object of BackPropRoot
procedure: BackPropClosure[T, U]
## Represent an operation
## Stores the gradient transformation for backprop in weights
## Stores indices of parent operation in parents
Node = object
weights: array[2, BackPropRoot]
parents: array[2, int] #ref indices to parent nodes
## Tape / Wengert list. Contains the list of applied operations
Context* = object
nodes: ref seq[Node]
# BackProp implementations
method call[T, U](target: BackPropRoot, param: T): U {.base.} =
# Raise exception here
echo "Hi"
return result
method call[T, U](target: BackProp[T, U], param: T): U =
echo "Ho"
result = target.procedure(param)
proc newBackProp[T, U](p: BackPropClosure[T, U]): BackProp[T, U] =
new(result)
result.procedure = p
# Context implementations
proc newContext*(): Context {.noSideEffect.} =
## Initialize a context (Tape / Wengert list)
result.nodes = new seq[Node]
proc main() =
proc squareInt(gradient: int): int {.noSideEffect, closure.} =
result = gradient * gradient
proc squareFloat(gradient: float): float {.noSideEffect, closure.} =
result = gradient * gradient
var
ctx = newContext()
n: Node
n.weights[0] = newBackProp(squareFloat)
n.weights[1] = newBackProp(squareInt)
call(n.weights[0], 1)
call(n.weights[0], 1.0)
call(n.weights[1], 1)
call(n.weights[1], 1.0)
ctx.nodes.add(n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment