Last active
January 29, 2022 20:21
-
-
Save cheery/285e81d55ca4d9ec44d628978545a866 to your computer and use it in GitHub Desktop.
ukanren with untested CHR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- encoding: utf-8 -*- | |
class TermOrVar(object): | |
def to_str(self, rbp=0): | |
assert False | |
class Term(TermOrVar): | |
def __init__(self, name, args): | |
self.name = name | |
self.args = args | |
for arg in args: | |
assert isinstance(arg, TermOrVar) | |
def to_str(self, rbp=0): | |
sig = self.get_signature() | |
if rbp > get_rbp(sig): | |
return "(" + self.to_str() + ")" | |
elif len(self.args) == 0: | |
return self.name | |
elif sig == "./2": | |
a = self.args[0].to_str(11) | |
b = self.args[1].to_str(10) | |
return a + " " + self.name + " " + b | |
elif sig == "⊸/2": | |
a = self.args[0].to_str(21) | |
b = self.args[1].to_str(20) | |
return a + " " + self.name + " " + b | |
elif sig == "+/2": | |
a = self.args[0].to_str(31) | |
b = self.args[1].to_str(30) | |
return a + " " + self.name + " " + b | |
elif sig == "⅋/2": | |
a = self.args[0].to_str(36) | |
b = self.args[1].to_str(35) | |
return a + " " + self.name + " " + b | |
elif sig == "⊗/2": | |
a = self.args[0].to_str(41) | |
b = self.args[1].to_str(40) | |
return a + " " + self.name + " " + b | |
elif sig == "!/1": | |
a = self.args[0].to_str(51) | |
return self.name + a | |
else: | |
args = "" | |
sp = " " | |
for arg in self.args: | |
args += sp + "(" + arg.to_str() + ")" | |
sp = " " | |
return self.name + args | |
def get_signature(self): | |
return self.name + "/" + str(len(self.args)) | |
def get_rbp(sig): | |
if sig == "./2": | |
return 10 | |
if sig == "⊸/2": | |
return 20 | |
if sig == "+/2": | |
return 30 | |
if sig == "⅋/2": | |
return 35 | |
if sig == "⊗/2": | |
return 40 | |
if sig == "!/1": | |
return 50 | |
return 60 | |
class Var(TermOrVar): | |
def __init__(self, key): | |
self.key = key | |
def to_str(self, rbp=0): | |
return "#" + str(self.key) | |
class Const(TermOrVar): | |
def __init__(self, key): | |
self.key = key | |
def to_str(self, rbp=0): | |
return "$" + str(self.key) | |
# Pattern hole for constraint handling rule patterns. | |
class Hole(TermOrVar): | |
def __init__(self, key): | |
self.key = key | |
def to_str(self, rbp=0): | |
return "$HOLE{" + str(self.key) + "}" | |
class Storage(object): | |
def __init__(self, rules, subst, nextvar, constraints, nextcid, notes, killed): | |
self.rules = rules | |
self.subst = subst | |
self.nextvar = nextvar | |
self.constraints = constraints | |
self.nextcid = nextcid | |
self.notes = notes | |
self.killed = killed | |
default_rules = [] # Filled later. | |
def empty(rules=default_rules): | |
return Storage(rules, dict(), 0, dict(), 0, dict(), dict()) | |
def copy_constraints(constraints): | |
nconstraints = dict() | |
for sig, table in constraints.items(): | |
nconstraints[sig] = table.copy() | |
return nconstraints | |
def copy_notes(notes): | |
nnotes = dict() | |
for key, table in notes.items(): | |
nnotes[key] = table.copy() | |
return nnotes | |
class Constraint(object): | |
def __init__(self, term, id): | |
self.term = term | |
self.id = id | |
def make_constraint(term, state): | |
assert isinstance(term, Term) | |
c = Constraint(walk(term, state.subst), state.nextcid) | |
state = Storage(state.rules, | |
state.subst, | |
state.nextvar, | |
state.constraints, | |
state.nextcid+1, | |
state.notes, | |
state.killed) | |
return reactivate(c, state) | |
def alive(c, state): | |
if c.id in state.killed: | |
return False | |
return True | |
# We won't access constraints outside our scope. | |
#table = state.constraints.get(c.term.get_signature(), None) | |
#if table is not None: | |
# if c.id in table: | |
# return True | |
#return False | |
def kill(c, state): | |
state = Storage(state.rules, | |
state.subst, | |
state.nextvar, | |
state.constraints, | |
state.nextcid, | |
state.notes, | |
state.killed.copy()) | |
state.killed[c.id] = c | |
return state | |
def suspend(c, state): | |
killed = state.killed | |
state = Storage(state.rules, | |
state.subst, | |
state.nextvar, | |
copy_constraints(state.constraints), | |
state.nextcid, | |
copy_notes(state.notes), | |
dict()) | |
if c.id not in killed: | |
add_notes(c.term, state, c) | |
table = state.constraints.get(c.term.get_signature(), None) | |
if table is None: | |
state.constraints[c.term.get_signature()] = table = dict() | |
table[c.id] = c | |
for k in killed.itervalues(): | |
table = state.constraints.get(k.term.get_signature(), None) | |
if table is not None: | |
table.pop(c.id, None) | |
return state | |
def add_notes(term, state, c): | |
if isinstance(term, Var): | |
notes = state.notes | |
wakeups = notes.get(term.key, None) | |
if wakeups is None: | |
notes[term.key] = wakeups = dict() | |
wakeups[c.id] = c | |
elif isinstance(term, Term): | |
for arg in term.args: | |
add_notes(arg, state, c) | |
def var_to_const(term, table): | |
if isinstance(term, Var) or isinstance(term, Const): | |
new_const = table.get(term.to_str(), None) | |
if new_const is None: | |
new_const = table[term.to_str()] = Const(len(table)) | |
return new_const, table | |
elif isinstance(term, Term): | |
nargs = [] | |
for arg in term.args: | |
narg, table = var_to_const(arg, table) | |
nargs.append(narg) | |
return Term(term.name, nargs), table | |
else: | |
assert False | |
def const_to_var(term, table, state): | |
if isinstance(term, Var): # The namespace entries may contain variables. | |
return term, table, state | |
elif isinstance(term, Const): | |
new_var = table.get(term.key, None) | |
if new_var is None: | |
var, state = fresh(state) | |
new_var = table[term.key] = var | |
return new_var, table, state | |
elif isinstance(term, Term): | |
nargs = [] | |
for arg in term.args: | |
narg, table, state = const_to_var(arg, table, state) | |
nargs.append(narg) | |
return Term(term.name, nargs), table, state | |
else: | |
assert False | |
def walk(term, subst): | |
if isinstance(term, Var): | |
val = subst.get(term.key, None) | |
if val is None: | |
return term | |
else: | |
return walk(val, subst) | |
elif isinstance(term, Term): | |
nargs = [] | |
for arg in term.args: | |
nargs.append(walk(arg, subst)) | |
return Term(term.name, nargs) | |
elif isinstance(term, Const): | |
return term | |
else: | |
assert False | |
def extS(key, term, state): | |
if occurs(key, term, state.subst): | |
return None | |
else: | |
nnotes = state.notes | |
nsubst = state.subst.copy() | |
nsubst[key] = term | |
if key in state.notes: | |
nnotes = copy_notes(state.notes) | |
wakeups = nnotes.pop(key) | |
else: | |
wakeups = dict() | |
state = Storage(state.rules, | |
nsubst, | |
state.nextvar, | |
state.constraints, | |
state.nextcid, | |
nnotes, | |
state.killed) | |
states = [state] | |
for c in wakeups.values(): | |
nstates = [] | |
for state in states: | |
if not alive(c, state): | |
continue | |
sts = reactivate(Constraint(walk(c.term, state.subst), c.id), state) | |
nstates.extend(sts) | |
states = nstates | |
return states | |
def fresh(state): | |
v = Var(state.nextvar) | |
nstate = Storage(state.rules, | |
state.subst, | |
state.nextvar+1, | |
state.constraints, | |
state.nextcid, | |
state.notes, | |
state.killed) | |
return (v, nstate) | |
class Combinator: | |
def go(self, state): | |
assert False | |
class eq(Combinator): | |
def __init__(self, t1, t2): | |
self.t1 = t1 | |
self.t2 = t2 | |
def go(self, state): | |
res = unify(self.t1, self.t2, state) | |
if len(res) == 0: | |
#import os | |
#os.write(1, "fail\n") | |
#for key, term in state.subst.items(): | |
# os.write(1, "%d = %s\n" % (key, walk(term,state.subst).to_str())) | |
return [] | |
return res | |
def unify(t1, t2, state): | |
t1 = walk(t1, state.subst) | |
t2 = walk(t2, state.subst) | |
#import os | |
#os.write(1, t1.to_str() + " == " + t2.to_str() + "\n") | |
if isinstance(t1, Var) and isinstance(t2, Var) and t1.key == t2.key: | |
return [state] | |
elif isinstance(t1, Var): | |
return extS(t1.key, t2, state) | |
elif isinstance(t2, Var): | |
return extS(t2.key, t1, state) | |
elif isinstance(t1, Term) and isinstance(t2, Term): | |
if t1.name == t2.name and len(t1.args) == len(t2.args): | |
states = [] | |
for i in range(0, len(t1.args)): | |
nstates = [] | |
for state in states: | |
states = unify(t1.args[i], t2.args[i], state) | |
nstates.extend(states) | |
states = nstates | |
return states | |
else: | |
return [] | |
elif isinstance(t1, Const) and isinstance(t2, Const): | |
if t1.key == t2.key: | |
return [state] | |
else: | |
return [] | |
else: | |
assert False | |
def occurs(key, term, subst): | |
term = walk(term, subst) | |
if isinstance(term, Var): | |
return term.key == key | |
elif isinstance(term, Term): | |
for arg in term.args: | |
if occurs(key, arg, subst): | |
return True | |
return False | |
elif isinstance(term, Const): | |
return False | |
else: | |
assert False | |
class Unit(Combinator): | |
def go(self, state): | |
return [state] | |
class Fail(Combinator): | |
def go(self, state): | |
return [] | |
tt = Unit() | |
ff = Fail() | |
class disj(Combinator): | |
def __init__(self, g1, g2): | |
self.g1 = g1 | |
self.g2 = g2 | |
def go(self, state): | |
return self.g1.go(state) + self.g2.go(state) | |
class conj(Combinator): | |
def __init__(self, g1, g2): | |
self.g1 = g1 | |
self.g2 = g2 | |
def go(self, state): | |
res = [] | |
for nstate in self.g1.go(state): | |
res.extend(self.g2.go(nstate)) | |
return res | |
class constraint(Combinator): | |
def __init__(self, term): | |
self.term = term | |
def go(self, state): | |
return make_constraint(self.term, state) | |
def match(term, pattern, table): | |
if isinstance(term, Var) and isinstance(pattern, Var): | |
if term.key == pattern.key: | |
return table | |
elif isinstance(pattern, Hole): | |
other = table.get(pattern.key, None) | |
if other is None: | |
table = table.copy() | |
table[pattern.key] = term | |
return table | |
else: | |
return match(term, other, table) | |
elif isinstance(term, Term) and isinstance(pattern, Term): | |
if term.name == pattern.name and len(term.args) == len(pattern.args): | |
for i in range(0, len(term.args)): | |
table = match(term.args[i], pattern.args[i], table) | |
if table is None: | |
return None | |
return table | |
return None | |
def lookup(sig, state): | |
table = state.constraints.get(sig, None) | |
if table is None: | |
return iter([]) | |
return table.itervalues() | |
def reactivate(c, state): | |
goals = tt | |
live = True | |
for rule in state.rules: | |
holes = rule.filter(c.term.get_signature()) | |
slots = rule.select() | |
for index in holes: | |
for table, row, kills in cartesian(slots, index, c, state): | |
live = True | |
for r in row: | |
live = live and alive(r, state) | |
if not live: | |
continue | |
if not rule.guard(table): | |
continue | |
for k in kills: | |
state = kill(k, state) | |
goals = conj(goals, Handler(rule.body, table)) | |
live = alive(c, state) | |
if not live: | |
break | |
if not live: | |
break | |
if not live: | |
break | |
state = suspend(c, state) | |
return goals.go(state) | |
class Handler(Combinator): | |
def __init__(self, body, binding): | |
self.body = body | |
self.binding = binding | |
def go(self, state): | |
return self.body(self.binding, state) | |
# Cartesian match across all slots. | |
def cartesian(slots, index, c, state): | |
if len(slots) == 0: | |
return [(dict(), [], [])] | |
out = [] | |
for table, row, kills in cartesian(slots[1:len(slots)], index-1, c, state): | |
k, sig, pat = slots[0] | |
if index == 0: | |
tab = match(c.term, pat, table) | |
if tab is None: | |
continue | |
if k: | |
out.append((tab, [c] + row, kills + [c])) | |
else: | |
out.append((tab, [c] + row, kills)) | |
else: | |
for nc in lookup(sig, state): | |
if not alive(nc, state): | |
continue | |
tab = match(nc.term, pat, table) | |
if tab is None: | |
continue | |
if k: | |
out.append((tab, [nc] + row, kills + [c])) | |
else: | |
out.append((tab, [nc] + row, kills)) | |
return out | |
class Rule(object): | |
def __init__(self, keep, remove, guard, body): | |
self.keep = keep | |
self.remove = remove | |
self.guard = guard | |
self.body = body | |
def filter(self, sig): | |
top = len(self.keep) + len(self.remove) - 1 | |
out = [] | |
while top >= 0: | |
if len(self.keep) <= top: | |
if sig == self.remove[top-len(self.keep)].get_signature(): | |
out.append(top) | |
elif top < len(self.keep): | |
if sig == self.keep[top].get_signature(): | |
out.append(top) | |
top -= 1 | |
return out | |
def select(self): | |
out = [] | |
for term in self.keep: | |
out.append((False, term.get_signature(), term)) | |
for term in self.remove: | |
out.append((True, term.get_signature(), term)) | |
return out | |
#def dual_c(args): | |
# assert len(args) == 2 | |
# return | |
# c = self.create("dual", args) | |
# self.store(c) | |
# activate(self, c) | |
def dual(x, y): | |
return Term("dual", [x, y]) | |
def always_true(bindings): | |
return True | |
# The CHR rule handler returns 'nothing' when guard doesn't pass. | |
def chr_rule(keep, remove, guard=always_true): | |
def __decorator__(fn): | |
default_rules.append(Rule(keep, remove, guard, fn)) | |
return fn | |
return __decorator__ | |
X = Hole(0) | |
Y = Hole(1) | |
Z = Hole(2) | |
@chr_rule([], [dual(X, X)]) | |
def rule_0(bindings, state): | |
return ff.go(state) | |
@chr_rule([dual(X, Y)], [dual(Y, Z)]) | |
def rule_1(bindings, state): | |
X = bindings[0] | |
Y = bindings[1] | |
Z = bindings[2] | |
return eq(X, Z).go(state) | |
@chr_rule([dual(X, Y)], [dual(Z, Y)]) | |
def rule_2(bindings, state): | |
X = bindings[0] | |
Y = bindings[1] | |
Z = bindings[2] | |
return eq(X, Z).go(state) | |
@chr_rule([dual(X, Y)], [dual(X, Z)]) | |
def rule_3(bindings, state): | |
X = bindings[0] | |
Y = bindings[1] | |
Z = bindings[2] | |
return eq(Y, Z).go(state) | |
@chr_rule([dual(X, Y)], [dual(Z, X)]) | |
def rule_4(bindings, state): | |
X = bindings[0] | |
Y = bindings[1] | |
Z = bindings[2] | |
return eq(Y, Z).go(state) | |
def tensor(x, y): | |
return Term("⊗", [x, y]) | |
def par(x, y): | |
return Term("⅋", [x, y]) | |
def plus(x, y): | |
return Term("+", [x, y]) | |
def band(x, y): | |
return Term("&", [x, y]) | |
def ofc(x): | |
return Term("!", [x]) | |
def que(x): | |
return Term("?", [x]) | |
unit = Term("1", []) | |
zero = Term("0", []) | |
top = Term("⊤", []) | |
bot = Term("⊥", []) | |
X = Hole(0) | |
Y = Hole(1) | |
Z = Hole(2) | |
@chr_rule([], [dual(tensor(X,Y),Z)]) | |
@chr_rule([], [dual(Z,tensor(X,Y))]) | |
def rule_5(bindings, state): | |
C, state = fresh(state) | |
D, state = fresh(state) | |
X,Y,Z = bindings[0],bindings[1],bindings[2] | |
return conj( | |
eq(Z, par(C,D)), | |
conj( | |
constraint(dual(X,C)), | |
constraint(dual(Y,D)))).go(state) | |
@chr_rule([], [dual(Z, par(X,Y))]) | |
@chr_rule([], [dual(par(X,Y), Z)]) | |
def rule_6(bindings, state): | |
C, state = fresh(state) | |
D, state = fresh(state) | |
X,Y,Z = bindings[0],bindings[1],bindings[2] | |
return conj( | |
eq(Z,tensor(C,D)), | |
conj( | |
constraint(dual(X,C)), | |
constraint(dual(Y,D)))).go(state) | |
## #rule_5 @ dual(plus(A,B), Z) <=> Z=band(C,D), dual(A,C), dual(B,D). | |
## #rule_6 @ dual(Z, band(A,B)) <=> Z=plus(C,D), dual(A,C), dual(B,D). | |
## #rule_7 @ dual(Z, plus(A,B)) <=> dual(band(A,B), Z). | |
## #rule_8 @ dual(band(A,B), Z) <=> dual(Z, plus(A,B)). | |
## # | |
## #rule_10 @ dual(ofc(A), Z) <=> Z=que(C), dual(A,C). | |
## #rule_11 @ dual(Z, que(A)) <=> Z=ofc(C), dual(A,C). | |
## #rule_12 @ dual(Z, ofc(A)) <=> dual(ofc(A), Z). | |
## #rule_13 @ dual(que(A), Z) <=> dual(Z, que(A)). | |
## # | |
## #rule_14 @ dual(unit, Z) <=> Z=bot. | |
## #rule_15 @ dual(Z, bot) <=> Z=unit. | |
## #rule_16 @ dual(Z, unit) <=> Z=bot. | |
## #rule_17 @ dual(bot, Z) <=> Z=unit. | |
## # | |
## #rule_18 @ dual(zero, Z) <=> Z=top. | |
## #rule_19 @ dual(Z, top) <=> Z=zero. | |
## #rule_20 @ dual(Z, zero) <=> Z=top. | |
## #rule_21 @ dual(top, Z) <=> Z=zero. | |
def demonstration(): | |
state = empty() | |
x, state = fresh(state) | |
y, state = fresh(state) | |
z, state = fresh(state) | |
w, state = fresh(state) | |
h, state = fresh(state) | |
#goal = constraint(Term("dual", [x, y])) | |
goal = conj(constraint(Term("dual", [x, y])), | |
constraint(Term("dual", [tensor(z,w), h]))) | |
#goal = constraint(Term("dual", [tensor(z,w), h])) | |
import os | |
os.write(1, "results for\n") | |
for state in goal.go(state): | |
os.write(1, "result:\n") | |
for var in state.subst: | |
os.write(1, " %d = %s\n" % (var, state.subst[var].to_str())) | |
for cs in state.constraints.itervalues(): | |
for c in cs.itervalues(): | |
if alive(c, state): | |
os.write(1, " %s\n" % c.term.to_str()) | |
if __name__=="__main__": | |
demonstration() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment