Skip to content

Instantly share code, notes, and snippets.

@cgranade
Created July 7, 2023 19:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cgranade/136b806ba24b9efdc3fec70887ae8645 to your computer and use it in GitHub Desktop.
Save cgranade/136b806ba24b9efdc3fec70887ae8645 to your computer and use it in GitHub Desktop.
# NB: This sample uses features from Python 3.10,
# in particular the match statement.
# If you get a syntax error, you may need a more
# recent version of Python.
from __future__ import annotations
from dataclasses import dataclass
class ExprBase:
def __add__(self, other: Expr) -> Expr:
return PlusExpr(self, other)
# Needed so we can override what happens when we add Var
# to a number, something like an int or a float.
def __radd__(self, other: Expr) -> Expr:
return PlusExpr(other, self)
def __mul__(self, other: Expr) -> Expr:
return TimesExpr(self, other)
def __rmul__(self, other: Expr) -> Expr:
return TimesExpr(other, self)
@dataclass
class Var(ExprBase):
name: str
def derivative(self, var: Var) -> Expr:
return 1 if var == self else 0
def simplify(self) -> Expr:
return self
def evaluate(self, assignments) -> Expr:
if self.name in assignments:
return assignments[self.name]
else:
return self
@dataclass
class PlusExpr(ExprBase):
left: Expr
right: Expr
def derivative(self, var: Var) -> Expr:
return derivative(self.left, var) + derivative(self.right, var)
def simplify(self):
match (self.left, self.right):
case (0, _):
return simplify(self.right)
case (_, 0):
return simplify(self.left)
case (int() as n, int() as m):
return n + m
case (l, r):
return PlusExpr(simplify(l), simplify(r))
def evaluate(self, assignments):
return evaluate(self.left, **assignments) + evaluate(self.right, **assignments)
@dataclass
class TimesExpr(ExprBase):
left: Expr
right: Expr
def derivative(self, var: Var) -> Expr:
return (
self.left * derivative(self.right, var) +
derivative(self.left, var) * self.right
)
def simplify(self):
match (self.left, self.right):
case (0, _):
return 0
case (_, 0):
return 0
case (1, _):
return simplify(self.right)
case (_, 1):
return simplify(self.left)
case (int() as n, int() as m):
return n + m
case (l, r):
return TimesExpr(simplify(l), simplify(r))
def evaluate(self, assignments):
return evaluate(self.left, **assignments) * evaluate(self.right, **assignments)
NumberLiteral = int | float
Expr = ExprBase | NumberLiteral
def derivative(expr: Expr, var: Var) -> Expr:
return expr.derivative(var) if isinstance(expr, ExprBase) else 0
def simplify(expr: Expr) -> Expr:
def simplify_inner(expr: Expr) -> Expr:
return expr.simplify() if isinstance(expr, ExprBase) else expr
prev = expr
simplified = simplify_inner(expr)
while simplified != prev:
prev = simplified
simplified = simplify_inner(simplified)
return simplified
def evaluate(expr: Expr, **assignments) -> Expr:
return (
expr.evaluate(assignments)
if isinstance(expr, ExprBase)
else expr
)
@dataclass
class PmlNodeKind:
name: str
def __call__(self, *children, **attributes):
return PmlNode(self, attributes, children)
@dataclass
class PmlNode:
kind: PmlNodeKind
attributes: dict[str, str]
children: list[PmlNode | str]
def to_html(self) -> str:
name = self.kind.name
attrs = (
" " + " ".join(
f'{name}="{value}"'
for name, value in self.attributes.items()
)
if self.attributes else
""
)
children = "".join(
child.to_html()
if isinstance(child, PmlNode) else
child
for child in self.children
)
return f"<{name}{attrs}>{children}</{name}>"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment