Skip to content

Instantly share code, notes, and snippets.

@yangchenyun
Last active March 13, 2023 15:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yangchenyun/6c801e151441258b83d2b17fa45d2b3d to your computer and use it in GitHub Desktop.
Save yangchenyun/6c801e151441258b83d2b17fa45d2b3d to your computer and use it in GitHub Desktop.
Symbolic differentiator
"""
The following program implements a numeric system with primitive data
structures and algorithms that can calculate derivatives for arbitrary
arithmetic expressions.
The core data structure is called `Dual`, which represents a number with its
numeric value and the operation performed in the expression AST. This means that
`Dual` remembers how the value is computed.
All primitive operations (including binary and unary) have the signature
`[]Duals -> Dual`. Additionally, primitive operations calculate the derivative
with respect to their parameters using differential calculus.
Composed operations use the chain rule to accumulate the result while traversing
the chain.
TODO:
- Merge the data type with python's number typing system (int, float)
- Implement other number protocols (__iadd__, __radd__, etc.)
- Support vector calculus
"""
from collections import defaultdict
import math
def treemap(fn, tree):
"""Treemap works similar of `map' but on a nested list instead of list.
It modifes the leaf nodes on a tree.
treemap(add1, [1, [2, [3, 4]]) => [2, [3, [4, 5]]
"""
if tree is None:
return tree
elif isinstance(tree, list):
return [treemap(fn, elem) for elem in tree]
else:
return fn(tree)
assert treemap(lambda x: x + 1, None) == None
assert treemap(lambda x: x + 1, []) == []
assert treemap(lambda x: x + 1, 1) == 2
assert treemap(lambda x: x + 1, [1]) == [2]
assert treemap(lambda x: x + 1, [1, [2, [3, 4]]]) == [2, [3, [4, 5]]]
def linkOperator(dual, accumulated, store):
"""LinkOperator takes on a dual, accumulated gradient value, and a store.
It returns a LinkOperator to construct new Duals.
This is just a function notation, the real implementation could be found below.
"""
pass
def endOperator(dual, accumulated, store):
"""Special LinkOperator to save the accumulated gradient value in a given store."""
store[dual] += accumulated
return store
def prim2(opName, valueFn, gradientFn):
"""Returns derivitive link function for binary operator: +,-,*,/,exp.
valueFn returns the numeric value
gradientFn returns the corresponding derivitive value according to formula
"""
def binaryOp(self, other):
if isinstance(other, int) or isinstance(other, float):
other = Dual(other)
def dop(dual, accumulated, store):
a = self
b = other
ga, gb = gradientFn(a.value, b.value, accumulated)
newStore = a.link(a, ga, store)
newStore = b.link(b, gb, newStore)
return newStore
name = f"{self.value} {opName} {other.value}"
return self.__class__(valueFn(self.value, other.value), dop, name=name)
return binaryOp
def prim1(opName, valueFn, gradientFn):
"""Returns derivitive link function for binary operator: +,-,*,/,exp.
valueFn returns the numeric value
gradientFn returns the corresponding derivitive value according to formula
"""
def unaryOp(self):
def dop(dual, accumulated, store):
a = self
ga = gradientFn(a.value, accumulated)
return a.link(a, (accumulated * ga), store)
name = f"{opName}({self.value})"
return self.__class__(valueFn(self.value), dop, name=name)
return unaryOp
class Dual:
def __init__(self, value, link=endOperator, name=None):
self.value = value
# link is a function which captures the operator's gradient calculation
# on the chain
self.link = link
if name is None:
self.name = str(value)
else:
self.name = name
def truncate(self):
"""Truncate reset the link information for the dual."""
self.link = endOperator
return self
def grad(self, store):
return self.link(self, 1.0, store)
# Binary operator
__add__ = prim2("+", lambda a, b: a + b, lambda a, b, z: (z, z))
__radd__ = prim2("+", lambda b, a: a + b, lambda b, a, z: (z, z))
__sub__ = prim2("-", lambda a, b: a - b, lambda a, b, z: (z, -z))
__rsub__ = prim2("-", lambda b, a: b - a, lambda b, a, z: (-z, z))
__mul__ = prim2("*", lambda a, b: a * b, lambda a, b, z: ((b * z), (a * z)))
__rmul__ = prim2("*", lambda b, a: a * b, lambda b, a, z: ((a * z), (b * z)))
__truediv__ = prim2(
"/", lambda a, b: a / b, lambda a, b, z: ((z * (1.0 / b)), (z * (-a / (b * b))))
)
__rtruediv__ = prim2(
"/", lambda b, a: a / b, lambda b, a, z: ((z * (-a / (b * b))), (z * (1.0 / b)))
)
# dx^y/dx = y * x^(y-1)
# dx^y/dy = x^y * ln(x)
__pow__ = prim2(
"^",
lambda a, b: a**b,
lambda a, b, z: ((z * (b * (a ** (b - 1)))), (z * (a**b * math.log(a)))),
)
# Unary operator
log = prim1("log", lambda a: math.log(a), lambda a, z: z * (1.0 / a))
exp = prim1("exp", lambda a: math.exp(a), lambda a, z: z * math.exp(a))
sqrt = prim1("sqrt", lambda a: math.sqrt(a), lambda a, z: z / (2 * math.sqrt(a)))
def __str__(self):
return f"[Dual(name={self.name}, val:{self.value})]"
def Del(fn, theta):
"""
Del is an operator, which takes in a function and a list of parameters
and return corresponding gradients.
"""
theta = treemap(lambda v: Dual(v), theta)
store = defaultdict(lambda: 0.0)
fn(*theta).grad(store)
# Return in respect to theta's shape
return treemap(lambda t: store[t], theta)
if __name__ == "__main__":
print(
[
# Basic binary operators
Del(lambda x, y: x + y, [3.0, 2.0]), # => [1.0, 1.0]
Del(lambda x, y: x * y, [2.0, 3.0]), # => [2.0, 3.0]
Del(lambda x, y: x - y, [2.0, 3.0]), # => [1.0, -1.0]
Del(lambda x, y: x / y, [2.0, 3.0]), # => [0.3333333, -0.2222]
Del(lambda x, y: x**y, [2.0, 3.0]), # => [12.0, 5.54517]
# Unary operator
Del(lambda x: x.exp(), [2.0]), # => e^2
Del(lambda x: x.log(), [2.0]), # => 0.5
Del(lambda x: x.sqrt(), [2.0]), # => 0.3535
# Composed function
Del(lambda x, y: (x * x) + x + y, [3.0, 2.0]), # => [7.0, 1.0]
# fn = (e^x + log(y) + sqrt(x))/y
Del(lambda x, y: (x.exp() * y.log() + x.sqrt()) / y, [3.0, 2.0]),
# support python int and float
Del(lambda x, b: (x * x * 2.0) + x * 1 + b, [3.0, 2.0]),
# right hand operator
Del(lambda x: 2.0 + x, [3.0]),
Del(lambda x: 2.0 - x, [3.0]),
Del(lambda x: 2.0 * x, [3.0]),
Del(lambda x: 2.0 / x, [3.0]),
Del(lambda x, b: (2.0 * x * x) + 1 * x + b, [3.0, 2.0]),
]
)
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment