-
-
Save cfbolz/ffd49e7a654c5ba224f73c1fc39673d0 to your computer and use it in GitHub Desktop.
various sketches of implementing unification with rational tree support
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pytest | |
| from dataclasses import dataclass, field | |
| class Expr: | |
| pass | |
| @dataclass | |
| class Var(Expr): | |
| name : str | |
| bound : Expr | None = None | |
| def __eq__(self, other): | |
| return isinstance(other, Var) and self.name == other.name | |
| def __hash__(self): | |
| return hash(self.name) | |
| @dataclass | |
| class Term(Expr): | |
| name : str | |
| children : list[Expr] = field(default_factory=list) | |
| def __eq__(self, other): | |
| return isinstance(other, Term) and self.name == other.name and self.children == other.children | |
| def __hash__(self): | |
| return hash((self.name, tuple(self.children))) | |
| class UnificationError(Exception): | |
| pass | |
| def unify(e1, e2): | |
| if isinstance(e2, Var): | |
| e1, e2 = e2, e1 | |
| if e1 is e2: | |
| return | |
| match (e1, e2): | |
| case (Term(name1, children1), Term(name2, children2)): | |
| if name1 != name2: | |
| raise UnificationError | |
| if len(children1) != len(children2): | |
| raise UnificationError | |
| for t1, t2 in zip(children1, children2): | |
| unify(t1, t2) | |
| case (Var(name1, None), _): | |
| e1.bound = e2 | |
| case (Var(name, bound), _): | |
| e1.bound = e2 # this is all that's needed for supporting rational trees | |
| unify(bound, e2) | |
| case _: | |
| raise UnificationError | |
| def test_simple(): | |
| v1 = Var("X") | |
| t1 = Term("a") | |
| unify(v1, t1) | |
| assert v1.bound == t1 | |
| with pytest.raises(UnificationError): | |
| unify(v1, Term("b")) | |
| def test_rational_trees(): | |
| X = Var("X") | |
| t1 = Term("a", [X]) | |
| unify(X, t1) | |
| unify(t1, t1) # loop | |
| Y = Var("Y") | |
| t2 = Term("a", [Y]) | |
| unify(Y, t2) | |
| unify(t1, t2) # loop | |
| def linear_unify(e1, e2, subst=None): | |
| if subst is None: | |
| subst = {} | |
| subst = subst.copy() # never mutate existing substs | |
| todo = [(e1, e2)] | |
| while todo: | |
| e1, e2 = todo.pop() | |
| if isinstance(e2, Var): | |
| e1, e2 = e2, e1 | |
| match (e1, e2): | |
| case (Term(name1, children1), Term(name2, children2)): | |
| if name1 != name2: | |
| raise UnificationError | |
| if len(children1) != len(children2): | |
| raise UnificationError | |
| for t1, t2 in zip(children1, children2): | |
| todo.append((t1, t2)) | |
| case (Var(name1, _), Var(name2, _)) if name1 == name2: | |
| pass | |
| case (Var(name1, _), _): | |
| bound = subst.get(name1) | |
| subst[name1] = e2 | |
| if bound is not None: | |
| todo.append((bound, e2)) | |
| case _: | |
| raise UnificationError | |
| return subst | |
| def test_rational_trees_linear(): | |
| X = Var("X") | |
| t1 = Term("a", [X]) | |
| subst = linear_unify(X, t1) | |
| assert subst == {"X": t1} | |
| linear_unify(t1, t1, subst) # loop | |
| Y = Var("Y") | |
| t2 = Term("a", [Y]) | |
| subst2 = linear_unify(Y, t2, subst) | |
| assert subst2 == {"X": t1, "Y": t2} | |
| subst3 = linear_unify(t1, t2, subst2) # loop | |
| assert subst3 == {"X": t2, "Y": X} | |
| def test_rational_trees_linear2(): | |
| X = Var("X") | |
| t1 = Term("a", [X]) | |
| subst = linear_unify(X, t1) | |
| assert subst == {"X": t1} | |
| linear_unify(t1, t1, subst) # loop | |
| Y = Var("Y") | |
| t2 = Term("a", [Term("a", [Term("a", [Y])])]) | |
| subst2 = linear_unify(Y, t2, subst) | |
| assert subst2 == {"X": t1, "Y": t2} | |
| subst3 = linear_unify(t2, t1, subst2) # loop | |
| number = -1 | |
| def new_str(): | |
| global number | |
| number += 1 | |
| return f"state{number}" | |
| @dataclass | |
| class DFA: | |
| final : bool | |
| on_a : "DFA | None" | |
| on_b : "DFA | None" | |
| name : str = field(default_factory=new_str) | |
| def is_same(a1, a2): | |
| todo = [(a1, a2)] | |
| while todo: | |
| a1, a2 = todo.pop() | |
| if a1 is a2 is None: | |
| continue | |
| if a1 is None and a2 is not None: | |
| raise UnificationError | |
| if a1 is not None and a2 is None: | |
| raise UnificationError | |
| if a1.name == a2.name: | |
| continue | |
| if a1.final != a2.final: | |
| raise UnificationError | |
| a1_on_a = a1.on_a | |
| a1.on_a = a2.on_a | |
| todo.append((a1_on_a, a2.on_a)) | |
| a1_on_b = a1.on_b | |
| a1.on_b = a2.on_b | |
| todo.append((a1_on_b, a2.on_b)) | |
| def test_dfa_is_same_linear(): | |
| dfa1 = DFA(False, DFA(True, None, None), None) | |
| dfa2 = DFA(False, DFA(True, None, None), None) | |
| is_same(dfa1, dfa2) | |
| def test_dfa_loop(): | |
| dfa1 = DFA(True, None, None) | |
| dfa1.on_a = dfa1 | |
| # a* | |
| is_same(dfa1, dfa1) | |
| dfa2 = DFA(True, None, None) | |
| dfa2.on_a = dfa2 | |
| is_same(dfa1, dfa2) | |
| dfa3 = DFA(True, DFA(True, None, None), None) | |
| dfa3.on_a.on_a = dfa3 | |
| is_same(dfa1, dfa3) | |
| def unify_with_tabling(e1, e2, table=None): | |
| if table is None: | |
| table = frozenset() | |
| if isinstance(e2, Var): | |
| e1, e2 = e2, e1 | |
| if (e1, e2) in table: | |
| return | |
| table = table | {(e1, e2)} | |
| match (e1, e2): | |
| case (Term(name1, children1), Term(name2, children2)): | |
| if name1 != name2: | |
| raise UnificationError | |
| if len(children1) != len(children2): | |
| raise UnificationError | |
| for t1, t2 in zip(children1, children2): | |
| unify_with_tabling(t1, t2, table) | |
| case (Var(name1, None), _): | |
| e1.bound = e2 | |
| case (Var(name, bound), _): | |
| unify_with_tabling(bound, e2, table) | |
| case _: | |
| raise UnificationError | |
| def test_unify_with_tabling(): | |
| X = Var("X") | |
| t1 = Term("a", [X]) | |
| unify_with_tabling(X, t1) | |
| unify_with_tabling(t1, t1) # loop | |
| Y = Var("Y") | |
| t2 = Term("a", [Y]) | |
| unify_with_tabling(Y, t2) | |
| unify_with_tabling(t1, t2) # loop | |
| def linear_unify_with_tabling(e1, e2, subst=None, table=None): | |
| if subst is None: | |
| subst = {} | |
| if table is None: | |
| table = frozenset() | |
| subst = subst.copy() # never mutate existing substs | |
| todo = [(e1, e2, table)] | |
| while todo: | |
| e1, e2, table = todo.pop() | |
| if (e1, e2) in table: | |
| continue | |
| if isinstance(e2, Var): | |
| e1, e2 = e2, e1 | |
| table = table | {(e1, e2)} | |
| match (e1, e2): | |
| case (Term(name1, children1), Term(name2, children2)): | |
| if name1 != name2: | |
| raise UnificationError | |
| if len(children1) != len(children2): | |
| raise UnificationError | |
| for t1, t2 in zip(children1, children2): | |
| todo.append((t1, t2, table)) | |
| case (Var(name1, _), _): | |
| bound = subst.get(name1) | |
| subst[name1] = e2 | |
| if bound is not None: | |
| todo.append((bound, e2, table)) | |
| case _: | |
| raise UnificationError | |
| return subst | |
| def test_rational_trees_linear_with_tabling(): | |
| X = Var("X") | |
| t1 = Term("a", [X]) | |
| subst = linear_unify_with_tabling(X, t1) | |
| assert subst == {"X": t1} | |
| linear_unify(t1, t1, subst) # loop | |
| Y = Var("Y") | |
| t2 = Term("a", [Y]) | |
| subst2 = linear_unify_with_tabling(Y, t2, subst) | |
| assert subst2 == {"X": t1, "Y": t2} | |
| subst3 = linear_unify_with_tabling(t1, t2, subst2) # loop | |
| assert subst3 == {"X": t2, "Y": X} | |
| def test_rational_trees_linear2_with_tabling(): | |
| X = Var("X") | |
| t1 = Term("a", [X]) | |
| subst = linear_unify(X, t1) | |
| assert subst == {"X": t1} | |
| linear_unify(t1, t1, subst) # loop | |
| Y = Var("Y") | |
| t2 = Term("a", [Term("a", [Term("a", [Y])])]) | |
| subst2 = linear_unify_with_tabling(Y, t2, subst) | |
| assert subst2 == {"X": t1, "Y": t2} | |
| subst3 = linear_unify_with_tabling(t2, t1, subst2) # loop | |
| # doesn't work: | |
| # def unify_with_tabling_occurs_check(e1, e2, table=None): | |
| # if table is None: | |
| # table = frozenset() | |
| # if isinstance(e2, Var): | |
| # e1, e2 = e2, e1 | |
| # if (e1, e2) in table: | |
| # raise UnificationError | |
| # table = table | {(e1, e2)} | |
| # match (e1, e2): | |
| # case (Term(name1, children1), Term(name2, children2)): | |
| # if name1 != name2: | |
| # raise UnificationError | |
| # if len(children1) != len(children2): | |
| # raise UnificationError | |
| # for t1, t2 in zip(children1, children2): | |
| # unify_with_tabling_occurs_check(t1, t2, table) | |
| # case (Var(name1, bound), _): | |
| # if bound is None: | |
| # e1.bound = e2 | |
| # unify_with_tabling_occurs_check(e1.bound, e2, table) | |
| # case _: | |
| # raise UnificationError | |
| # | |
| # | |
| # def test_unify_with_tabling_occurs_check(): | |
| # X = Var("X") | |
| # t1 = Term("a", [X]) | |
| # import pdb;pdb.set_trace() | |
| # unify_with_tabling_occurs_check(X, t1) | |
| # unify_with_tabling_occurs_check(t1, t1) # loop | |
| # | |
| # Y = Var("Y") | |
| # t2 = Term("a", [Y]) | |
| # unify_with_tabling_occurs_check(Y, t2) | |
| # unify_with_tabling_occurs_check(t1, t2) # loop | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment