Created
September 3, 2023 01:49
-
-
Save CubeFlix/cc906cd4fe24e954944f1f15af2b3b0b to your computer and use it in GitHub Desktop.
Node-based auto-differentiation engine in Python.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An auto-differentiation engine. | |
Written by cubeflix (https://github.com/cubeflix). | |
""" | |
from enum import Enum | |
from typing import List, Dict | |
from numbers import Number | |
class OpError(Exception): | |
"""Basic autograd OpError.""" | |
# An enum of OpTypes. | |
OpType = Enum('OpType', | |
['CONST', | |
'ADD', | |
'NEG', | |
'MUL', | |
'DIV']) | |
class Node: | |
""" | |
A single node in an auto-differentiation tree. A node contains a value, | |
an operation type, a list of children nodes, and optional arguments. | |
The auto-differentiation engine builds a tree of nodes by tracking each | |
operation carried out on children nodes. A recursive function then | |
back-propagates through the tree in order to calculate individual | |
derivatives. | |
For example, a constant value would be a node of OpType none, with no | |
children. The product of two nodes would be represented by a node of | |
OpType `mul` and two children. Note that the order of children nodes in | |
the list of children is important, especially for non-commutative | |
operations such as division. | |
""" | |
def __init__(self, value: Number, op: OpType, children: List['Node'], | |
args: Dict[str, any]={}): | |
"""Create the node.""" | |
self.value = value | |
self.op = op | |
self.children = children | |
self.args = args | |
def __add__(self, other) -> 'Node': | |
"""Perform the add operation on a node.""" | |
if not isinstance(other, Node) and not isinstance(other, Number): | |
raise TypeError(f"cannot add Node with type {type(other)}") | |
if isinstance(other, Number): | |
other = val(other) | |
return OPS[OpType.ADD][0](self, other) | |
def __radd__(self, other) -> 'Node': | |
"""Perform the add operation on a node (reverse).""" | |
if not isinstance(other, Node) and not isinstance(other, Number): | |
raise TypeError(f"cannot add Node with type {type(other)}") | |
if isinstance(other, Number): | |
other = val(other) | |
return OPS[OpType.ADD][0](self, other) | |
def __neg__(self) -> 'Node': | |
"""Perform the negation operation on a node.""" | |
return OPS[OpType.NEG][0](self) | |
def __sub__(self, other) -> 'Node': | |
"""Perform the subtract operation on a node.""" | |
if not isinstance(other, Node) and not isinstance(other, Number): | |
raise TypeError(f"cannot subtract Node with type {type(other)}") | |
if isinstance(other, Number): | |
other = OPS[OpType.NEG][0](val(other)) | |
return OPS[OpType.ADD][0](self, other) | |
def __rsub__(self, other) -> 'Node': | |
"""Perform the subtract operation on a node (reverse).""" | |
if not isinstance(other, Node) and not isinstance(other, Number): | |
raise TypeError(f"cannot subtract Node with type {type(other)}") | |
if isinstance(other, Number): | |
other = val(other) | |
return OPS[OpType.ADD][0](OPS[OpType.NEG][0](self), other) | |
def __mul__(self, other) -> 'Node': | |
"""Perform the multiply operation on a node.""" | |
if not isinstance(other, Node) and not isinstance(other, Number): | |
raise TypeError(f"cannot multiply Node with type {type(other)}") | |
if isinstance(other, Number): | |
other = val(other) | |
return OPS[OpType.MUL][0](self, other) | |
def __rmul__(self, other) -> 'Node': | |
"""Perform the multiply operation on a node (reverse).""" | |
if not isinstance(other, Node) and not isinstance(other, Number): | |
raise TypeError(f"cannot multiply Node with type {type(other)}") | |
if isinstance(other, Number): | |
other = val(other) | |
return OPS[OpType.MUL][0](self, other) | |
def __truediv__(self, other) -> 'Node': | |
"""Perform the divide operation on a node.""" | |
if not isinstance(other, Node) and not isinstance(other, Number): | |
raise TypeError(f"cannot divide Node with type {type(other)}") | |
if isinstance(other, Number): | |
other = val(other) | |
return OPS[OpType.DIV][0](self, other) | |
def __itruediv__(self, other) -> 'Node': | |
"""Perform the divide operation on a node.""" | |
if not isinstance(other, Node) and not isinstance(other, Number): | |
raise TypeError(f"cannot divide Node with type {type(other)}") | |
if isinstance(other, Number): | |
other = val(other) | |
return OPS[OpType.DIV][0](self, other) | |
def __rtruediv__(self, other) -> 'Node': | |
"""Perform the divide operation on a node (reverse).""" | |
if not isinstance(other, Node) and not isinstance(other, Number): | |
raise TypeError(f"cannot divide Node with type {type(other)}") | |
if isinstance(other, Number): | |
other = val(other) | |
return OPS[OpType.DIV][0](other, self) | |
def __float__(self) -> float: | |
"""Convert the node into a float object.""" | |
return float(self.value) | |
def __int__(self) -> int: | |
"""Convert the node into a int object.""" | |
return int(self.value) | |
def __str__(self) -> str: | |
"""Get a readable string representation of the node.""" | |
return str(self.value) | |
def __repr__(self) -> str: | |
"""Get an un-ambiguous string representation of the node.""" | |
return f"<Node value={self.__str__()} op={self.op} children=[{', '.join([str(i) for i in self.children])}]>" | |
def val(x: Number) -> Node: | |
"""Convert a value into a constant node.""" | |
return Node(x, OpType.CONST, []) | |
def d_const(a: Node, b: Node) -> float: | |
"""Calculate the derivative of a constant value with respect to another node.""" | |
if a == b: | |
return 1 | |
else: | |
return 0 | |
def add(a: Node, b: Node) -> Node: | |
"""Add two nodes.""" | |
if not isinstance(a, Node): | |
raise TypeError("argument must be of type Node") | |
if not isinstance(b, Node): | |
raise TypeError("argument must be of type Node") | |
c = Node(a.value + b.value, OpType.ADD, [a, b]) | |
return c | |
def d_add(a: Node, b: Node) -> float: | |
"""Calculate the derivative of a sum with respect to another node.""" | |
if not isinstance(a, Node): | |
raise TypeError("argument must be of type Node") | |
if not isinstance(b, Node): | |
raise TypeError("argument must be of type Node") | |
assert a.op == OpType.ADD, "OpType of argument must be `ADD`" | |
# Calculate the derivative of both child nodes and add them. | |
return grad(a.children[0], b) + grad(a.children[1], b) | |
def neg(a: Node) -> Node: | |
"""Negate a node.""" | |
if not isinstance(a, Node): | |
raise TypeError("argument must be of type Node") | |
c = Node(-a.value, OpType.NEG, [a]) | |
return c | |
def d_neg(a: Node, b: Node) -> float: | |
"""Calculate the derivative of a negation with respect to another node.""" | |
if not isinstance(a, Node): | |
raise TypeError("argument must be of type Node") | |
if not isinstance(b, Node): | |
raise TypeError("argument must be of type Node") | |
assert a.op == OpType.NEG, "OpType of argument must be `NEG`" | |
# Calculate the derivative of the child node and negate it. | |
return -grad(a.children[0], b) | |
def mul(a: Node, b: Node) -> Node: | |
"""Multiply two nodes.""" | |
if not isinstance(a, Node): | |
raise TypeError("argument must be of type Node") | |
if not isinstance(b, Node): | |
raise TypeError("argument must be of type Node") | |
c = Node(a.value * b.value, OpType.MUL, [a, b]) | |
return c | |
def d_mul(a: Node, b: Node) -> float: | |
"""Calculate the derivative of a product with respect to another node.""" | |
if not isinstance(a, Node): | |
raise TypeError("argument must be of type Node") | |
if not isinstance(b, Node): | |
raise TypeError("argument must be of type Node") | |
assert a.op == OpType.MUL, "OpType of argument must be `MUL`" | |
# Calculate the derivative of the product using the product rule. | |
return grad(a.children[0], b) * a.children[1].value + grad(a.children[1], b) * a.children[0].value | |
def div(a: Node, b: Node) -> Node: | |
"""Divide two nodes.""" | |
if not isinstance(a, Node): | |
raise TypeError("argument must be of type Node") | |
if not isinstance(b, Node): | |
raise TypeError("argument must be of type Node") | |
c = Node(a.value / b.value, OpType.DIV, [a, b]) | |
return c | |
def d_div(a: Node, b: Node) -> float: | |
"""Calculate the derivative of a quotient with respect to another node.""" | |
if not isinstance(a, Node): | |
raise TypeError("argument must be of type Node") | |
if not isinstance(b, Node): | |
raise TypeError("argument must be of type Node") | |
assert a.op == OpType.DIV, "OpType of argument must be `DIV`" | |
# Calculate the derivative of the quotient using the quotient rule. | |
return (grad(a.children[0], b) * a.children[1].value - grad(a.children[1], b) * a.children[0].value) / (a.children[1].value ** 2) | |
# A dict of functions for each operation. | |
OPS = { | |
OpType.CONST: [None, d_const], | |
OpType.ADD: [add, d_add], | |
OpType.NEG: [neg, d_neg], | |
OpType.MUL: [mul, d_mul], | |
OpType.DIV: [div, d_div] | |
} | |
def grad(a, b): | |
""" | |
Calculate the derivative of node `a` with respect to node `b`. Recursively | |
calculates the gradient using back-propagation and the chain rule. | |
""" | |
if not isinstance(a, Node): | |
raise TypeError("argument must be of type Node") | |
if not isinstance(b, Node): | |
raise TypeError("argument must be of type Node") | |
# Check that the OpType of `a` is valid. | |
if not a.op in OPS.keys(): | |
raise OpError("invalid operation type") | |
# Calculate the derivative of the current node with respect to `b`. | |
g = OPS[a.op][1](a, b) | |
return g |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment