Skip to content

Instantly share code, notes, and snippets.

@cfbolz
Created January 13, 2026 11:41
Show Gist options
  • Select an option

  • Save cfbolz/ffd49e7a654c5ba224f73c1fc39673d0 to your computer and use it in GitHub Desktop.

Select an option

Save cfbolz/ffd49e7a654c5ba224f73c1fc39673d0 to your computer and use it in GitHub Desktop.
various sketches of implementing unification with rational tree support
import pytest
from dataclasses import dataclass, field
class Expr:
pass
@dataclass
class Var(Expr):
name : str
bound : Expr | None = None
def __eq__(self, other):
return isinstance(other, Var) and self.name == other.name
def __hash__(self):
return hash(self.name)
@dataclass
class Term(Expr):
name : str
children : list[Expr] = field(default_factory=list)
def __eq__(self, other):
return isinstance(other, Term) and self.name == other.name and self.children == other.children
def __hash__(self):
return hash((self.name, tuple(self.children)))
class UnificationError(Exception):
pass
def unify(e1, e2):
if isinstance(e2, Var):
e1, e2 = e2, e1
if e1 is e2:
return
match (e1, e2):
case (Term(name1, children1), Term(name2, children2)):
if name1 != name2:
raise UnificationError
if len(children1) != len(children2):
raise UnificationError
for t1, t2 in zip(children1, children2):
unify(t1, t2)
case (Var(name1, None), _):
e1.bound = e2
case (Var(name, bound), _):
e1.bound = e2 # this is all that's needed for supporting rational trees
unify(bound, e2)
case _:
raise UnificationError
def test_simple():
v1 = Var("X")
t1 = Term("a")
unify(v1, t1)
assert v1.bound == t1
with pytest.raises(UnificationError):
unify(v1, Term("b"))
def test_rational_trees():
X = Var("X")
t1 = Term("a", [X])
unify(X, t1)
unify(t1, t1) # loop
Y = Var("Y")
t2 = Term("a", [Y])
unify(Y, t2)
unify(t1, t2) # loop
def linear_unify(e1, e2, subst=None):
if subst is None:
subst = {}
subst = subst.copy() # never mutate existing substs
todo = [(e1, e2)]
while todo:
e1, e2 = todo.pop()
if isinstance(e2, Var):
e1, e2 = e2, e1
match (e1, e2):
case (Term(name1, children1), Term(name2, children2)):
if name1 != name2:
raise UnificationError
if len(children1) != len(children2):
raise UnificationError
for t1, t2 in zip(children1, children2):
todo.append((t1, t2))
case (Var(name1, _), Var(name2, _)) if name1 == name2:
pass
case (Var(name1, _), _):
bound = subst.get(name1)
subst[name1] = e2
if bound is not None:
todo.append((bound, e2))
case _:
raise UnificationError
return subst
def test_rational_trees_linear():
X = Var("X")
t1 = Term("a", [X])
subst = linear_unify(X, t1)
assert subst == {"X": t1}
linear_unify(t1, t1, subst) # loop
Y = Var("Y")
t2 = Term("a", [Y])
subst2 = linear_unify(Y, t2, subst)
assert subst2 == {"X": t1, "Y": t2}
subst3 = linear_unify(t1, t2, subst2) # loop
assert subst3 == {"X": t2, "Y": X}
def test_rational_trees_linear2():
X = Var("X")
t1 = Term("a", [X])
subst = linear_unify(X, t1)
assert subst == {"X": t1}
linear_unify(t1, t1, subst) # loop
Y = Var("Y")
t2 = Term("a", [Term("a", [Term("a", [Y])])])
subst2 = linear_unify(Y, t2, subst)
assert subst2 == {"X": t1, "Y": t2}
subst3 = linear_unify(t2, t1, subst2) # loop
number = -1
def new_str():
global number
number += 1
return f"state{number}"
@dataclass
class DFA:
final : bool
on_a : "DFA | None"
on_b : "DFA | None"
name : str = field(default_factory=new_str)
def is_same(a1, a2):
todo = [(a1, a2)]
while todo:
a1, a2 = todo.pop()
if a1 is a2 is None:
continue
if a1 is None and a2 is not None:
raise UnificationError
if a1 is not None and a2 is None:
raise UnificationError
if a1.name == a2.name:
continue
if a1.final != a2.final:
raise UnificationError
a1_on_a = a1.on_a
a1.on_a = a2.on_a
todo.append((a1_on_a, a2.on_a))
a1_on_b = a1.on_b
a1.on_b = a2.on_b
todo.append((a1_on_b, a2.on_b))
def test_dfa_is_same_linear():
dfa1 = DFA(False, DFA(True, None, None), None)
dfa2 = DFA(False, DFA(True, None, None), None)
is_same(dfa1, dfa2)
def test_dfa_loop():
dfa1 = DFA(True, None, None)
dfa1.on_a = dfa1
# a*
is_same(dfa1, dfa1)
dfa2 = DFA(True, None, None)
dfa2.on_a = dfa2
is_same(dfa1, dfa2)
dfa3 = DFA(True, DFA(True, None, None), None)
dfa3.on_a.on_a = dfa3
is_same(dfa1, dfa3)
def unify_with_tabling(e1, e2, table=None):
if table is None:
table = frozenset()
if isinstance(e2, Var):
e1, e2 = e2, e1
if (e1, e2) in table:
return
table = table | {(e1, e2)}
match (e1, e2):
case (Term(name1, children1), Term(name2, children2)):
if name1 != name2:
raise UnificationError
if len(children1) != len(children2):
raise UnificationError
for t1, t2 in zip(children1, children2):
unify_with_tabling(t1, t2, table)
case (Var(name1, None), _):
e1.bound = e2
case (Var(name, bound), _):
unify_with_tabling(bound, e2, table)
case _:
raise UnificationError
def test_unify_with_tabling():
X = Var("X")
t1 = Term("a", [X])
unify_with_tabling(X, t1)
unify_with_tabling(t1, t1) # loop
Y = Var("Y")
t2 = Term("a", [Y])
unify_with_tabling(Y, t2)
unify_with_tabling(t1, t2) # loop
def linear_unify_with_tabling(e1, e2, subst=None, table=None):
if subst is None:
subst = {}
if table is None:
table = frozenset()
subst = subst.copy() # never mutate existing substs
todo = [(e1, e2, table)]
while todo:
e1, e2, table = todo.pop()
if (e1, e2) in table:
continue
if isinstance(e2, Var):
e1, e2 = e2, e1
table = table | {(e1, e2)}
match (e1, e2):
case (Term(name1, children1), Term(name2, children2)):
if name1 != name2:
raise UnificationError
if len(children1) != len(children2):
raise UnificationError
for t1, t2 in zip(children1, children2):
todo.append((t1, t2, table))
case (Var(name1, _), _):
bound = subst.get(name1)
subst[name1] = e2
if bound is not None:
todo.append((bound, e2, table))
case _:
raise UnificationError
return subst
def test_rational_trees_linear_with_tabling():
X = Var("X")
t1 = Term("a", [X])
subst = linear_unify_with_tabling(X, t1)
assert subst == {"X": t1}
linear_unify(t1, t1, subst) # loop
Y = Var("Y")
t2 = Term("a", [Y])
subst2 = linear_unify_with_tabling(Y, t2, subst)
assert subst2 == {"X": t1, "Y": t2}
subst3 = linear_unify_with_tabling(t1, t2, subst2) # loop
assert subst3 == {"X": t2, "Y": X}
def test_rational_trees_linear2_with_tabling():
X = Var("X")
t1 = Term("a", [X])
subst = linear_unify(X, t1)
assert subst == {"X": t1}
linear_unify(t1, t1, subst) # loop
Y = Var("Y")
t2 = Term("a", [Term("a", [Term("a", [Y])])])
subst2 = linear_unify_with_tabling(Y, t2, subst)
assert subst2 == {"X": t1, "Y": t2}
subst3 = linear_unify_with_tabling(t2, t1, subst2) # loop
# doesn't work:
# def unify_with_tabling_occurs_check(e1, e2, table=None):
# if table is None:
# table = frozenset()
# if isinstance(e2, Var):
# e1, e2 = e2, e1
# if (e1, e2) in table:
# raise UnificationError
# table = table | {(e1, e2)}
# match (e1, e2):
# case (Term(name1, children1), Term(name2, children2)):
# if name1 != name2:
# raise UnificationError
# if len(children1) != len(children2):
# raise UnificationError
# for t1, t2 in zip(children1, children2):
# unify_with_tabling_occurs_check(t1, t2, table)
# case (Var(name1, bound), _):
# if bound is None:
# e1.bound = e2
# unify_with_tabling_occurs_check(e1.bound, e2, table)
# case _:
# raise UnificationError
#
#
# def test_unify_with_tabling_occurs_check():
# X = Var("X")
# t1 = Term("a", [X])
# import pdb;pdb.set_trace()
# unify_with_tabling_occurs_check(X, t1)
# unify_with_tabling_occurs_check(t1, t1) # loop
#
# Y = Var("Y")
# t2 = Term("a", [Y])
# unify_with_tabling_occurs_check(Y, t2)
# unify_with_tabling_occurs_check(t1, t2) # loop
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment