Created
July 7, 2023 19:52
-
-
Save cgranade/136b806ba24b9efdc3fec70887ae8645 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NB: This sample uses features from Python 3.10, | |
# in particular the match statement. | |
# If you get a syntax error, you may need a more | |
# recent version of Python. | |
from __future__ import annotations | |
from dataclasses import dataclass | |
class ExprBase: | |
def __add__(self, other: Expr) -> Expr: | |
return PlusExpr(self, other) | |
# Needed so we can override what happens when we add Var | |
# to a number, something like an int or a float. | |
def __radd__(self, other: Expr) -> Expr: | |
return PlusExpr(other, self) | |
def __mul__(self, other: Expr) -> Expr: | |
return TimesExpr(self, other) | |
def __rmul__(self, other: Expr) -> Expr: | |
return TimesExpr(other, self) | |
@dataclass | |
class Var(ExprBase): | |
name: str | |
def derivative(self, var: Var) -> Expr: | |
return 1 if var == self else 0 | |
def simplify(self) -> Expr: | |
return self | |
def evaluate(self, assignments) -> Expr: | |
if self.name in assignments: | |
return assignments[self.name] | |
else: | |
return self | |
@dataclass | |
class PlusExpr(ExprBase): | |
left: Expr | |
right: Expr | |
def derivative(self, var: Var) -> Expr: | |
return derivative(self.left, var) + derivative(self.right, var) | |
def simplify(self): | |
match (self.left, self.right): | |
case (0, _): | |
return simplify(self.right) | |
case (_, 0): | |
return simplify(self.left) | |
case (int() as n, int() as m): | |
return n + m | |
case (l, r): | |
return PlusExpr(simplify(l), simplify(r)) | |
def evaluate(self, assignments): | |
return evaluate(self.left, **assignments) + evaluate(self.right, **assignments) | |
@dataclass | |
class TimesExpr(ExprBase): | |
left: Expr | |
right: Expr | |
def derivative(self, var: Var) -> Expr: | |
return ( | |
self.left * derivative(self.right, var) + | |
derivative(self.left, var) * self.right | |
) | |
def simplify(self): | |
match (self.left, self.right): | |
case (0, _): | |
return 0 | |
case (_, 0): | |
return 0 | |
case (1, _): | |
return simplify(self.right) | |
case (_, 1): | |
return simplify(self.left) | |
case (int() as n, int() as m): | |
return n + m | |
case (l, r): | |
return TimesExpr(simplify(l), simplify(r)) | |
def evaluate(self, assignments): | |
return evaluate(self.left, **assignments) * evaluate(self.right, **assignments) | |
NumberLiteral = int | float | |
Expr = ExprBase | NumberLiteral | |
def derivative(expr: Expr, var: Var) -> Expr: | |
return expr.derivative(var) if isinstance(expr, ExprBase) else 0 | |
def simplify(expr: Expr) -> Expr: | |
def simplify_inner(expr: Expr) -> Expr: | |
return expr.simplify() if isinstance(expr, ExprBase) else expr | |
prev = expr | |
simplified = simplify_inner(expr) | |
while simplified != prev: | |
prev = simplified | |
simplified = simplify_inner(simplified) | |
return simplified | |
def evaluate(expr: Expr, **assignments) -> Expr: | |
return ( | |
expr.evaluate(assignments) | |
if isinstance(expr, ExprBase) | |
else expr | |
) | |
@dataclass | |
class PmlNodeKind: | |
name: str | |
def __call__(self, *children, **attributes): | |
return PmlNode(self, attributes, children) | |
@dataclass | |
class PmlNode: | |
kind: PmlNodeKind | |
attributes: dict[str, str] | |
children: list[PmlNode | str] | |
def to_html(self) -> str: | |
name = self.kind.name | |
attrs = ( | |
" " + " ".join( | |
f'{name}="{value}"' | |
for name, value in self.attributes.items() | |
) | |
if self.attributes else | |
"" | |
) | |
children = "".join( | |
child.to_html() | |
if isinstance(child, PmlNode) else | |
child | |
for child in self.children | |
) | |
return f"<{name}{attrs}>{children}</{name}>" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment