Skip to content

Instantly share code, notes, and snippets.

@bartvm
Created March 7, 2018 20:00
Show Gist options
  • Save bartvm/bfed6fac8786bbd1a78ed42c38ec6940 to your computer and use it in GitHub Desktop.
Save bartvm/bfed6fac8786bbd1a78ed42c38ec6940 to your computer and use it in GitHub Desktop.
def f(x, y):
z = x * y
return z
# Operator overloading
adjoints = {
'mul': lambda dz, x, y: dz * y, dz * x
}
def f(x, y):
# Forward
tape.append(('mul', x, y))
z = x * y
# Backward
dz = 1
op, *args = tape.pop()
dx, dy = adjoints[op](*args)
return dx, dy
# Tape-based SCT
def f(x, y):
# Forward
tape.append(x)
tape.append(y)
z = x * y
# Backward
dz = 1
y = tape.pop()
x = tape.pop()
dx = dz * y
dy = dz * x
return dx, dy
# Backpropagator SCT
def f(x, y):
# Forward
z = x * y
def df(dz):
return dz * y, dz * dz * x
# Backward
dz = 1
dx, dy = df(dz)
return dx, dy
# Delimited continuation SCT
def f(x, y):
# This forward step gets automatically captured by the shift/reset control operators
def mul(k):
# One step forward
z = x * y
# Rest of the forward pass
k(z)
# One step backward
x.d = z.d * y
y.d = z.d * x
# Start backward
z.d = 1
# Rest of the forward pass is nothing
mul(lambda x: None)
return x.d, y.d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment