Skip to content

Instantly share code, notes, and snippets.

@op8867555
Created January 28, 2018 12:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save op8867555/59d246a54188fe0d282656fe83e84a65 to your computer and use it in GitHub Desktop.
Save op8867555/59d246a54188fe0d282656fe83e84a65 to your computer and use it in GitHub Desktop.
Reverse-Mode Autodiff
from collections import OrderedDict, Iterable
import math
from itertools import chain
import operator as ops
class Var:
def __init__(self, v):
self.v = v
self.ss = 0
class Unary:
def __init__(self, v, x, dx):
self.v = v
self.x = x
self.dx = dx
self.ss = 0
class Binary:
def __init__(self, v, x, y, dx, dy):
self.v = v
self.x = x
self.dx = dx
self.y = y
self.dy = dy
self.ss = 0
class Tape(OrderedDict):
def __init__(self, xs):
super().__init__()
if not isinstance(xs, Iterable):
xs = [xs]
for v in xs:
self[v] = 0
def __add__(self, other):
xs = self.keys()
if isinstance(other, list):
ys = other
else:
ys = other.keys()
return Tape(chain(xs, ys))
def __repr__(self):
from pprint import pformat
return 'Tape({})'.format(pformat(list(self.keys())))
def _repr_pretty_(self, p, cycle):
return p.text(self.__repr__())
def var(v):
cell = Var(v)
return Expr(cell, tape=Tape(cell))
def unary(f, dfda):
def method(x):
if not isinstance(x, Expr):
return f(x)
cell = Unary(f(x.cell.v), x.cell, dfda(x.cell.v))
return Expr(cell, tape=x.tape + [cell])
return method
def binary(f, dfda, dfdb):
def method(self, other):
if not isinstance(other, Expr):
cell = Unary(f(self.cell.v, other), self.cell, dfda(self.cell.v, other))
return Expr(cell, tape=self.tape + [cell])
a = self.cell.v
b = other.cell.v
cell = Binary(f(a, b),
self.cell, other.cell,
dfda(a, b), dfdb(a, b))
return Expr(cell, tape=self.tape + other.tape+[cell])
return method
class Expr():
def __init__(self, cell, tape):
self.cell = cell
self.tape = tape
@property
def value(self):
return self.cell.v
@property
def sensitivity(self):
return self.cell.ss
def backprop(self):
for cell in self.tape:
cell.ss = 0
self.cell.ss = 1
for cell in reversed(self.tape):
if isinstance(cell, Binary):
cell.x.ss += cell.ss * cell.dx
cell.y.ss += cell.ss * cell.dy
elif isinstance(cell, Unary):
cell.x.ss += cell.ss * cell.dx
__add__ = binary(ops.add, lambda x, y: 1, lambda x, y: 1)
__sub__ = binary(ops.sub, lambda x, y: 1, lambda x, y: -1)
__mul__ = binary(ops.mul, lambda x, y: y, lambda x, y: x)
__truediv__ = binary(ops.truediv, lambda x, y: 1 / y, lambda x, y: -x / y ** 2)
__radd__ = __add__
__rsub__ = lambda self, other: unary(lambda x: other - x, lambda x: -1)(self)
__rmul__ = __mul__
__rtruediv__ = lambda self, other: unary(lambda x: other / x, lambda x: -1 / x ** 2)(self)
# TODO: Implement rest arith. operators
sin = unary(math.sin, math.cos)
exp = unary(math.exp, math.exp)
log = unary(math.log, lambda x: 1 / x)
tanh = unary(math.tanh, lambda x: 1 - math.tanh(x) ** 2)
# TODO: Implement rest math functions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment