Skip to content

Instantly share code, notes, and snippets.

@marvinborner
Created May 5, 2024 00:44
Show Gist options
  • Save marvinborner/c60b4a4db8bbdefaf34f5e3424c50d06 to your computer and use it in GitHub Desktop.
Save marvinborner/c60b4a4db8bbdefaf34f5e3424c50d06 to your computer and use it in GitHub Desktop.
Optimal lambda calculus evaluator in Python based on @VictorTaelin's JavaScript abstract algorithm implementation (1:1 translation)
# translated from https://github.com/VictorTaelin/abstract-algorithm
import math
import re
# --- name.js ---
count = 0
def fresh():
global count
count += 1
return "$" + str(count)
def rank(num):
s = ""
num = num + 1
while num > 0:
num -= 1
s += chr(97 + num % 26)
num = math.floor(num / 26)
return s
def Var(name):
return {"ctor": "Var", "name": name}
def Lam(name, body):
return {"ctor": "Lam", "name": name, "body": body}
def App(func, argm):
return {"ctor": "App", "func": func, "argm": argm}
def lambda_show(term):
if term["ctor"] == "Var":
return term["name"]
elif term["ctor"] == "Lam":
name = term["name"]
body = lambda_show(term["body"])
return "λ" + name + "." + body
elif term["ctor"] == "App":
func = lambda_show(term["func"])
argm = lambda_show(term["argm"])
return "(" + func + " " + argm + ")"
def lambda_read(code):
indx = 0
def skip_space():
nonlocal indx
while indx < len(code) and re.search(r"[ \n]", code[indx]):
indx += 1
def read_name():
nonlocal indx
name = ""
while indx < len(code) and re.search(r"[a-zA-Z0-9]", code[indx]):
name += code[indx]
indx += 1
return name
def read_char(ch):
nonlocal indx
skip_space()
# if indx == code.length or code[indx] != ch:
# throw "Expected '"+ch+"', found '"+(code[indx]||"<eof>")+"' at "+indx+"."
indx += 1
defs = {}
def read_term(vars):
nonlocal indx
skip_space()
head = code[indx]
indx += 1
if head == "(":
term = read_term(vars)
while indx < len(code) and not re.search(r"^\s*\)", code[indx:]):
term = App(term, read_term(vars))
skip = read_char(")")
return term
elif head == "λ":
name = read_name()
skip = read_char(".")
body = read_term(vars + [name])
return Lam(name, body)
elif head == "@":
name = read_name()
term = read_term(vars)
body = read_term(vars + [name])
return App(Lam(name, body), term)
elif head == "$":
name = read_name(vars)
term = read_term(vars)
defs[name] = term
return read_term(vars)
else:
name = head + read_name()
if vars.index(name) != -1:
return Var(name)
elif defs[name]:
return defs[name]
return read_term([])
# --- comp.js ---
def compile(term):
vars = {}
def go(term, lams):
nonlocal vars
if term["ctor"] == "Lam":
this = fresh()
vars[this] = []
lams = lams.copy()
lams[term["name"]] = this
body = go(term["body"], lams)
if len(vars[this]) == 0:
varn = "-"
else:
varn = vars[this][0]
for i in range(1, len(vars[this])):
dup1 = varn
dup2 = vars[this][i]
varn = fresh()
lines.append("* " + varn + " " + dup1 + " " + dup2)
lines.append("- " + this + " " + varn + " " + body)
return this
elif term["ctor"] == "App":
this = fresh()
func = go(term["func"], lams)
argm = go(term["argm"], lams)
lines.append("- " + func + " " + argm + " " + this)
return this
elif term["ctor"] == "Var":
this = fresh()
lamb = lams[term["name"]]
vars[lams[term["name"]]].append(this)
return this
lines = []
init = go(term, {})
lines.append("- root " + init + " root")
return "\n".join(lines[::-1])
def decompile(inet):
def build_term(inet, ptr, vars, dup_exit):
if inet[ptr["node"]]["ctor"] == "Lam":
if ptr["slot"] == 0:
name = rank(len(vars))
vars.append({"ptr": Ptr(ptr["node"], 1), "name": name})
body = build_term(
inet, enter(inet, Ptr(ptr["node"], 2)), vars, dup_exit
)
vars.pop()
return Lam(name, body)
elif ptr["slot"] == 1:
for indx in range(len(vars)):
myvar = vars[len(vars) - indx - 1]
if equal(myvar["ptr"], ptr):
return Var(myvar["name"])
elif ptr["slot"] == 2:
argm = build_term(
inet, enter(inet, Ptr(ptr["node"], 1)), vars, dup_exit
)
func = build_term(
inet, enter(inet, Ptr(ptr["node"], 0)), vars, dup_exit
)
return App(func, argm)
else:
if ptr["slot"] == 0:
exit = dup_exit.pop()
term = build_term(
inet, enter(inet, Ptr(ptr["node"], exit)), vars, dup_exit
)
dup_exit.append(exit)
return term
else:
dup_exit.append(ptr["slot"])
term = build_term(
inet, enter(inet, Ptr(ptr["node"], 0)), vars, dup_exit
)
dup_exit.pop()
return term
return build_term(inet, enter(inet, Ptr("0", 1)), [], [])
# --- inet.js ---
def New(ctor, kind):
return {"ctor": ctor, "port": [None, None, None], "kind": kind}
def Ptr(node, slot):
return {"ctor": "Ptr", "node": node, "slot": slot}
def inet_show(inet):
text = ""
numb = 0
link = {}
for node in inet:
text += "- " if inet[node]["ctor"] == "Lam" else "+ "
for slot in range(3):
self_id = f"{node}-{slot}"
if equal(inet[node]["port"][slot], Ptr(node, slot)):
text += "- "
else:
neig_id = f'{inet[node]["port"][slot]["node"]}-{inet[node]["port"][slot]["slot"]}'
if neig_id in link and link[neig_id]:
text += link[neig_id] + " "
else:
link[self_id] = rank(numb)
numb += 1
text += link[self_id] + " "
text += "\n"
return text[:-1]
def inet_read(code):
lines = list(filter(lambda x: x != "", code.split("\n")))
link = {}
inet = {}
for node in range(len(lines)):
snode = str(node)
kind = lines[node][0]
vars = lines[node][2:].split(" ")
inet[snode] = New(
"Lam" if kind == "-" else ("Let" if kind == "+" else fresh()), {}
)
for slot in range(3):
name = vars[slot]
if name == "-":
inet[snode]["port"][slot] = Ptr(snode, slot)
elif name not in link:
link[name] = Ptr(snode, slot)
else:
inet[snode]["port"][slot] = link[name]
inet[link[name]["node"]]["port"][link[name]["slot"]] = Ptr(
snode, slot
)
return inet
def equal(a, b):
return b != None and a["node"] == b["node"] and a["slot"] == b["slot"]
def enter(inet, ptr):
if inet[ptr["node"]] != None:
return inet[ptr["node"]]["port"][ptr["slot"]]
else:
return None
def link(inet, a_ptr, b_ptr):
if a_ptr:
inet[a_ptr["node"]]["port"][a_ptr["slot"]] = b_ptr
if b_ptr:
inet[b_ptr["node"]]["port"][b_ptr["slot"]] = a_ptr
def unlink(inet, a_ptr):
b_ptr = enter(inet, a_ptr)
if equal(a_ptr, enter(inet, b_ptr)):
inet[a_ptr["node"]]["port"][a_ptr["slot"]] = a_ptr
inet[b_ptr["node"]]["port"][b_ptr["slot"]] = b_ptr
def annihilate(inet, a, b):
a_dest1 = enter(inet, Ptr(a, 1))
b_dest1 = enter(inet, Ptr(b, 1))
link(inet, a_dest1, b_dest1)
a_dest2 = enter(inet, Ptr(a, 2))
b_dest2 = enter(inet, Ptr(b, 2))
link(inet, a_dest2, b_dest2)
def commute(inet, a, b):
p = fresh()
q = fresh()
r = fresh()
s = fresh()
inet[p] = New(inet[b]["ctor"], inet[b]["kind"])
inet[q] = New(inet[b]["ctor"], inet[b]["kind"])
inet[r] = New(inet[a]["ctor"], inet[a]["kind"])
inet[s] = New(inet[a]["ctor"], inet[a]["kind"])
link(inet, Ptr(r, 1), Ptr(p, 1))
link(inet, Ptr(s, 1), Ptr(p, 2))
link(inet, Ptr(r, 2), Ptr(q, 1))
link(inet, Ptr(s, 2), Ptr(q, 2))
link(inet, Ptr(p, 0), enter(inet, Ptr(a, 1)))
link(inet, Ptr(q, 0), enter(inet, Ptr(a, 2)))
link(inet, Ptr(r, 0), enter(inet, Ptr(b, 1)))
link(inet, Ptr(s, 0), enter(inet, Ptr(b, 2)))
def rewrite(inet, a, b):
if inet[a]["ctor"] == inet[b]["ctor"]:
annihilate(inet, a, b)
else:
commute(inet, a, b)
for i in range(3):
unlink(inet, Ptr(a, i))
unlink(inet, Ptr(b, i))
del inet[a]
del inet[b]
def reduce(inet):
warp = []
exit = []
next = enter(inet, Ptr("0", 1))
prev = None
back = None
rwts = 0
while next["node"] != "0" or len(warp) > 0:
next = enter(inet, warp.pop()) if next["node"] == "0" else next
prev = enter(inet, next)
if next["slot"] == 0 and prev["slot"] == 0:
back = enter(inet, Ptr(prev["node"], exit.pop()))
rewrite(inet, prev["node"], next["node"])
next = enter(inet, back)
rwts += 1
elif next["slot"] == 0:
warp.append(Ptr(next["node"], 2))
next = enter(inet, Ptr(next["node"], 1))
else:
exit.append(next["slot"])
next = enter(inet, Ptr(next["node"], 0))
return rwts
# --- testing ---
term = lambda_read("(λf.λx.(f (f x)) λf.λx.(f (f x)))")
inet = inet_read(compile(term))
print(inet_show(inet))
redu = reduce(inet)
print(redu)
print(inet_show(inet))
print(lambda_show(decompile(inet)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment