-
-
Save tekknolagi/b587de40ea55dc9d65b70282fb58e262 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__pycache__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from __future__ import annotations | |
import dataclasses | |
import math | |
import unittest | |
import typing | |
@dataclasses.dataclass | |
class Expr: | |
def diff(self, var: str) -> Expr: | |
raise NotImplementedError | |
@dataclasses.dataclass | |
class Const(Expr): | |
value: int | float | |
def diff(self, var: str) -> Expr: | |
return Const(0) | |
@dataclasses.dataclass | |
class Var(Expr): | |
name: str | |
def diff(self, var: str) -> Expr: | |
return Const(1) if self.name == var else Const(0) | |
@dataclasses.dataclass | |
class Unary(Expr): | |
expr: Expr | |
@dataclasses.dataclass | |
class Negate(Unary): | |
def diff(self, var: str) -> Expr: | |
raise NotImplementedError("TODO") | |
@dataclasses.dataclass | |
class Binary(Expr): | |
left: Expr | |
right: Expr | |
@dataclasses.dataclass | |
class Add(Binary): | |
def diff(self, var: str) -> Expr: | |
return add(self.left.diff(var), self.right.diff(var)) | |
def add(x: Expr, y: Expr) -> Expr: | |
match (x, y): | |
case (Const(xval), Const(yval)): | |
return Const(xval + yval) | |
case (Const(0), x) | (x, Const(0)): | |
return x | |
case _: | |
return Add(x, y) | |
@dataclasses.dataclass | |
class Sub(Binary): | |
def diff(self, var: str) -> Expr: | |
return sub(self.left.diff(var), self.right.diff(var)) | |
def sub(x: Expr, y: Expr) -> Expr: | |
match (x, y): | |
case (Const(xval), Const(yval)): | |
return Const(xval - yval) | |
case (x, Const(0)): | |
return x | |
case _: | |
return Sub(x, y) | |
@dataclasses.dataclass | |
class Mul(Binary): | |
def diff(self, var: str) -> Expr: | |
return add( | |
mul(self.left, self.right.diff(var)), | |
mul(self.left.diff(var), self.right), | |
) | |
def mul(x: Expr, y: Expr) -> Expr: | |
match (x, y): | |
case (Const(xval), Const(yval)): | |
return Const(xval * yval) | |
case (Const(0), _) | (_, Const(0)): | |
return Const(0) | |
case (Const(1), x) | (x, Const(1)): | |
return x | |
case _: | |
return Mul(x, y) | |
@dataclasses.dataclass | |
class Div(Binary): | |
def diff(self, var: str) -> Expr: | |
return mul(self.left, pow(self.right, Const(-1))).diff(var) | |
@dataclasses.dataclass | |
class Pow(Binary): | |
def diff(self, var: str) -> Expr: | |
base = self.left | |
exp = self.right | |
return mul(exp, mul(pow(base, sub(exp, Const(1))), base.diff(var))) | |
def pow(x: Expr, y: Expr) -> Expr: | |
match (x, y): | |
case (Const(xval), Const(yval)): | |
return Const(xval**yval) | |
case (Const(0), _): | |
return Const(0) | |
case (_, Const(0)): | |
return Const(1) | |
case (_, Const(1)): | |
return x | |
case _: | |
return Pow(x, y) | |
@dataclasses.dataclass | |
class Function(Expr): | |
expr: Expr | |
def self_diff(self, var: str) -> Expr: | |
raise NotImplementedError( | |
f"Function {self.__class__.__name__} is not differentiable" | |
) | |
def diff(self, var: str) -> Expr: | |
return mul(self.self_diff(var), self.expr.diff(var)) | |
@dataclasses.dataclass | |
class Sin(Function): | |
def self_diff(self, var: str) -> Expr: | |
return cos(self.expr) | |
def sin(expr: Expr) -> Expr: | |
match expr: | |
case Const(val): | |
return Const(math.sin(val)) | |
case _: | |
return Sin(expr) | |
@dataclasses.dataclass | |
class Cos(Function): | |
def self_diff(self, var: str) -> Expr: | |
return mul(Const(-1), sin(self.expr)) | |
def cos(expr: Expr) -> Expr: | |
match expr: | |
case Const(val): | |
return Const(math.cos(val)) | |
case _: | |
return Cos(expr) | |
class DiffTests(unittest.TestCase): | |
def test_const(self) -> None: | |
self.assertEqual(Const(42).diff("x"), Const(0)) | |
def test_var(self) -> None: | |
self.assertEqual(Var("x").diff("x"), Const(1)) | |
self.assertEqual(Var("y").diff("x"), Const(0)) | |
def test_add(self) -> None: | |
self.assertEqual(Add(Var("x"), Const(3)).diff("x"), Const(1)) | |
self.assertEqual(Add(Var("x"), Var("x")).diff("x"), Const(2)) | |
self.assertEqual(Add(Var("x"), Var("y")).diff("x"), Const(1)) | |
self.assertEqual(Add(Var("x"), Var("y")).diff("y"), Const(1)) | |
def test_sub(self) -> None: | |
self.assertEqual(Sub(Var("x"), Const(3)).diff("x"), Const(1)) | |
self.assertEqual(Sub(Var("x"), Var("x")).diff("x"), Const(0)) | |
self.assertEqual(Sub(Var("x"), Var("y")).diff("x"), Const(1)) | |
self.assertEqual(Sub(Var("x"), Var("y")).diff("y"), Const(-1)) | |
def test_mul(self) -> None: | |
self.assertEqual(Mul(Var("x"), Var("y")).diff("x"), Var("y")) | |
self.assertEqual(Mul(Var("x"), Var("y")).diff("y"), Var("x")) | |
def test_div(self) -> None: | |
self.assertEqual( | |
Div(Const(1), Var("x")).diff("x"), Mul(Const(-1), Pow(Var("x"), Const(-2))) | |
) | |
def test_pow(self) -> None: | |
self.assertEqual(Pow(Var("x"), Const(0)).diff("x"), Const(0)) | |
self.assertEqual(Pow(Var("x"), Const(1)).diff("x"), Const(1)) | |
self.assertEqual(Pow(Var("x"), Const(2)).diff("x"), Mul(Const(2), Var("x"))) | |
self.assertEqual( | |
Pow(Var("x"), Const(3)).diff("x"), | |
Mul(Const(3), Pow(Var("x"), Const(2))), | |
) | |
def test_integration(self) -> None: | |
expr = Pow( | |
Add( | |
Mul(Const(3), Var("x")), | |
Pow(Var("x"), Const(4)), | |
), | |
Const(2), | |
) | |
self.assertEqual( | |
expr.diff("x"), | |
Mul( | |
Const(2), | |
Mul( | |
Add(Mul(Const(3), Var("x")), Pow(Var("x"), Const(4))), | |
Add(Const(3), Mul(Const(4), Pow(Var("x"), Const(3)))), | |
), | |
), | |
) | |
def test_sin(self) -> None: | |
self.assertEqual(Sin(Var("x")).diff("x"), Cos(Var("x"))) | |
def test_cos(self) -> None: | |
self.assertEqual(Cos(Var("x")).diff("x"), Mul(Const(-1), Sin(Var("x")))) | |
def test_chain_rule(self) -> None: | |
self.assertEqual( | |
Sin(Pow(Var("x"), Const(2))).diff("x"), | |
Mul(Cos(Pow(Var("x"), Const(2))), Mul(Const(2), Var("x"))), | |
) | |
PREC = [ | |
(Add, "+", "any", 1), | |
(Sub, "-", "left", 1), | |
(Negate, "-", "any", 2), | |
(Mul, "*", "any", 3), | |
(Div, "/", "left", 3), | |
(Pow, "^", "right", 4), | |
] | |
OPERATORS = [info[1] for info in PREC] | |
class DiffError(Exception): | |
pass | |
def op_info_by_type(expr: Expr) -> typing.Tuple[str, str, int]: | |
for op, op_name, assoc, op_prec in PREC: | |
if op == type(expr): | |
return op_name, assoc, op_prec | |
raise DiffError(f"Unknown operator: {type(expr)}") | |
def op_info_by_name(name: str) -> typing.Tuple[type[Expr], str, int]: | |
for op, op_name, assoc, op_prec in PREC: | |
if op_name == name: | |
return op, assoc, op_prec | |
raise DiffError(f"Unknown operator: {name}") | |
def pretty(expr: Expr, prec: int = 0) -> str: | |
match expr: | |
case Const(val): | |
return str(val) | |
case Var(name): | |
return name | |
case Function(x): | |
return f"{type(expr).__name__}({pretty(x, 0)})" | |
case Unary(x): | |
op, assoc, op_prec = op_info_by_type(expr) | |
result = f"{op}{pretty(x, op_prec - 1)}" | |
if prec >= op_prec: | |
return "(" + result + ")" | |
return result | |
case Binary(left, right): | |
op, assoc, op_prec = op_info_by_type(expr) | |
left_prec = op_prec if assoc == "right" else op_prec - 1 | |
right_prec = op_prec if assoc == "left" else op_prec - 1 | |
result = f"{pretty(left, left_prec)}{op}{pretty(right, right_prec)}" | |
if prec >= op_prec: | |
return "(" + result + ")" | |
return result | |
raise NotImplementedError(type(expr)) | |
class PrettyPrintTests(unittest.TestCase): | |
def test_const(self) -> None: | |
self.assertEqual(pretty(Const(3)), "3") | |
def test_var(self) -> None: | |
self.assertEqual(pretty(Var("x")), "x") | |
def test_unary_negate_var(self) -> None: | |
self.assertEqual(pretty(Negate(Var("x"))), "-x") | |
def test_unary_negate_unary_negate(self) -> None: | |
self.assertEqual(pretty(Negate(Negate(Var("x")))), "--x") | |
def test_add_const(self) -> None: | |
expr = Add(Const(3), Const(4)) | |
self.assertEqual(pretty(expr), "3+4") | |
def test_add_unary_negate(self) -> None: | |
expr: Expr = Add(Const(3), Negate(Const(4))) | |
self.assertEqual(pretty(expr), "3+-4") | |
expr = Add(Negate(Const(3)), Const(4)) | |
self.assertEqual(pretty(expr), "-3+4") | |
def test_unary_negate_add(self) -> None: | |
expr = Negate(Add(Const(3), Const(4))) | |
self.assertEqual(pretty(expr), "-(3+4)") | |
def test_add_add(self) -> None: | |
expr = Add(Const(3), Add(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3+4+5") | |
def test_add_sub(self) -> None: | |
expr = Add(Const(3), Sub(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3+4-5") | |
expr = Add(Sub(Const(3), Const(4)), Const(5)) | |
self.assertEqual(pretty(expr), "3-4+5") | |
def test_sub_unary_negate(self) -> None: | |
expr = Sub(Const(3), Negate(Const(4))) | |
self.assertEqual(pretty(expr), "3--4") | |
def test_sub_sub(self) -> None: | |
expr = Sub(Const(3), Sub(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3-(4-5)") | |
expr = Sub(Sub(Const(3), Const(4)), Const(5)) | |
self.assertEqual(pretty(expr), "3-4-5") | |
def test_sub_add(self) -> None: | |
expr = Sub(Add(Const(3), Const(4)), Const(5)) | |
self.assertEqual(pretty(expr), "3+4-5") | |
expr = Sub(Const(3), Add(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3-(4+5)") | |
def test_add_mul(self) -> None: | |
expr = Add(Const(3), Mul(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3+4*5") | |
expr = Add(Mul(Const(4), Const(5)), Const(3)) | |
self.assertEqual(pretty(expr), "4*5+3") | |
def test_mul_negate(self) -> None: | |
expr: Expr = Mul(Negate(Const(3)), Const(4)) | |
self.assertEqual(pretty(expr), "(-3)*4") | |
expr = Negate(Mul(Const(3), Const(4))) | |
self.assertEqual(pretty(expr), "-3*4") | |
def test_mul_mul(self) -> None: | |
expr = Mul(Const(3), Mul(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3*4*5") | |
expr = Mul(Mul(Const(3), Const(4)), Const(5)) | |
self.assertEqual(pretty(expr), "3*4*5") | |
def test_mul_add(self) -> None: | |
expr = Mul(Const(3), Add(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3*(4+5)") | |
expr = Mul(Add(Const(4), Const(5)), Const(3)) | |
self.assertEqual(pretty(expr), "(4+5)*3") | |
def test_mul_sub(self) -> None: | |
expr = Mul(Const(3), Sub(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3*(4-5)") | |
expr = Mul(Sub(Const(4), Const(5)), Const(3)) | |
self.assertEqual(pretty(expr), "(4-5)*3") | |
def test_mul_div(self) -> None: | |
expr = Mul(Const(3), Div(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3*4/5") | |
expr = Mul(Div(Const(3), Const(4)), Const(5)) | |
self.assertEqual(pretty(expr), "3/4*5") | |
def test_mul_pow(self) -> None: | |
expr = Mul(Const(3), Pow(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3*4^5") | |
expr = Mul(Pow(Const(4), Const(5)), Const(3)) | |
self.assertEqual(pretty(expr), "4^5*3") | |
def test_div(self) -> None: | |
expr = Div(Const(3), Const(4)) | |
self.assertEqual(pretty(expr), "3/4") | |
def test_div_div(self) -> None: | |
expr = Div(Const(3), Div(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3/(4/5)") | |
expr = Div(Div(Const(3), Const(4)), Const(5)) | |
self.assertEqual(pretty(expr), "3/4/5") | |
def test_div_mul(self) -> None: | |
expr = Div(Const(3), Mul(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3/(4*5)") | |
expr = Div(Mul(Const(3), Const(4)), Const(5)) | |
self.assertEqual(pretty(expr), "3*4/5") | |
def test_pow_pow(self) -> None: | |
expr = Pow(Const(3), Pow(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3^4^5") | |
expr = Pow(Pow(Const(3), Const(4)), Const(5)) | |
self.assertEqual(pretty(expr), "(3^4)^5") | |
def test_pow_mul(self) -> None: | |
expr = Pow(Const(3), Mul(Const(4), Const(5))) | |
self.assertEqual(pretty(expr), "3^(4*5)") | |
expr = Pow(Mul(Const(3), Const(4)), Const(5)) | |
self.assertEqual(pretty(expr), "(3*4)^5") | |
def test_function(self) -> None: | |
expr = Sin(Const(3)) | |
self.assertEqual(pretty(expr), "Sin(3)") | |
def test_function_add(self) -> None: | |
expr = Sin(Add(Const(1), Const(2))) | |
self.assertEqual(pretty(expr), "Sin(1+2)") | |
def test_function_mul(self) -> None: | |
expr = Mul(Sin(Add(Const(1), Const(2))), Const(3)) | |
self.assertEqual(pretty(expr), "Sin(1+2)*3") | |
def test_function_function(self) -> None: | |
expr = Sin(Sin(Const(3))) | |
self.assertEqual(pretty(expr), "Sin(Sin(3))") | |
def tokenize(source: str) -> list[typing.Any]: | |
assert isinstance(source, str) | |
if not source.isascii(): | |
raise DiffError("Only ASCII characters are supported") | |
result: list[typing.Any] = [] | |
def pop() -> str: | |
nonlocal source | |
result = source[0] | |
source = source[1:] | |
return result | |
while source: | |
c = pop() | |
if c.isspace(): | |
continue | |
if c.isdigit(): | |
num = int(c) | |
while source and source[0].isdigit(): | |
num = num * 10 + int(pop()) | |
result.append(num) | |
continue | |
if c in OPERATORS or c in "()": | |
result.append(c) | |
continue | |
if c.isalpha(): | |
var = c | |
while source and source[0].isalpha(): | |
var += pop() | |
result.append(var) | |
continue | |
raise DiffError(f"Unexpected character: {c}") | |
return result | |
class TokenizeTests(unittest.TestCase): | |
def test_empty(self) -> None: | |
self.assertEqual(tokenize(""), []) | |
def test_strip_leading_whitespace(self) -> None: | |
self.assertEqual(tokenize(" "), []) | |
self.assertEqual(tokenize(" "), []) | |
self.assertEqual(tokenize(" \t "), []) | |
def test_digit(self) -> None: | |
self.assertEqual(tokenize("1"), [1]) | |
def test_digits(self) -> None: | |
self.assertEqual(tokenize("123"), [123]) | |
def test_multiple_numbers(self) -> None: | |
self.assertEqual(tokenize("12 34 56"), [12, 34, 56]) | |
def test_operator(self) -> None: | |
self.assertEqual(tokenize("+"), ["+"]) | |
self.assertEqual(tokenize("-"), ["-"]) | |
self.assertEqual(tokenize("*"), ["*"]) | |
self.assertEqual(tokenize("/"), ["/"]) | |
self.assertEqual(tokenize("^"), ["^"]) | |
self.assertEqual(tokenize("("), ["("]) | |
self.assertEqual(tokenize(")"), [")"]) | |
def test_plus(self) -> None: | |
self.assertEqual(tokenize("1+2"), [1, "+", 2]) | |
def test_minus(self) -> None: | |
self.assertEqual(tokenize("1-2"), [1, "-", 2]) | |
def test_times(self) -> None: | |
self.assertEqual(tokenize("1*2"), [1, "*", 2]) | |
def test_divide(self) -> None: | |
self.assertEqual(tokenize("1/2"), [1, "/", 2]) | |
def test_power(self) -> None: | |
self.assertEqual(tokenize("1^2"), [1, "^", 2]) | |
def test_compound(self) -> None: | |
self.assertEqual(tokenize("1+2*3"), [1, "+", 2, "*", 3]) | |
self.assertEqual(tokenize("(1+2)*3"), ["(", 1, "+", 2, ")", "*", 3]) | |
def parse_(tokens: list[typing.Any], prec: int) -> Expr: | |
def paren() -> Expr: | |
if not tokens: | |
raise DiffError("Unexpected end of input") | |
if tokens[0] == "-": | |
# Unary negate | |
tokens.pop(0) | |
return Negate(paren()) | |
if tokens[0] in OPERATORS: | |
raise DiffError(f"Unexpected operator: {tokens[0]}") | |
if isinstance(tokens[0], int): | |
return Const(tokens.pop(0)) | |
if tokens[0] == "(": | |
tokens.pop(0) | |
result = parse_(tokens, 0) | |
if not tokens or tokens.pop(0) != ")": | |
raise DiffError("Expected closing parenthesis") | |
return result | |
if isinstance(tokens[0], str) and tokens[0].isalpha(): | |
return Var(tokens.pop(0)) | |
raise DiffError(f"Unexpected token: {tokens[0]}") | |
lhs = paren() | |
while tokens and (token := tokens[0]) in OPERATORS: | |
op, assoc, op_prec = op_info_by_name(token) | |
if op_prec < prec: | |
break | |
tokens.pop(0) | |
next_prec = op_prec + 1 if assoc == "left" else op_prec | |
rhs = parse_(tokens, next_prec) | |
lhs = op(lhs, rhs) | |
return lhs | |
def parse(tokens: list[typing.Any]) -> Expr: | |
result = parse_(tokens, 0) | |
if tokens: | |
raise DiffError("Unexpected tokens: " + " ".join(map(str, tokens))) | |
return result | |
class ParseTests(unittest.TestCase): | |
def test_const(self) -> None: | |
self.assertEqual(parse([3]), Const(3)) | |
def test_const_leftover_raises(self) -> None: | |
with self.assertRaisesRegex(DiffError, "Unexpected tokens: 4"): | |
parse([3, 4]) | |
def test_const_paren(self) -> None: | |
self.assertEqual(parse(["(", 3, ")"]), Const(3)) | |
def test_const_paren_missing(self) -> None: | |
with self.assertRaisesRegex(DiffError, "Expected closing parenthesis"): | |
parse(["(", 3]) | |
with self.assertRaisesRegex(DiffError, "Unexpected tokens: \)"): | |
parse([3, ")"]) | |
def test_negate_const(self) -> None: | |
with self.assertRaisesRegex(DiffError, "Unexpected end of input"): | |
parse(["-"]) | |
self.assertEqual(parse(["-", 3]), Negate(Const(3))) | |
def test_add(self) -> None: | |
self.assertEqual(parse([1, "+", 2]), Add(Const(1), Const(2))) | |
def test_add_negate(self) -> None: | |
self.assertEqual(parse([1, "+", "-", 2]), Add(Const(1), Negate(Const(2)))) | |
def test_begin_add_raises(self) -> None: | |
with self.assertRaisesRegex(DiffError, r"Unexpected operator: +"): | |
parse(["+"]) | |
with self.assertRaisesRegex(DiffError, r"Unexpected operator: +"): | |
parse(["+", 2]) | |
def test_double_add_raises(self) -> None: | |
with self.assertRaisesRegex(DiffError, r"Unexpected operator: +"): | |
parse([1, "+", "+", 2]) | |
def test_add_add(self) -> None: | |
self.assertEqual( | |
parse([1, "+", 2, "+", 3]), Add(Const(1), Add(Const(2), Const(3))) | |
) | |
def test_add_mul(self) -> None: | |
self.assertEqual( | |
parse([1, "+", 2, "*", 3]), Add(Const(1), Mul(Const(2), Const(3))) | |
) | |
self.assertEqual( | |
parse(["(", 1, "+", 2, ")", "*", 3]), Mul(Add(Const(1), Const(2)), Const(3)) | |
) | |
def test_mul_add(self) -> None: | |
self.assertEqual( | |
parse([1, "*", 2, "+", 3]), Add(Mul(Const(1), Const(2)), Const(3)) | |
) | |
def test_sub(self) -> None: | |
self.assertEqual(parse([1, "-", 2]), Sub(Const(1), Const(2))) | |
def test_sub_negate(self) -> None: | |
self.assertEqual(parse([1, "-", "-", 2]), Sub(Const(1), Negate(Const(2)))) | |
def test_sub_sub(self) -> None: | |
self.assertEqual( | |
parse([1, "-", 2, "-", 3]), Sub(Sub(Const(1), Const(2)), Const(3)) | |
) | |
def test_add_mul(self) -> None: | |
self.assertEqual( | |
parse([1, "+", 2, "*", 3]), Add(Const(1), Mul(Const(2), Const(3))) | |
) | |
self.assertEqual( | |
parse([1, "*", 2, "+", 3]), Add(Mul(Const(1), Const(2)), Const(3)) | |
) | |
class EndToEndTests(unittest.TestCase): | |
def _expect_parse(self, source: str, expected: Expr) -> None: | |
tokens = tokenize(source) | |
parsed = parse(tokens) | |
self.assertEqual(parsed, expected) | |
def _expect_reparse(self, expr: Expr, expected: Expr) -> None: | |
self._expect_parse(pretty(expr), expected) | |
def _run(self, expr: Expr) -> None: | |
self._expect_reparse(expr, expr) | |
def test_const(self) -> None: | |
self._run(Const(3)) | |
def test_add(self) -> None: | |
self._run(Add(Const(1), Const(2))) | |
def test_sub(self) -> None: | |
self._run(Sub(Const(1), Const(2))) | |
def test_mul(self) -> None: | |
self._run(Mul(Const(2), Const(3))) | |
def test_div(self) -> None: | |
self._run(Div(Const(3), Const(4))) | |
def test_pow(self) -> None: | |
self._run(Pow(Const(2), Const(3))) | |
def test_add_add(self) -> None: | |
self._run(Add(Const(1), Add(Const(2), Const(3)))) | |
self._expect_reparse( | |
Add(Add(Const(1), Const(2)), Const(3)), | |
Add(Const(1), Add(Const(2), Const(3))), | |
) | |
def test_sub_sub(self) -> None: | |
self._run(Sub(Const(1), Sub(Const(2), Const(3)))) | |
self._run(Sub(Sub(Const(1), Const(2)), Const(3))) | |
def test_add_mul(self) -> None: | |
self._run(Add(Const(1), Mul(Const(2), Const(3)))) | |
self._run(Add(Mul(Const(1), Const(2)), Const(3))) | |
def test_mul_add(self) -> None: | |
self._run(Mul(Const(1), Add(Const(2), Const(3)))) | |
self._run(Mul(Add(Const(1), Const(2)), Const(3))) | |
def test_eli(self) -> None: | |
# From https://eli.thegreenplace.net/2012/08/02/parsing-expressions-by-precedence-climbing | |
source = "2 + 3 ^ 2 * 3 + 4" | |
self._expect_parse( | |
source, Add(Const(2), Add(Mul(Pow(Const(3), Const(2)), Const(3)), Const(4))) | |
) | |
if __name__ == "__main__": | |
__import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 | |
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hypothesis | |
import hypothesis.strategies as st | |
import unittest | |
import diff | |
# Make add/mul left associative to make round trip testing work | |
diff.PREC = [ | |
(diff.Add, "+", "left", 1), | |
(diff.Sub, "-", "left", 1), | |
(diff.Mul, "*", "left", 2), | |
(diff.Div, "/", "left", 2), | |
(diff.Pow, "^", "right", 3), | |
] | |
consts = st.builds(diff.Const, st.integers(0, 10)) | |
vars = st.builds(diff.Var, st.from_regex(r"[a-z]{1,3}", fullmatch=True)) | |
nodeclasses = st.sampled_from([p[0] for p in diff.PREC]) | |
@st.composite | |
def binary_nodes(draw, elements): | |
left = draw(elements) | |
right = draw(elements) | |
cls = draw(nodeclasses) | |
return cls(left, right) | |
recursive_tree = st.recursive(st.one_of(consts, vars), binary_nodes) | |
class TestDiff(unittest.TestCase): | |
@hypothesis.given(recursive_tree) | |
@hypothesis.settings(max_examples=1000) | |
def test_diff(self, tree): | |
pretty = diff.pretty(tree) | |
parsed = diff.parse(diff.tokenize(pretty)) | |
self.assertEqual(tree, parsed) | |
if __name__ == "__main__": | |
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
hypothesis==6.111.2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment