Skip to content

Instantly share code, notes, and snippets.

@Alwinfy
Created December 11, 2023 01:10
Show Gist options
  • Save Alwinfy/6472ff913caf6426b6641304d199cfba to your computer and use it in GitHub Desktop.
Save Alwinfy/6472ff913caf6426b6641304d199cfba to your computer and use it in GitHub Desktop.
from __future__ import annotations
from dataclasses import dataclass
def prod(args):
out = 1
for i in args:
out *= i
return out
@dataclass
class Function:
name: str
args: list[str | int | Function]
def eval_fun(f: str | int | Function, mappings: dict[str, int]) -> int:
match f:
case str(s):
if s not in mappings:
raise ValueError("No value for: " + s)
return mappings[s]
case int(s):
return s
case Function(name, args):
children = [eval_fun(arg, mappings) for arg in args]
if name == "add":
return sum(children)
if name == "mul":
return prod(children)
raise ValueError("Bad function: " + self.name)
simple_ast = Function("add", ["x", Function("mul", [3, "y"])])
print(simple_ast)
print(eval_fun(simple_ast, {"x": 1, "y": 2}))
from dataclasses import dataclass
def prod(args):
out = 1
for i in args:
out *= i
return out
class PValue:
def eval(self, mappings: dict[str, int]) -> int:
pass
@dataclass
class PFunction:
name: str
args: list[PValue]
def eval(self, mappings: dict[str, int]) -> int:
children = [a.eval(mappings) for a in self.args]
if self.name == "add":
return sum(children)
if self.name == "mul":
return prod(children)
raise ValueError("Bad function: " + self.name)
@dataclass
class PInt:
value: int
def eval(self, mappings: dict[str, int]) -> int:
return self.value
@dataclass
class PSym:
sym: str
def eval(self, mappings: dict[str, int]) -> int:
if self.sym not in mappings:
raise ValueError("No value for: " + self.sym)
return mappings[self.sym]
poly_ast = PFunction("add", [PSym("x"), PFunction("mul", [PInt(3), PSym("y")])])
print(poly_ast)
print(poly_ast.eval({"x": 1, "y": 2}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment