Skip to content

Instantly share code, notes, and snippets.

@pervognsen
Last active January 18, 2024 02:30
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pervognsen/17e637e877040d336ba3abc7a13ef8d5 to your computer and use it in GitHub Desktop.
Save pervognsen/17e637e877040d336ba3abc7a13ef8d5 to your computer and use it in GitHub Desktop.
# Reverse-mode automatic differentiation
import math
# d(-x) = -dx
def func_neg(x):
return -x, [-1]
# d(x + y) = dx + dy
def func_add(x, y):
return x + y, [1, 1]
# d(x - y) = dx - dy
def func_sub(x, y):
return x - y, [1, -1]
# d(x y) = y dx + x dy
def func_mul(x, y):
return x * y, [y, x]
# d(x / y) = d(x 1/y) = 1/y dx - x/y^2 dy
def func_div(x, y):
return x / y, [1/y, -x/(y*y)]
# d(cos(x)) = -sin(x) dx
def func_cos(x):
return cos(x), [-sin(x)]
# d(sin(x)) = cos(x) dx
def func_sin(x):
return sin(x), [cos(x)]
# d(exp(x)) = exp(x) dx
def func_exp(x):
exp_x = exp(x)
return exp_x, [exp_x]
# d(log(x)) = 1/x dx
def func_log(x):
return log(x), [1/x]
# d(x**y) = d(exp(log(x) y)) = x**y y/x dx + x**y log(y) dy
def func_pow(x, y):
pow_xy = x**y
return pow_xy, [x**(y-1) * y, pow_xy * log(y)]
def func_when(x, y, z):
return when(x, y, z), [0, x, 1-x]
def func_le(x, y):
return x <= y, [0, 0]
def func_ge(x, y):
return x >= y, [0, 0]
class State:
def __init__(self, value=0, weights=()):
self.value = value
self.weights = weights
class Node:
def __init__(self, func, args):
self.func = func
self.args = args
def __neg__(self):
return make_node(func_neg, self)
def __add__(self, other):
return make_node(func_add, self, other)
def __radd__(self, other):
return make_node(func_add, other, self)
def __sub__(self, other):
return make_node(func_sub, self, other)
def __rsub__(self, other):
return make_node(func_sub, other, self)
def __mul__(self, other):
return make_node(func_mul, self, other)
def __rmul__(self, other):
return make_node(func_mul, other, self)
def __pow__(self, other):
return make_node(func_pow, self, other)
def __rpow__(self, other):
return make_node(func_pow, other, self)
def __truediv__(self, other):
return make_node(func_div, self, other)
def __rtruediv__(self, other):
return make_node(func_div, other, self)
def __le__(self, other):
return make_node(func_le, self, other)
def __ge__(self, other):
return make_node(func_ge, self, other)
def evaluate(self, bindings={}):
states = {node: State(value) for node, value in bindings.items()}
def visit(node):
if node in states:
return states[node].value
value, weights = node.func(*(visit(arg) for arg in node.args))
states[node] = State(value, weights)
return value
visit(self)
return states
def gradients(self, bindings={}):
states = self.evaluate(bindings)
gradients = {node: 0 for node in states}
gradients[self] = 1
for node, state in reversed(list(states.items())):
gradient = gradients[node]
for arg, weight in zip(node.args, state.weights):
gradients[arg] += weight * gradient
return gradients
def memo(func):
cache = {}
def wrapped(*args):
if args not in cache:
value = cache[args] = func(*args)
else:
value = cache[args]
return value
wrapped.__name__ = func.__name__
return wrapped
@memo
def const(value):
return make_node(lambda: (value, ()))
@memo
def make_node(func, *args):
return Node(func, [arg if isinstance(arg, Node) else const(arg) for arg in args])
none = make_node(lambda: 0, ())
class Var(Node):
def __init__(self, value=None):
super().__init__(self._func, ())
self.value = value
def _func(self):
if self.value is None:
raise ValueError("Unassigned variable")
return self.value, ()
def var(value=None):
return Var(value)
def wrap_unary(math_func, node_func):
def wrapper(x):
return make_node(node_func, x) if isinstance(x, Node) else math_func(x)
wrapper.__name__ = math_func.__name__
return wrapper
cos = wrap_unary(math.cos, func_cos)
sin = wrap_unary(math.sin, func_sin)
exp = wrap_unary(math.exp, func_exp)
log = wrap_unary(math.log, func_log)
def when(x, y, z):
if isinstance(x, Node) or isinstance(y, Node) or isinstance(z, Node):
return make_node(func_when, x, y, z)
else:
return y if x else z
# Tests
x = var(2)
y = var(3)
f = exp(sin(x * y) / y)
x0 = x.value
y0 = y.value
gradients = f.gradients()
print(gradients[x]) # 0.8747797595113476
print(exp(sin(x0*y0)/y0)*cos(x0*y0)) # 0.8747797595113477
print(gradients[y]) # 0.6114716536871055
print(exp(sin(x0*y0)/y0)*(x0*y0*cos(x0*y0) - sin(x0*y0))/(y0**2)) # 0.6114716536871057
def hessian(node, *args):
gradients = node.gradients({arg: arg for arg in args})
return {arg: gradients[arg].gradients() for arg in args}
gradients = hessian(f, x, y)
print(gradients[x][x]) # 1.603636510793541
print(exp(sin(x0*y0)/y0)*(cos(x0*y0)**2 - y0*sin(x0*y0))) # 1.603636510793541
g = when(x >= 0, x**2, -x**3)
print(g.gradients({x: 2})[x]) # 4
print(g.gradients({x: -2})[x]) # -12
print(g.gradients({x: x})[x].gradients({x: 1})[x]) # 2
print(g.gradients({x: x})[x].gradients({x: -1})[x]) # 6
class LinearModel:
def __init__(self, a=1, b=0):
self.a = var(a)
self.b = var(b)
def __call__(self, x):
return self.a*x + self.b
def loss(self, data):
return sum((self(x) - y)**2 for x, y in data) / len(data)
def train(self, data, rate):
gradients = self.loss(data).gradients()
self.a.value -= rate * gradients[self.a]
self.b.value -= rate * gradients[self.b]
import random
model = LinearModel()
n = 1000
data = [(1 + 10*i/n, 2*(1 + 10*i/n) - 3) for i in range(n)]
random.shuffle(data)
for i in range(10):
for point in data:
model.train([point], 0.01)
print(model.a.value, model.b.value) # 1.999999999999998 -2.9999999999999893
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment