Skip to content

Instantly share code, notes, and snippets.

@tekknolagi
Last active October 10, 2024 16:59
Show Gist options
  • Save tekknolagi/b587de40ea55dc9d65b70282fb58e262 to your computer and use it in GitHub Desktop.
Save tekknolagi/b587de40ea55dc9d65b70282fb58e262 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
from __future__ import annotations
import dataclasses
import math
import unittest
import typing
@dataclasses.dataclass
class Expr:
def diff(self, var: str) -> Expr:
raise NotImplementedError
@dataclasses.dataclass
class Const(Expr):
value: int | float
def diff(self, var: str) -> Expr:
return Const(0)
@dataclasses.dataclass
class Var(Expr):
name: str
def diff(self, var: str) -> Expr:
return Const(1) if self.name == var else Const(0)
@dataclasses.dataclass
class Unary(Expr):
expr: Expr
@dataclasses.dataclass
class Negate(Unary):
def diff(self, var: str) -> Expr:
raise NotImplementedError("TODO")
@dataclasses.dataclass
class Binary(Expr):
left: Expr
right: Expr
@dataclasses.dataclass
class Add(Binary):
def diff(self, var: str) -> Expr:
return add(self.left.diff(var), self.right.diff(var))
def add(x: Expr, y: Expr) -> Expr:
match (x, y):
case (Const(xval), Const(yval)):
return Const(xval + yval)
case (Const(0), x) | (x, Const(0)):
return x
case _:
return Add(x, y)
@dataclasses.dataclass
class Sub(Binary):
def diff(self, var: str) -> Expr:
return sub(self.left.diff(var), self.right.diff(var))
def sub(x: Expr, y: Expr) -> Expr:
match (x, y):
case (Const(xval), Const(yval)):
return Const(xval - yval)
case (x, Const(0)):
return x
case _:
return Sub(x, y)
@dataclasses.dataclass
class Mul(Binary):
def diff(self, var: str) -> Expr:
return add(
mul(self.left, self.right.diff(var)),
mul(self.left.diff(var), self.right),
)
def mul(x: Expr, y: Expr) -> Expr:
match (x, y):
case (Const(xval), Const(yval)):
return Const(xval * yval)
case (Const(0), _) | (_, Const(0)):
return Const(0)
case (Const(1), x) | (x, Const(1)):
return x
case _:
return Mul(x, y)
@dataclasses.dataclass
class Div(Binary):
def diff(self, var: str) -> Expr:
return mul(self.left, pow(self.right, Const(-1))).diff(var)
@dataclasses.dataclass
class Pow(Binary):
def diff(self, var: str) -> Expr:
base = self.left
exp = self.right
return mul(exp, mul(pow(base, sub(exp, Const(1))), base.diff(var)))
def pow(x: Expr, y: Expr) -> Expr:
match (x, y):
case (Const(xval), Const(yval)):
return Const(xval**yval)
case (Const(0), _):
return Const(0)
case (_, Const(0)):
return Const(1)
case (_, Const(1)):
return x
case _:
return Pow(x, y)
@dataclasses.dataclass
class Function(Expr):
expr: Expr
def self_diff(self, var: str) -> Expr:
raise NotImplementedError(
f"Function {self.__class__.__name__} is not differentiable"
)
def diff(self, var: str) -> Expr:
return mul(self.self_diff(var), self.expr.diff(var))
@dataclasses.dataclass
class Sin(Function):
def self_diff(self, var: str) -> Expr:
return cos(self.expr)
def sin(expr: Expr) -> Expr:
match expr:
case Const(val):
return Const(math.sin(val))
case _:
return Sin(expr)
@dataclasses.dataclass
class Cos(Function):
def self_diff(self, var: str) -> Expr:
return mul(Const(-1), sin(self.expr))
def cos(expr: Expr) -> Expr:
match expr:
case Const(val):
return Const(math.cos(val))
case _:
return Cos(expr)
class DiffTests(unittest.TestCase):
def test_const(self) -> None:
self.assertEqual(Const(42).diff("x"), Const(0))
def test_var(self) -> None:
self.assertEqual(Var("x").diff("x"), Const(1))
self.assertEqual(Var("y").diff("x"), Const(0))
def test_add(self) -> None:
self.assertEqual(Add(Var("x"), Const(3)).diff("x"), Const(1))
self.assertEqual(Add(Var("x"), Var("x")).diff("x"), Const(2))
self.assertEqual(Add(Var("x"), Var("y")).diff("x"), Const(1))
self.assertEqual(Add(Var("x"), Var("y")).diff("y"), Const(1))
def test_sub(self) -> None:
self.assertEqual(Sub(Var("x"), Const(3)).diff("x"), Const(1))
self.assertEqual(Sub(Var("x"), Var("x")).diff("x"), Const(0))
self.assertEqual(Sub(Var("x"), Var("y")).diff("x"), Const(1))
self.assertEqual(Sub(Var("x"), Var("y")).diff("y"), Const(-1))
def test_mul(self) -> None:
self.assertEqual(Mul(Var("x"), Var("y")).diff("x"), Var("y"))
self.assertEqual(Mul(Var("x"), Var("y")).diff("y"), Var("x"))
def test_div(self) -> None:
self.assertEqual(
Div(Const(1), Var("x")).diff("x"), Mul(Const(-1), Pow(Var("x"), Const(-2)))
)
def test_pow(self) -> None:
self.assertEqual(Pow(Var("x"), Const(0)).diff("x"), Const(0))
self.assertEqual(Pow(Var("x"), Const(1)).diff("x"), Const(1))
self.assertEqual(Pow(Var("x"), Const(2)).diff("x"), Mul(Const(2), Var("x")))
self.assertEqual(
Pow(Var("x"), Const(3)).diff("x"),
Mul(Const(3), Pow(Var("x"), Const(2))),
)
def test_integration(self) -> None:
expr = Pow(
Add(
Mul(Const(3), Var("x")),
Pow(Var("x"), Const(4)),
),
Const(2),
)
self.assertEqual(
expr.diff("x"),
Mul(
Const(2),
Mul(
Add(Mul(Const(3), Var("x")), Pow(Var("x"), Const(4))),
Add(Const(3), Mul(Const(4), Pow(Var("x"), Const(3)))),
),
),
)
def test_sin(self) -> None:
self.assertEqual(Sin(Var("x")).diff("x"), Cos(Var("x")))
def test_cos(self) -> None:
self.assertEqual(Cos(Var("x")).diff("x"), Mul(Const(-1), Sin(Var("x"))))
def test_chain_rule(self) -> None:
self.assertEqual(
Sin(Pow(Var("x"), Const(2))).diff("x"),
Mul(Cos(Pow(Var("x"), Const(2))), Mul(Const(2), Var("x"))),
)
PREC = [
(Add, "+", "any", 1),
(Sub, "-", "left", 1),
(Negate, "-", "any", 2),
(Mul, "*", "any", 3),
(Div, "/", "left", 3),
(Pow, "^", "right", 4),
]
OPERATORS = [info[1] for info in PREC]
class DiffError(Exception):
pass
def op_info_by_type(expr: Expr) -> typing.Tuple[str, str, int]:
for op, op_name, assoc, op_prec in PREC:
if op == type(expr):
return op_name, assoc, op_prec
raise DiffError(f"Unknown operator: {type(expr)}")
def op_info_by_name(name: str) -> typing.Tuple[type[Expr], str, int]:
for op, op_name, assoc, op_prec in PREC:
if op_name == name:
return op, assoc, op_prec
raise DiffError(f"Unknown operator: {name}")
def pretty(expr: Expr, prec: int = 0) -> str:
match expr:
case Const(val):
return str(val)
case Var(name):
return name
case Function(x):
return f"{type(expr).__name__}({pretty(x, 0)})"
case Unary(x):
op, assoc, op_prec = op_info_by_type(expr)
result = f"{op}{pretty(x, op_prec - 1)}"
if prec >= op_prec:
return "(" + result + ")"
return result
case Binary(left, right):
op, assoc, op_prec = op_info_by_type(expr)
left_prec = op_prec if assoc == "right" else op_prec - 1
right_prec = op_prec if assoc == "left" else op_prec - 1
result = f"{pretty(left, left_prec)}{op}{pretty(right, right_prec)}"
if prec >= op_prec:
return "(" + result + ")"
return result
raise NotImplementedError(type(expr))
class PrettyPrintTests(unittest.TestCase):
def test_const(self) -> None:
self.assertEqual(pretty(Const(3)), "3")
def test_var(self) -> None:
self.assertEqual(pretty(Var("x")), "x")
def test_unary_negate_var(self) -> None:
self.assertEqual(pretty(Negate(Var("x"))), "-x")
def test_unary_negate_unary_negate(self) -> None:
self.assertEqual(pretty(Negate(Negate(Var("x")))), "--x")
def test_add_const(self) -> None:
expr = Add(Const(3), Const(4))
self.assertEqual(pretty(expr), "3+4")
def test_add_unary_negate(self) -> None:
expr: Expr = Add(Const(3), Negate(Const(4)))
self.assertEqual(pretty(expr), "3+-4")
expr = Add(Negate(Const(3)), Const(4))
self.assertEqual(pretty(expr), "-3+4")
def test_unary_negate_add(self) -> None:
expr = Negate(Add(Const(3), Const(4)))
self.assertEqual(pretty(expr), "-(3+4)")
def test_add_add(self) -> None:
expr = Add(Const(3), Add(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3+4+5")
def test_add_sub(self) -> None:
expr = Add(Const(3), Sub(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3+4-5")
expr = Add(Sub(Const(3), Const(4)), Const(5))
self.assertEqual(pretty(expr), "3-4+5")
def test_sub_unary_negate(self) -> None:
expr = Sub(Const(3), Negate(Const(4)))
self.assertEqual(pretty(expr), "3--4")
def test_sub_sub(self) -> None:
expr = Sub(Const(3), Sub(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3-(4-5)")
expr = Sub(Sub(Const(3), Const(4)), Const(5))
self.assertEqual(pretty(expr), "3-4-5")
def test_sub_add(self) -> None:
expr = Sub(Add(Const(3), Const(4)), Const(5))
self.assertEqual(pretty(expr), "3+4-5")
expr = Sub(Const(3), Add(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3-(4+5)")
def test_add_mul(self) -> None:
expr = Add(Const(3), Mul(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3+4*5")
expr = Add(Mul(Const(4), Const(5)), Const(3))
self.assertEqual(pretty(expr), "4*5+3")
def test_mul_negate(self) -> None:
expr: Expr = Mul(Negate(Const(3)), Const(4))
self.assertEqual(pretty(expr), "(-3)*4")
expr = Negate(Mul(Const(3), Const(4)))
self.assertEqual(pretty(expr), "-3*4")
def test_mul_mul(self) -> None:
expr = Mul(Const(3), Mul(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3*4*5")
expr = Mul(Mul(Const(3), Const(4)), Const(5))
self.assertEqual(pretty(expr), "3*4*5")
def test_mul_add(self) -> None:
expr = Mul(Const(3), Add(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3*(4+5)")
expr = Mul(Add(Const(4), Const(5)), Const(3))
self.assertEqual(pretty(expr), "(4+5)*3")
def test_mul_sub(self) -> None:
expr = Mul(Const(3), Sub(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3*(4-5)")
expr = Mul(Sub(Const(4), Const(5)), Const(3))
self.assertEqual(pretty(expr), "(4-5)*3")
def test_mul_div(self) -> None:
expr = Mul(Const(3), Div(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3*4/5")
expr = Mul(Div(Const(3), Const(4)), Const(5))
self.assertEqual(pretty(expr), "3/4*5")
def test_mul_pow(self) -> None:
expr = Mul(Const(3), Pow(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3*4^5")
expr = Mul(Pow(Const(4), Const(5)), Const(3))
self.assertEqual(pretty(expr), "4^5*3")
def test_div(self) -> None:
expr = Div(Const(3), Const(4))
self.assertEqual(pretty(expr), "3/4")
def test_div_div(self) -> None:
expr = Div(Const(3), Div(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3/(4/5)")
expr = Div(Div(Const(3), Const(4)), Const(5))
self.assertEqual(pretty(expr), "3/4/5")
def test_div_mul(self) -> None:
expr = Div(Const(3), Mul(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3/(4*5)")
expr = Div(Mul(Const(3), Const(4)), Const(5))
self.assertEqual(pretty(expr), "3*4/5")
def test_pow_pow(self) -> None:
expr = Pow(Const(3), Pow(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3^4^5")
expr = Pow(Pow(Const(3), Const(4)), Const(5))
self.assertEqual(pretty(expr), "(3^4)^5")
def test_pow_mul(self) -> None:
expr = Pow(Const(3), Mul(Const(4), Const(5)))
self.assertEqual(pretty(expr), "3^(4*5)")
expr = Pow(Mul(Const(3), Const(4)), Const(5))
self.assertEqual(pretty(expr), "(3*4)^5")
def test_function(self) -> None:
expr = Sin(Const(3))
self.assertEqual(pretty(expr), "Sin(3)")
def test_function_add(self) -> None:
expr = Sin(Add(Const(1), Const(2)))
self.assertEqual(pretty(expr), "Sin(1+2)")
def test_function_mul(self) -> None:
expr = Mul(Sin(Add(Const(1), Const(2))), Const(3))
self.assertEqual(pretty(expr), "Sin(1+2)*3")
def test_function_function(self) -> None:
expr = Sin(Sin(Const(3)))
self.assertEqual(pretty(expr), "Sin(Sin(3))")
def tokenize(source: str) -> list[typing.Any]:
assert isinstance(source, str)
if not source.isascii():
raise DiffError("Only ASCII characters are supported")
result: list[typing.Any] = []
def pop() -> str:
nonlocal source
result = source[0]
source = source[1:]
return result
while source:
c = pop()
if c.isspace():
continue
if c.isdigit():
num = int(c)
while source and source[0].isdigit():
num = num * 10 + int(pop())
result.append(num)
continue
if c in OPERATORS or c in "()":
result.append(c)
continue
if c.isalpha():
var = c
while source and source[0].isalpha():
var += pop()
result.append(var)
continue
raise DiffError(f"Unexpected character: {c}")
return result
class TokenizeTests(unittest.TestCase):
def test_empty(self) -> None:
self.assertEqual(tokenize(""), [])
def test_strip_leading_whitespace(self) -> None:
self.assertEqual(tokenize(" "), [])
self.assertEqual(tokenize(" "), [])
self.assertEqual(tokenize(" \t "), [])
def test_digit(self) -> None:
self.assertEqual(tokenize("1"), [1])
def test_digits(self) -> None:
self.assertEqual(tokenize("123"), [123])
def test_multiple_numbers(self) -> None:
self.assertEqual(tokenize("12 34 56"), [12, 34, 56])
def test_operator(self) -> None:
self.assertEqual(tokenize("+"), ["+"])
self.assertEqual(tokenize("-"), ["-"])
self.assertEqual(tokenize("*"), ["*"])
self.assertEqual(tokenize("/"), ["/"])
self.assertEqual(tokenize("^"), ["^"])
self.assertEqual(tokenize("("), ["("])
self.assertEqual(tokenize(")"), [")"])
def test_plus(self) -> None:
self.assertEqual(tokenize("1+2"), [1, "+", 2])
def test_minus(self) -> None:
self.assertEqual(tokenize("1-2"), [1, "-", 2])
def test_times(self) -> None:
self.assertEqual(tokenize("1*2"), [1, "*", 2])
def test_divide(self) -> None:
self.assertEqual(tokenize("1/2"), [1, "/", 2])
def test_power(self) -> None:
self.assertEqual(tokenize("1^2"), [1, "^", 2])
def test_compound(self) -> None:
self.assertEqual(tokenize("1+2*3"), [1, "+", 2, "*", 3])
self.assertEqual(tokenize("(1+2)*3"), ["(", 1, "+", 2, ")", "*", 3])
def parse_(tokens: list[typing.Any], prec: int) -> Expr:
def paren() -> Expr:
if not tokens:
raise DiffError("Unexpected end of input")
if tokens[0] == "-":
# Unary negate
tokens.pop(0)
return Negate(paren())
if tokens[0] in OPERATORS:
raise DiffError(f"Unexpected operator: {tokens[0]}")
if isinstance(tokens[0], int):
return Const(tokens.pop(0))
if tokens[0] == "(":
tokens.pop(0)
result = parse_(tokens, 0)
if not tokens or tokens.pop(0) != ")":
raise DiffError("Expected closing parenthesis")
return result
if isinstance(tokens[0], str) and tokens[0].isalpha():
return Var(tokens.pop(0))
raise DiffError(f"Unexpected token: {tokens[0]}")
lhs = paren()
while tokens and (token := tokens[0]) in OPERATORS:
op, assoc, op_prec = op_info_by_name(token)
if op_prec < prec:
break
tokens.pop(0)
next_prec = op_prec + 1 if assoc == "left" else op_prec
rhs = parse_(tokens, next_prec)
lhs = op(lhs, rhs)
return lhs
def parse(tokens: list[typing.Any]) -> Expr:
result = parse_(tokens, 0)
if tokens:
raise DiffError("Unexpected tokens: " + " ".join(map(str, tokens)))
return result
class ParseTests(unittest.TestCase):
def test_const(self) -> None:
self.assertEqual(parse([3]), Const(3))
def test_const_leftover_raises(self) -> None:
with self.assertRaisesRegex(DiffError, "Unexpected tokens: 4"):
parse([3, 4])
def test_const_paren(self) -> None:
self.assertEqual(parse(["(", 3, ")"]), Const(3))
def test_const_paren_missing(self) -> None:
with self.assertRaisesRegex(DiffError, "Expected closing parenthesis"):
parse(["(", 3])
with self.assertRaisesRegex(DiffError, "Unexpected tokens: \)"):
parse([3, ")"])
def test_negate_const(self) -> None:
with self.assertRaisesRegex(DiffError, "Unexpected end of input"):
parse(["-"])
self.assertEqual(parse(["-", 3]), Negate(Const(3)))
def test_add(self) -> None:
self.assertEqual(parse([1, "+", 2]), Add(Const(1), Const(2)))
def test_add_negate(self) -> None:
self.assertEqual(parse([1, "+", "-", 2]), Add(Const(1), Negate(Const(2))))
def test_begin_add_raises(self) -> None:
with self.assertRaisesRegex(DiffError, r"Unexpected operator: +"):
parse(["+"])
with self.assertRaisesRegex(DiffError, r"Unexpected operator: +"):
parse(["+", 2])
def test_double_add_raises(self) -> None:
with self.assertRaisesRegex(DiffError, r"Unexpected operator: +"):
parse([1, "+", "+", 2])
def test_add_add(self) -> None:
self.assertEqual(
parse([1, "+", 2, "+", 3]), Add(Const(1), Add(Const(2), Const(3)))
)
def test_add_mul(self) -> None:
self.assertEqual(
parse([1, "+", 2, "*", 3]), Add(Const(1), Mul(Const(2), Const(3)))
)
self.assertEqual(
parse(["(", 1, "+", 2, ")", "*", 3]), Mul(Add(Const(1), Const(2)), Const(3))
)
def test_mul_add(self) -> None:
self.assertEqual(
parse([1, "*", 2, "+", 3]), Add(Mul(Const(1), Const(2)), Const(3))
)
def test_sub(self) -> None:
self.assertEqual(parse([1, "-", 2]), Sub(Const(1), Const(2)))
def test_sub_negate(self) -> None:
self.assertEqual(parse([1, "-", "-", 2]), Sub(Const(1), Negate(Const(2))))
def test_sub_sub(self) -> None:
self.assertEqual(
parse([1, "-", 2, "-", 3]), Sub(Sub(Const(1), Const(2)), Const(3))
)
def test_add_mul(self) -> None:
self.assertEqual(
parse([1, "+", 2, "*", 3]), Add(Const(1), Mul(Const(2), Const(3)))
)
self.assertEqual(
parse([1, "*", 2, "+", 3]), Add(Mul(Const(1), Const(2)), Const(3))
)
class EndToEndTests(unittest.TestCase):
def _expect_parse(self, source: str, expected: Expr) -> None:
tokens = tokenize(source)
parsed = parse(tokens)
self.assertEqual(parsed, expected)
def _expect_reparse(self, expr: Expr, expected: Expr) -> None:
self._expect_parse(pretty(expr), expected)
def _run(self, expr: Expr) -> None:
self._expect_reparse(expr, expr)
def test_const(self) -> None:
self._run(Const(3))
def test_add(self) -> None:
self._run(Add(Const(1), Const(2)))
def test_sub(self) -> None:
self._run(Sub(Const(1), Const(2)))
def test_mul(self) -> None:
self._run(Mul(Const(2), Const(3)))
def test_div(self) -> None:
self._run(Div(Const(3), Const(4)))
def test_pow(self) -> None:
self._run(Pow(Const(2), Const(3)))
def test_add_add(self) -> None:
self._run(Add(Const(1), Add(Const(2), Const(3))))
self._expect_reparse(
Add(Add(Const(1), Const(2)), Const(3)),
Add(Const(1), Add(Const(2), Const(3))),
)
def test_sub_sub(self) -> None:
self._run(Sub(Const(1), Sub(Const(2), Const(3))))
self._run(Sub(Sub(Const(1), Const(2)), Const(3)))
def test_add_mul(self) -> None:
self._run(Add(Const(1), Mul(Const(2), Const(3))))
self._run(Add(Mul(Const(1), Const(2)), Const(3)))
def test_mul_add(self) -> None:
self._run(Mul(Const(1), Add(Const(2), Const(3))))
self._run(Mul(Add(Const(1), Const(2)), Const(3)))
def test_eli(self) -> None:
# From https://eli.thegreenplace.net/2012/08/02/parsing-expressions-by-precedence-climbing
source = "2 + 3 ^ 2 * 3 + 4"
self._expect_parse(
source, Add(Const(2), Add(Mul(Pow(Const(3), Const(2)), Const(3)), Const(4)))
)
if __name__ == "__main__":
__import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999
unittest.main()
import hypothesis
import hypothesis.strategies as st
import unittest
import diff
# Make add/mul left associative to make round trip testing work
diff.PREC = [
(diff.Add, "+", "left", 1),
(diff.Sub, "-", "left", 1),
(diff.Mul, "*", "left", 2),
(diff.Div, "/", "left", 2),
(diff.Pow, "^", "right", 3),
]
consts = st.builds(diff.Const, st.integers(0, 10))
vars = st.builds(diff.Var, st.from_regex(r"[a-z]{1,3}", fullmatch=True))
nodeclasses = st.sampled_from([p[0] for p in diff.PREC])
@st.composite
def binary_nodes(draw, elements):
left = draw(elements)
right = draw(elements)
cls = draw(nodeclasses)
return cls(left, right)
recursive_tree = st.recursive(st.one_of(consts, vars), binary_nodes)
class TestDiff(unittest.TestCase):
@hypothesis.given(recursive_tree)
@hypothesis.settings(max_examples=1000)
def test_diff(self, tree):
pretty = diff.pretty(tree)
parsed = diff.parse(diff.tokenize(pretty))
self.assertEqual(tree, parsed)
if __name__ == "__main__":
unittest.main()
hypothesis==6.111.2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment