Created
May 5, 2024 00:44
-
-
Save marvinborner/c60b4a4db8bbdefaf34f5e3424c50d06 to your computer and use it in GitHub Desktop.
Optimal lambda calculus evaluator in Python based on @VictorTaelin's JavaScript abstract algorithm implementation (1:1 translation)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# translated from https://github.com/VictorTaelin/abstract-algorithm | |
import math | |
import re | |
# --- name.js --- | |
count = 0 | |
def fresh(): | |
global count | |
count += 1 | |
return "$" + str(count) | |
def rank(num): | |
s = "" | |
num = num + 1 | |
while num > 0: | |
num -= 1 | |
s += chr(97 + num % 26) | |
num = math.floor(num / 26) | |
return s | |
def Var(name): | |
return {"ctor": "Var", "name": name} | |
def Lam(name, body): | |
return {"ctor": "Lam", "name": name, "body": body} | |
def App(func, argm): | |
return {"ctor": "App", "func": func, "argm": argm} | |
def lambda_show(term): | |
if term["ctor"] == "Var": | |
return term["name"] | |
elif term["ctor"] == "Lam": | |
name = term["name"] | |
body = lambda_show(term["body"]) | |
return "λ" + name + "." + body | |
elif term["ctor"] == "App": | |
func = lambda_show(term["func"]) | |
argm = lambda_show(term["argm"]) | |
return "(" + func + " " + argm + ")" | |
def lambda_read(code): | |
indx = 0 | |
def skip_space(): | |
nonlocal indx | |
while indx < len(code) and re.search(r"[ \n]", code[indx]): | |
indx += 1 | |
def read_name(): | |
nonlocal indx | |
name = "" | |
while indx < len(code) and re.search(r"[a-zA-Z0-9]", code[indx]): | |
name += code[indx] | |
indx += 1 | |
return name | |
def read_char(ch): | |
nonlocal indx | |
skip_space() | |
# if indx == code.length or code[indx] != ch: | |
# throw "Expected '"+ch+"', found '"+(code[indx]||"<eof>")+"' at "+indx+"." | |
indx += 1 | |
defs = {} | |
def read_term(vars): | |
nonlocal indx | |
skip_space() | |
head = code[indx] | |
indx += 1 | |
if head == "(": | |
term = read_term(vars) | |
while indx < len(code) and not re.search(r"^\s*\)", code[indx:]): | |
term = App(term, read_term(vars)) | |
skip = read_char(")") | |
return term | |
elif head == "λ": | |
name = read_name() | |
skip = read_char(".") | |
body = read_term(vars + [name]) | |
return Lam(name, body) | |
elif head == "@": | |
name = read_name() | |
term = read_term(vars) | |
body = read_term(vars + [name]) | |
return App(Lam(name, body), term) | |
elif head == "$": | |
name = read_name(vars) | |
term = read_term(vars) | |
defs[name] = term | |
return read_term(vars) | |
else: | |
name = head + read_name() | |
if vars.index(name) != -1: | |
return Var(name) | |
elif defs[name]: | |
return defs[name] | |
return read_term([]) | |
# --- comp.js --- | |
def compile(term): | |
vars = {} | |
def go(term, lams): | |
nonlocal vars | |
if term["ctor"] == "Lam": | |
this = fresh() | |
vars[this] = [] | |
lams = lams.copy() | |
lams[term["name"]] = this | |
body = go(term["body"], lams) | |
if len(vars[this]) == 0: | |
varn = "-" | |
else: | |
varn = vars[this][0] | |
for i in range(1, len(vars[this])): | |
dup1 = varn | |
dup2 = vars[this][i] | |
varn = fresh() | |
lines.append("* " + varn + " " + dup1 + " " + dup2) | |
lines.append("- " + this + " " + varn + " " + body) | |
return this | |
elif term["ctor"] == "App": | |
this = fresh() | |
func = go(term["func"], lams) | |
argm = go(term["argm"], lams) | |
lines.append("- " + func + " " + argm + " " + this) | |
return this | |
elif term["ctor"] == "Var": | |
this = fresh() | |
lamb = lams[term["name"]] | |
vars[lams[term["name"]]].append(this) | |
return this | |
lines = [] | |
init = go(term, {}) | |
lines.append("- root " + init + " root") | |
return "\n".join(lines[::-1]) | |
def decompile(inet): | |
def build_term(inet, ptr, vars, dup_exit): | |
if inet[ptr["node"]]["ctor"] == "Lam": | |
if ptr["slot"] == 0: | |
name = rank(len(vars)) | |
vars.append({"ptr": Ptr(ptr["node"], 1), "name": name}) | |
body = build_term( | |
inet, enter(inet, Ptr(ptr["node"], 2)), vars, dup_exit | |
) | |
vars.pop() | |
return Lam(name, body) | |
elif ptr["slot"] == 1: | |
for indx in range(len(vars)): | |
myvar = vars[len(vars) - indx - 1] | |
if equal(myvar["ptr"], ptr): | |
return Var(myvar["name"]) | |
elif ptr["slot"] == 2: | |
argm = build_term( | |
inet, enter(inet, Ptr(ptr["node"], 1)), vars, dup_exit | |
) | |
func = build_term( | |
inet, enter(inet, Ptr(ptr["node"], 0)), vars, dup_exit | |
) | |
return App(func, argm) | |
else: | |
if ptr["slot"] == 0: | |
exit = dup_exit.pop() | |
term = build_term( | |
inet, enter(inet, Ptr(ptr["node"], exit)), vars, dup_exit | |
) | |
dup_exit.append(exit) | |
return term | |
else: | |
dup_exit.append(ptr["slot"]) | |
term = build_term( | |
inet, enter(inet, Ptr(ptr["node"], 0)), vars, dup_exit | |
) | |
dup_exit.pop() | |
return term | |
return build_term(inet, enter(inet, Ptr("0", 1)), [], []) | |
# --- inet.js --- | |
def New(ctor, kind): | |
return {"ctor": ctor, "port": [None, None, None], "kind": kind} | |
def Ptr(node, slot): | |
return {"ctor": "Ptr", "node": node, "slot": slot} | |
def inet_show(inet): | |
text = "" | |
numb = 0 | |
link = {} | |
for node in inet: | |
text += "- " if inet[node]["ctor"] == "Lam" else "+ " | |
for slot in range(3): | |
self_id = f"{node}-{slot}" | |
if equal(inet[node]["port"][slot], Ptr(node, slot)): | |
text += "- " | |
else: | |
neig_id = f'{inet[node]["port"][slot]["node"]}-{inet[node]["port"][slot]["slot"]}' | |
if neig_id in link and link[neig_id]: | |
text += link[neig_id] + " " | |
else: | |
link[self_id] = rank(numb) | |
numb += 1 | |
text += link[self_id] + " " | |
text += "\n" | |
return text[:-1] | |
def inet_read(code): | |
lines = list(filter(lambda x: x != "", code.split("\n"))) | |
link = {} | |
inet = {} | |
for node in range(len(lines)): | |
snode = str(node) | |
kind = lines[node][0] | |
vars = lines[node][2:].split(" ") | |
inet[snode] = New( | |
"Lam" if kind == "-" else ("Let" if kind == "+" else fresh()), {} | |
) | |
for slot in range(3): | |
name = vars[slot] | |
if name == "-": | |
inet[snode]["port"][slot] = Ptr(snode, slot) | |
elif name not in link: | |
link[name] = Ptr(snode, slot) | |
else: | |
inet[snode]["port"][slot] = link[name] | |
inet[link[name]["node"]]["port"][link[name]["slot"]] = Ptr( | |
snode, slot | |
) | |
return inet | |
def equal(a, b): | |
return b != None and a["node"] == b["node"] and a["slot"] == b["slot"] | |
def enter(inet, ptr): | |
if inet[ptr["node"]] != None: | |
return inet[ptr["node"]]["port"][ptr["slot"]] | |
else: | |
return None | |
def link(inet, a_ptr, b_ptr): | |
if a_ptr: | |
inet[a_ptr["node"]]["port"][a_ptr["slot"]] = b_ptr | |
if b_ptr: | |
inet[b_ptr["node"]]["port"][b_ptr["slot"]] = a_ptr | |
def unlink(inet, a_ptr): | |
b_ptr = enter(inet, a_ptr) | |
if equal(a_ptr, enter(inet, b_ptr)): | |
inet[a_ptr["node"]]["port"][a_ptr["slot"]] = a_ptr | |
inet[b_ptr["node"]]["port"][b_ptr["slot"]] = b_ptr | |
def annihilate(inet, a, b): | |
a_dest1 = enter(inet, Ptr(a, 1)) | |
b_dest1 = enter(inet, Ptr(b, 1)) | |
link(inet, a_dest1, b_dest1) | |
a_dest2 = enter(inet, Ptr(a, 2)) | |
b_dest2 = enter(inet, Ptr(b, 2)) | |
link(inet, a_dest2, b_dest2) | |
def commute(inet, a, b): | |
p = fresh() | |
q = fresh() | |
r = fresh() | |
s = fresh() | |
inet[p] = New(inet[b]["ctor"], inet[b]["kind"]) | |
inet[q] = New(inet[b]["ctor"], inet[b]["kind"]) | |
inet[r] = New(inet[a]["ctor"], inet[a]["kind"]) | |
inet[s] = New(inet[a]["ctor"], inet[a]["kind"]) | |
link(inet, Ptr(r, 1), Ptr(p, 1)) | |
link(inet, Ptr(s, 1), Ptr(p, 2)) | |
link(inet, Ptr(r, 2), Ptr(q, 1)) | |
link(inet, Ptr(s, 2), Ptr(q, 2)) | |
link(inet, Ptr(p, 0), enter(inet, Ptr(a, 1))) | |
link(inet, Ptr(q, 0), enter(inet, Ptr(a, 2))) | |
link(inet, Ptr(r, 0), enter(inet, Ptr(b, 1))) | |
link(inet, Ptr(s, 0), enter(inet, Ptr(b, 2))) | |
def rewrite(inet, a, b): | |
if inet[a]["ctor"] == inet[b]["ctor"]: | |
annihilate(inet, a, b) | |
else: | |
commute(inet, a, b) | |
for i in range(3): | |
unlink(inet, Ptr(a, i)) | |
unlink(inet, Ptr(b, i)) | |
del inet[a] | |
del inet[b] | |
def reduce(inet): | |
warp = [] | |
exit = [] | |
next = enter(inet, Ptr("0", 1)) | |
prev = None | |
back = None | |
rwts = 0 | |
while next["node"] != "0" or len(warp) > 0: | |
next = enter(inet, warp.pop()) if next["node"] == "0" else next | |
prev = enter(inet, next) | |
if next["slot"] == 0 and prev["slot"] == 0: | |
back = enter(inet, Ptr(prev["node"], exit.pop())) | |
rewrite(inet, prev["node"], next["node"]) | |
next = enter(inet, back) | |
rwts += 1 | |
elif next["slot"] == 0: | |
warp.append(Ptr(next["node"], 2)) | |
next = enter(inet, Ptr(next["node"], 1)) | |
else: | |
exit.append(next["slot"]) | |
next = enter(inet, Ptr(next["node"], 0)) | |
return rwts | |
# --- testing --- | |
term = lambda_read("(λf.λx.(f (f x)) λf.λx.(f (f x)))") | |
inet = inet_read(compile(term)) | |
print(inet_show(inet)) | |
redu = reduce(inet) | |
print(redu) | |
print(inet_show(inet)) | |
print(lambda_show(decompile(inet))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment