Skip to content

Instantly share code, notes, and snippets.

@CubeFlix
Created September 3, 2023 01:49
Show Gist options
  • Save CubeFlix/cc906cd4fe24e954944f1f15af2b3b0b to your computer and use it in GitHub Desktop.
Save CubeFlix/cc906cd4fe24e954944f1f15af2b3b0b to your computer and use it in GitHub Desktop.
Node-based auto-differentiation engine in Python.
"""
An auto-differentiation engine.
Written by cubeflix (https://github.com/cubeflix).
"""
from enum import Enum
from typing import List, Dict
from numbers import Number
class OpError(Exception):
"""Basic autograd OpError."""
# An enum of OpTypes.
OpType = Enum('OpType',
['CONST',
'ADD',
'NEG',
'MUL',
'DIV'])
class Node:
"""
A single node in an auto-differentiation tree. A node contains a value,
an operation type, a list of children nodes, and optional arguments.
The auto-differentiation engine builds a tree of nodes by tracking each
operation carried out on children nodes. A recursive function then
back-propagates through the tree in order to calculate individual
derivatives.
For example, a constant value would be a node of OpType none, with no
children. The product of two nodes would be represented by a node of
OpType `mul` and two children. Note that the order of children nodes in
the list of children is important, especially for non-commutative
operations such as division.
"""
def __init__(self, value: Number, op: OpType, children: List['Node'],
args: Dict[str, any]={}):
"""Create the node."""
self.value = value
self.op = op
self.children = children
self.args = args
def __add__(self, other) -> 'Node':
"""Perform the add operation on a node."""
if not isinstance(other, Node) and not isinstance(other, Number):
raise TypeError(f"cannot add Node with type {type(other)}")
if isinstance(other, Number):
other = val(other)
return OPS[OpType.ADD][0](self, other)
def __radd__(self, other) -> 'Node':
"""Perform the add operation on a node (reverse)."""
if not isinstance(other, Node) and not isinstance(other, Number):
raise TypeError(f"cannot add Node with type {type(other)}")
if isinstance(other, Number):
other = val(other)
return OPS[OpType.ADD][0](self, other)
def __neg__(self) -> 'Node':
"""Perform the negation operation on a node."""
return OPS[OpType.NEG][0](self)
def __sub__(self, other) -> 'Node':
"""Perform the subtract operation on a node."""
if not isinstance(other, Node) and not isinstance(other, Number):
raise TypeError(f"cannot subtract Node with type {type(other)}")
if isinstance(other, Number):
other = OPS[OpType.NEG][0](val(other))
return OPS[OpType.ADD][0](self, other)
def __rsub__(self, other) -> 'Node':
"""Perform the subtract operation on a node (reverse)."""
if not isinstance(other, Node) and not isinstance(other, Number):
raise TypeError(f"cannot subtract Node with type {type(other)}")
if isinstance(other, Number):
other = val(other)
return OPS[OpType.ADD][0](OPS[OpType.NEG][0](self), other)
def __mul__(self, other) -> 'Node':
"""Perform the multiply operation on a node."""
if not isinstance(other, Node) and not isinstance(other, Number):
raise TypeError(f"cannot multiply Node with type {type(other)}")
if isinstance(other, Number):
other = val(other)
return OPS[OpType.MUL][0](self, other)
def __rmul__(self, other) -> 'Node':
"""Perform the multiply operation on a node (reverse)."""
if not isinstance(other, Node) and not isinstance(other, Number):
raise TypeError(f"cannot multiply Node with type {type(other)}")
if isinstance(other, Number):
other = val(other)
return OPS[OpType.MUL][0](self, other)
def __truediv__(self, other) -> 'Node':
"""Perform the divide operation on a node."""
if not isinstance(other, Node) and not isinstance(other, Number):
raise TypeError(f"cannot divide Node with type {type(other)}")
if isinstance(other, Number):
other = val(other)
return OPS[OpType.DIV][0](self, other)
def __itruediv__(self, other) -> 'Node':
"""Perform the divide operation on a node."""
if not isinstance(other, Node) and not isinstance(other, Number):
raise TypeError(f"cannot divide Node with type {type(other)}")
if isinstance(other, Number):
other = val(other)
return OPS[OpType.DIV][0](self, other)
def __rtruediv__(self, other) -> 'Node':
"""Perform the divide operation on a node (reverse)."""
if not isinstance(other, Node) and not isinstance(other, Number):
raise TypeError(f"cannot divide Node with type {type(other)}")
if isinstance(other, Number):
other = val(other)
return OPS[OpType.DIV][0](other, self)
def __float__(self) -> float:
"""Convert the node into a float object."""
return float(self.value)
def __int__(self) -> int:
"""Convert the node into a int object."""
return int(self.value)
def __str__(self) -> str:
"""Get a readable string representation of the node."""
return str(self.value)
def __repr__(self) -> str:
"""Get an un-ambiguous string representation of the node."""
return f"<Node value={self.__str__()} op={self.op} children=[{', '.join([str(i) for i in self.children])}]>"
def val(x: Number) -> Node:
"""Convert a value into a constant node."""
return Node(x, OpType.CONST, [])
def d_const(a: Node, b: Node) -> float:
"""Calculate the derivative of a constant value with respect to another node."""
if a == b:
return 1
else:
return 0
def add(a: Node, b: Node) -> Node:
"""Add two nodes."""
if not isinstance(a, Node):
raise TypeError("argument must be of type Node")
if not isinstance(b, Node):
raise TypeError("argument must be of type Node")
c = Node(a.value + b.value, OpType.ADD, [a, b])
return c
def d_add(a: Node, b: Node) -> float:
"""Calculate the derivative of a sum with respect to another node."""
if not isinstance(a, Node):
raise TypeError("argument must be of type Node")
if not isinstance(b, Node):
raise TypeError("argument must be of type Node")
assert a.op == OpType.ADD, "OpType of argument must be `ADD`"
# Calculate the derivative of both child nodes and add them.
return grad(a.children[0], b) + grad(a.children[1], b)
def neg(a: Node) -> Node:
"""Negate a node."""
if not isinstance(a, Node):
raise TypeError("argument must be of type Node")
c = Node(-a.value, OpType.NEG, [a])
return c
def d_neg(a: Node, b: Node) -> float:
"""Calculate the derivative of a negation with respect to another node."""
if not isinstance(a, Node):
raise TypeError("argument must be of type Node")
if not isinstance(b, Node):
raise TypeError("argument must be of type Node")
assert a.op == OpType.NEG, "OpType of argument must be `NEG`"
# Calculate the derivative of the child node and negate it.
return -grad(a.children[0], b)
def mul(a: Node, b: Node) -> Node:
"""Multiply two nodes."""
if not isinstance(a, Node):
raise TypeError("argument must be of type Node")
if not isinstance(b, Node):
raise TypeError("argument must be of type Node")
c = Node(a.value * b.value, OpType.MUL, [a, b])
return c
def d_mul(a: Node, b: Node) -> float:
"""Calculate the derivative of a product with respect to another node."""
if not isinstance(a, Node):
raise TypeError("argument must be of type Node")
if not isinstance(b, Node):
raise TypeError("argument must be of type Node")
assert a.op == OpType.MUL, "OpType of argument must be `MUL`"
# Calculate the derivative of the product using the product rule.
return grad(a.children[0], b) * a.children[1].value + grad(a.children[1], b) * a.children[0].value
def div(a: Node, b: Node) -> Node:
"""Divide two nodes."""
if not isinstance(a, Node):
raise TypeError("argument must be of type Node")
if not isinstance(b, Node):
raise TypeError("argument must be of type Node")
c = Node(a.value / b.value, OpType.DIV, [a, b])
return c
def d_div(a: Node, b: Node) -> float:
"""Calculate the derivative of a quotient with respect to another node."""
if not isinstance(a, Node):
raise TypeError("argument must be of type Node")
if not isinstance(b, Node):
raise TypeError("argument must be of type Node")
assert a.op == OpType.DIV, "OpType of argument must be `DIV`"
# Calculate the derivative of the quotient using the quotient rule.
return (grad(a.children[0], b) * a.children[1].value - grad(a.children[1], b) * a.children[0].value) / (a.children[1].value ** 2)
# A dict of functions for each operation.
OPS = {
OpType.CONST: [None, d_const],
OpType.ADD: [add, d_add],
OpType.NEG: [neg, d_neg],
OpType.MUL: [mul, d_mul],
OpType.DIV: [div, d_div]
}
def grad(a, b):
"""
Calculate the derivative of node `a` with respect to node `b`. Recursively
calculates the gradient using back-propagation and the chain rule.
"""
if not isinstance(a, Node):
raise TypeError("argument must be of type Node")
if not isinstance(b, Node):
raise TypeError("argument must be of type Node")
# Check that the OpType of `a` is valid.
if not a.op in OPS.keys():
raise OpError("invalid operation type")
# Calculate the derivative of the current node with respect to `b`.
g = OPS[a.op][1](a, b)
return g
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment