Created
January 28, 2018 12:27
-
-
Save op8867555/59d246a54188fe0d282656fe83e84a65 to your computer and use it in GitHub Desktop.
Reverse-Mode Autodiff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict, Iterable | |
import math | |
from itertools import chain | |
import operator as ops | |
class Var: | |
def __init__(self, v): | |
self.v = v | |
self.ss = 0 | |
class Unary: | |
def __init__(self, v, x, dx): | |
self.v = v | |
self.x = x | |
self.dx = dx | |
self.ss = 0 | |
class Binary: | |
def __init__(self, v, x, y, dx, dy): | |
self.v = v | |
self.x = x | |
self.dx = dx | |
self.y = y | |
self.dy = dy | |
self.ss = 0 | |
class Tape(OrderedDict): | |
def __init__(self, xs): | |
super().__init__() | |
if not isinstance(xs, Iterable): | |
xs = [xs] | |
for v in xs: | |
self[v] = 0 | |
def __add__(self, other): | |
xs = self.keys() | |
if isinstance(other, list): | |
ys = other | |
else: | |
ys = other.keys() | |
return Tape(chain(xs, ys)) | |
def __repr__(self): | |
from pprint import pformat | |
return 'Tape({})'.format(pformat(list(self.keys()))) | |
def _repr_pretty_(self, p, cycle): | |
return p.text(self.__repr__()) | |
def var(v): | |
cell = Var(v) | |
return Expr(cell, tape=Tape(cell)) | |
def unary(f, dfda): | |
def method(x): | |
if not isinstance(x, Expr): | |
return f(x) | |
cell = Unary(f(x.cell.v), x.cell, dfda(x.cell.v)) | |
return Expr(cell, tape=x.tape + [cell]) | |
return method | |
def binary(f, dfda, dfdb): | |
def method(self, other): | |
if not isinstance(other, Expr): | |
cell = Unary(f(self.cell.v, other), self.cell, dfda(self.cell.v, other)) | |
return Expr(cell, tape=self.tape + [cell]) | |
a = self.cell.v | |
b = other.cell.v | |
cell = Binary(f(a, b), | |
self.cell, other.cell, | |
dfda(a, b), dfdb(a, b)) | |
return Expr(cell, tape=self.tape + other.tape+[cell]) | |
return method | |
class Expr(): | |
def __init__(self, cell, tape): | |
self.cell = cell | |
self.tape = tape | |
@property | |
def value(self): | |
return self.cell.v | |
@property | |
def sensitivity(self): | |
return self.cell.ss | |
def backprop(self): | |
for cell in self.tape: | |
cell.ss = 0 | |
self.cell.ss = 1 | |
for cell in reversed(self.tape): | |
if isinstance(cell, Binary): | |
cell.x.ss += cell.ss * cell.dx | |
cell.y.ss += cell.ss * cell.dy | |
elif isinstance(cell, Unary): | |
cell.x.ss += cell.ss * cell.dx | |
__add__ = binary(ops.add, lambda x, y: 1, lambda x, y: 1) | |
__sub__ = binary(ops.sub, lambda x, y: 1, lambda x, y: -1) | |
__mul__ = binary(ops.mul, lambda x, y: y, lambda x, y: x) | |
__truediv__ = binary(ops.truediv, lambda x, y: 1 / y, lambda x, y: -x / y ** 2) | |
__radd__ = __add__ | |
__rsub__ = lambda self, other: unary(lambda x: other - x, lambda x: -1)(self) | |
__rmul__ = __mul__ | |
__rtruediv__ = lambda self, other: unary(lambda x: other / x, lambda x: -1 / x ** 2)(self) | |
# TODO: Implement rest arith. operators | |
sin = unary(math.sin, math.cos) | |
exp = unary(math.exp, math.exp) | |
log = unary(math.log, lambda x: 1 / x) | |
tanh = unary(math.tanh, lambda x: 1 - math.tanh(x) ** 2) | |
# TODO: Implement rest math functions |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment