Last active
September 9, 2017 19:23
-
-
Save Varriount/9e32806b5c9dbdfdea2b66ebc4e6b83f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import linalg | |
{.experimental.} | |
# Global templates | |
template foreach(expression, iter: untyped): untyped = | |
for element in mitems(iter): | |
expression(element) | |
type | |
BackPropRoot = ref object of RootObj | |
BackPropClosure[T, U] = proc (gradient: T): U {.noSideEffect, closure.} | |
## Proc for Backward propagation are typed `BackProp` and are (implicit) closures without side-effects | |
## To ease search, backward propagation procedures are prefixed with bp_ | |
BackProp[T, U] = ref object of BackPropRoot | |
procedure: BackPropClosure[T, U] | |
## Represent an operation | |
## Stores the gradient transformation for backprop in weights | |
## Stores indices of parent operation in parents | |
Node = object | |
weights: array[2, BackPropRoot] | |
parents: array[2, int] #ref indices to parent nodes | |
## Tape / Wengert list. Contains the list of applied operations | |
Context* = object | |
nodes: ref seq[Node] | |
# BackProp implementations | |
method call[T, U](target: BackPropRoot, param: T): U {.base.} = | |
# Raise exception here | |
echo "Hi" | |
return result | |
method call[T, U](target: BackProp[T, U], param: T): U = | |
echo "Ho" | |
result = target.procedure(param) | |
proc newBackProp[T, U](p: BackPropClosure[T, U]): BackProp[T, U] = | |
new(result) | |
result.procedure = p | |
# Context implementations | |
proc newContext*(): Context {.noSideEffect.} = | |
## Initialize a context (Tape / Wengert list) | |
result.nodes = new seq[Node] | |
proc main() = | |
proc squareInt(gradient: int): int {.noSideEffect, closure.} = | |
result = gradient * gradient | |
proc squareFloat(gradient: float): float {.noSideEffect, closure.} = | |
result = gradient * gradient | |
var | |
ctx = newContext() | |
n: Node | |
n.weights[0] = newBackProp(squareFloat) | |
n.weights[1] = newBackProp(squareInt) | |
call(n.weights[0], 1) | |
call(n.weights[0], 1.0) | |
call(n.weights[1], 1) | |
call(n.weights[1], 1.0) | |
ctx.nodes.add(n) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment