Skip to content

Instantly share code, notes, and snippets.

@cheery
Last active January 29, 2022 20:21
Show Gist options
  • Save cheery/285e81d55ca4d9ec44d628978545a866 to your computer and use it in GitHub Desktop.
Save cheery/285e81d55ca4d9ec44d628978545a866 to your computer and use it in GitHub Desktop.
ukanren with untested CHR
# -*- encoding: utf-8 -*-
class TermOrVar(object):
def to_str(self, rbp=0):
assert False
class Term(TermOrVar):
def __init__(self, name, args):
self.name = name
self.args = args
for arg in args:
assert isinstance(arg, TermOrVar)
def to_str(self, rbp=0):
sig = self.get_signature()
if rbp > get_rbp(sig):
return "(" + self.to_str() + ")"
elif len(self.args) == 0:
return self.name
elif sig == "./2":
a = self.args[0].to_str(11)
b = self.args[1].to_str(10)
return a + " " + self.name + " " + b
elif sig == "⊸/2":
a = self.args[0].to_str(21)
b = self.args[1].to_str(20)
return a + " " + self.name + " " + b
elif sig == "+/2":
a = self.args[0].to_str(31)
b = self.args[1].to_str(30)
return a + " " + self.name + " " + b
elif sig == "⅋/2":
a = self.args[0].to_str(36)
b = self.args[1].to_str(35)
return a + " " + self.name + " " + b
elif sig == "⊗/2":
a = self.args[0].to_str(41)
b = self.args[1].to_str(40)
return a + " " + self.name + " " + b
elif sig == "!/1":
a = self.args[0].to_str(51)
return self.name + a
else:
args = ""
sp = " "
for arg in self.args:
args += sp + "(" + arg.to_str() + ")"
sp = " "
return self.name + args
def get_signature(self):
return self.name + "/" + str(len(self.args))
def get_rbp(sig):
if sig == "./2":
return 10
if sig == "⊸/2":
return 20
if sig == "+/2":
return 30
if sig == "⅋/2":
return 35
if sig == "⊗/2":
return 40
if sig == "!/1":
return 50
return 60
class Var(TermOrVar):
def __init__(self, key):
self.key = key
def to_str(self, rbp=0):
return "#" + str(self.key)
class Const(TermOrVar):
def __init__(self, key):
self.key = key
def to_str(self, rbp=0):
return "$" + str(self.key)
# Pattern hole for constraint handling rule patterns.
class Hole(TermOrVar):
def __init__(self, key):
self.key = key
def to_str(self, rbp=0):
return "$HOLE{" + str(self.key) + "}"
class Storage(object):
def __init__(self, rules, subst, nextvar, constraints, nextcid, notes, killed):
self.rules = rules
self.subst = subst
self.nextvar = nextvar
self.constraints = constraints
self.nextcid = nextcid
self.notes = notes
self.killed = killed
default_rules = [] # Filled later.
def empty(rules=default_rules):
return Storage(rules, dict(), 0, dict(), 0, dict(), dict())
def copy_constraints(constraints):
nconstraints = dict()
for sig, table in constraints.items():
nconstraints[sig] = table.copy()
return nconstraints
def copy_notes(notes):
nnotes = dict()
for key, table in notes.items():
nnotes[key] = table.copy()
return nnotes
class Constraint(object):
def __init__(self, term, id):
self.term = term
self.id = id
def make_constraint(term, state):
assert isinstance(term, Term)
c = Constraint(walk(term, state.subst), state.nextcid)
state = Storage(state.rules,
state.subst,
state.nextvar,
state.constraints,
state.nextcid+1,
state.notes,
state.killed)
return reactivate(c, state)
def alive(c, state):
if c.id in state.killed:
return False
return True
# We won't access constraints outside our scope.
#table = state.constraints.get(c.term.get_signature(), None)
#if table is not None:
# if c.id in table:
# return True
#return False
def kill(c, state):
state = Storage(state.rules,
state.subst,
state.nextvar,
state.constraints,
state.nextcid,
state.notes,
state.killed.copy())
state.killed[c.id] = c
return state
def suspend(c, state):
killed = state.killed
state = Storage(state.rules,
state.subst,
state.nextvar,
copy_constraints(state.constraints),
state.nextcid,
copy_notes(state.notes),
dict())
if c.id not in killed:
add_notes(c.term, state, c)
table = state.constraints.get(c.term.get_signature(), None)
if table is None:
state.constraints[c.term.get_signature()] = table = dict()
table[c.id] = c
for k in killed.itervalues():
table = state.constraints.get(k.term.get_signature(), None)
if table is not None:
table.pop(c.id, None)
return state
def add_notes(term, state, c):
if isinstance(term, Var):
notes = state.notes
wakeups = notes.get(term.key, None)
if wakeups is None:
notes[term.key] = wakeups = dict()
wakeups[c.id] = c
elif isinstance(term, Term):
for arg in term.args:
add_notes(arg, state, c)
def var_to_const(term, table):
if isinstance(term, Var) or isinstance(term, Const):
new_const = table.get(term.to_str(), None)
if new_const is None:
new_const = table[term.to_str()] = Const(len(table))
return new_const, table
elif isinstance(term, Term):
nargs = []
for arg in term.args:
narg, table = var_to_const(arg, table)
nargs.append(narg)
return Term(term.name, nargs), table
else:
assert False
def const_to_var(term, table, state):
if isinstance(term, Var): # The namespace entries may contain variables.
return term, table, state
elif isinstance(term, Const):
new_var = table.get(term.key, None)
if new_var is None:
var, state = fresh(state)
new_var = table[term.key] = var
return new_var, table, state
elif isinstance(term, Term):
nargs = []
for arg in term.args:
narg, table, state = const_to_var(arg, table, state)
nargs.append(narg)
return Term(term.name, nargs), table, state
else:
assert False
def walk(term, subst):
if isinstance(term, Var):
val = subst.get(term.key, None)
if val is None:
return term
else:
return walk(val, subst)
elif isinstance(term, Term):
nargs = []
for arg in term.args:
nargs.append(walk(arg, subst))
return Term(term.name, nargs)
elif isinstance(term, Const):
return term
else:
assert False
def extS(key, term, state):
if occurs(key, term, state.subst):
return None
else:
nnotes = state.notes
nsubst = state.subst.copy()
nsubst[key] = term
if key in state.notes:
nnotes = copy_notes(state.notes)
wakeups = nnotes.pop(key)
else:
wakeups = dict()
state = Storage(state.rules,
nsubst,
state.nextvar,
state.constraints,
state.nextcid,
nnotes,
state.killed)
states = [state]
for c in wakeups.values():
nstates = []
for state in states:
if not alive(c, state):
continue
sts = reactivate(Constraint(walk(c.term, state.subst), c.id), state)
nstates.extend(sts)
states = nstates
return states
def fresh(state):
v = Var(state.nextvar)
nstate = Storage(state.rules,
state.subst,
state.nextvar+1,
state.constraints,
state.nextcid,
state.notes,
state.killed)
return (v, nstate)
class Combinator:
def go(self, state):
assert False
class eq(Combinator):
def __init__(self, t1, t2):
self.t1 = t1
self.t2 = t2
def go(self, state):
res = unify(self.t1, self.t2, state)
if len(res) == 0:
#import os
#os.write(1, "fail\n")
#for key, term in state.subst.items():
# os.write(1, "%d = %s\n" % (key, walk(term,state.subst).to_str()))
return []
return res
def unify(t1, t2, state):
t1 = walk(t1, state.subst)
t2 = walk(t2, state.subst)
#import os
#os.write(1, t1.to_str() + " == " + t2.to_str() + "\n")
if isinstance(t1, Var) and isinstance(t2, Var) and t1.key == t2.key:
return [state]
elif isinstance(t1, Var):
return extS(t1.key, t2, state)
elif isinstance(t2, Var):
return extS(t2.key, t1, state)
elif isinstance(t1, Term) and isinstance(t2, Term):
if t1.name == t2.name and len(t1.args) == len(t2.args):
states = []
for i in range(0, len(t1.args)):
nstates = []
for state in states:
states = unify(t1.args[i], t2.args[i], state)
nstates.extend(states)
states = nstates
return states
else:
return []
elif isinstance(t1, Const) and isinstance(t2, Const):
if t1.key == t2.key:
return [state]
else:
return []
else:
assert False
def occurs(key, term, subst):
term = walk(term, subst)
if isinstance(term, Var):
return term.key == key
elif isinstance(term, Term):
for arg in term.args:
if occurs(key, arg, subst):
return True
return False
elif isinstance(term, Const):
return False
else:
assert False
class Unit(Combinator):
def go(self, state):
return [state]
class Fail(Combinator):
def go(self, state):
return []
tt = Unit()
ff = Fail()
class disj(Combinator):
def __init__(self, g1, g2):
self.g1 = g1
self.g2 = g2
def go(self, state):
return self.g1.go(state) + self.g2.go(state)
class conj(Combinator):
def __init__(self, g1, g2):
self.g1 = g1
self.g2 = g2
def go(self, state):
res = []
for nstate in self.g1.go(state):
res.extend(self.g2.go(nstate))
return res
class constraint(Combinator):
def __init__(self, term):
self.term = term
def go(self, state):
return make_constraint(self.term, state)
def match(term, pattern, table):
if isinstance(term, Var) and isinstance(pattern, Var):
if term.key == pattern.key:
return table
elif isinstance(pattern, Hole):
other = table.get(pattern.key, None)
if other is None:
table = table.copy()
table[pattern.key] = term
return table
else:
return match(term, other, table)
elif isinstance(term, Term) and isinstance(pattern, Term):
if term.name == pattern.name and len(term.args) == len(pattern.args):
for i in range(0, len(term.args)):
table = match(term.args[i], pattern.args[i], table)
if table is None:
return None
return table
return None
def lookup(sig, state):
table = state.constraints.get(sig, None)
if table is None:
return iter([])
return table.itervalues()
def reactivate(c, state):
goals = tt
live = True
for rule in state.rules:
holes = rule.filter(c.term.get_signature())
slots = rule.select()
for index in holes:
for table, row, kills in cartesian(slots, index, c, state):
live = True
for r in row:
live = live and alive(r, state)
if not live:
continue
if not rule.guard(table):
continue
for k in kills:
state = kill(k, state)
goals = conj(goals, Handler(rule.body, table))
live = alive(c, state)
if not live:
break
if not live:
break
if not live:
break
state = suspend(c, state)
return goals.go(state)
class Handler(Combinator):
def __init__(self, body, binding):
self.body = body
self.binding = binding
def go(self, state):
return self.body(self.binding, state)
# Cartesian match across all slots.
def cartesian(slots, index, c, state):
if len(slots) == 0:
return [(dict(), [], [])]
out = []
for table, row, kills in cartesian(slots[1:len(slots)], index-1, c, state):
k, sig, pat = slots[0]
if index == 0:
tab = match(c.term, pat, table)
if tab is None:
continue
if k:
out.append((tab, [c] + row, kills + [c]))
else:
out.append((tab, [c] + row, kills))
else:
for nc in lookup(sig, state):
if not alive(nc, state):
continue
tab = match(nc.term, pat, table)
if tab is None:
continue
if k:
out.append((tab, [nc] + row, kills + [c]))
else:
out.append((tab, [nc] + row, kills))
return out
class Rule(object):
def __init__(self, keep, remove, guard, body):
self.keep = keep
self.remove = remove
self.guard = guard
self.body = body
def filter(self, sig):
top = len(self.keep) + len(self.remove) - 1
out = []
while top >= 0:
if len(self.keep) <= top:
if sig == self.remove[top-len(self.keep)].get_signature():
out.append(top)
elif top < len(self.keep):
if sig == self.keep[top].get_signature():
out.append(top)
top -= 1
return out
def select(self):
out = []
for term in self.keep:
out.append((False, term.get_signature(), term))
for term in self.remove:
out.append((True, term.get_signature(), term))
return out
#def dual_c(args):
# assert len(args) == 2
# return
# c = self.create("dual", args)
# self.store(c)
# activate(self, c)
def dual(x, y):
return Term("dual", [x, y])
def always_true(bindings):
return True
# The CHR rule handler returns 'nothing' when guard doesn't pass.
def chr_rule(keep, remove, guard=always_true):
def __decorator__(fn):
default_rules.append(Rule(keep, remove, guard, fn))
return fn
return __decorator__
X = Hole(0)
Y = Hole(1)
Z = Hole(2)
@chr_rule([], [dual(X, X)])
def rule_0(bindings, state):
return ff.go(state)
@chr_rule([dual(X, Y)], [dual(Y, Z)])
def rule_1(bindings, state):
X = bindings[0]
Y = bindings[1]
Z = bindings[2]
return eq(X, Z).go(state)
@chr_rule([dual(X, Y)], [dual(Z, Y)])
def rule_2(bindings, state):
X = bindings[0]
Y = bindings[1]
Z = bindings[2]
return eq(X, Z).go(state)
@chr_rule([dual(X, Y)], [dual(X, Z)])
def rule_3(bindings, state):
X = bindings[0]
Y = bindings[1]
Z = bindings[2]
return eq(Y, Z).go(state)
@chr_rule([dual(X, Y)], [dual(Z, X)])
def rule_4(bindings, state):
X = bindings[0]
Y = bindings[1]
Z = bindings[2]
return eq(Y, Z).go(state)
def tensor(x, y):
return Term("⊗", [x, y])
def par(x, y):
return Term("⅋", [x, y])
def plus(x, y):
return Term("+", [x, y])
def band(x, y):
return Term("&", [x, y])
def ofc(x):
return Term("!", [x])
def que(x):
return Term("?", [x])
unit = Term("1", [])
zero = Term("0", [])
top = Term("⊤", [])
bot = Term("⊥", [])
X = Hole(0)
Y = Hole(1)
Z = Hole(2)
@chr_rule([], [dual(tensor(X,Y),Z)])
@chr_rule([], [dual(Z,tensor(X,Y))])
def rule_5(bindings, state):
C, state = fresh(state)
D, state = fresh(state)
X,Y,Z = bindings[0],bindings[1],bindings[2]
return conj(
eq(Z, par(C,D)),
conj(
constraint(dual(X,C)),
constraint(dual(Y,D)))).go(state)
@chr_rule([], [dual(Z, par(X,Y))])
@chr_rule([], [dual(par(X,Y), Z)])
def rule_6(bindings, state):
C, state = fresh(state)
D, state = fresh(state)
X,Y,Z = bindings[0],bindings[1],bindings[2]
return conj(
eq(Z,tensor(C,D)),
conj(
constraint(dual(X,C)),
constraint(dual(Y,D)))).go(state)
## #rule_5 @ dual(plus(A,B), Z) <=> Z=band(C,D), dual(A,C), dual(B,D).
## #rule_6 @ dual(Z, band(A,B)) <=> Z=plus(C,D), dual(A,C), dual(B,D).
## #rule_7 @ dual(Z, plus(A,B)) <=> dual(band(A,B), Z).
## #rule_8 @ dual(band(A,B), Z) <=> dual(Z, plus(A,B)).
## #
## #rule_10 @ dual(ofc(A), Z) <=> Z=que(C), dual(A,C).
## #rule_11 @ dual(Z, que(A)) <=> Z=ofc(C), dual(A,C).
## #rule_12 @ dual(Z, ofc(A)) <=> dual(ofc(A), Z).
## #rule_13 @ dual(que(A), Z) <=> dual(Z, que(A)).
## #
## #rule_14 @ dual(unit, Z) <=> Z=bot.
## #rule_15 @ dual(Z, bot) <=> Z=unit.
## #rule_16 @ dual(Z, unit) <=> Z=bot.
## #rule_17 @ dual(bot, Z) <=> Z=unit.
## #
## #rule_18 @ dual(zero, Z) <=> Z=top.
## #rule_19 @ dual(Z, top) <=> Z=zero.
## #rule_20 @ dual(Z, zero) <=> Z=top.
## #rule_21 @ dual(top, Z) <=> Z=zero.
def demonstration():
state = empty()
x, state = fresh(state)
y, state = fresh(state)
z, state = fresh(state)
w, state = fresh(state)
h, state = fresh(state)
#goal = constraint(Term("dual", [x, y]))
goal = conj(constraint(Term("dual", [x, y])),
constraint(Term("dual", [tensor(z,w), h])))
#goal = constraint(Term("dual", [tensor(z,w), h]))
import os
os.write(1, "results for\n")
for state in goal.go(state):
os.write(1, "result:\n")
for var in state.subst:
os.write(1, " %d = %s\n" % (var, state.subst[var].to_str()))
for cs in state.constraints.itervalues():
for c in cs.itervalues():
if alive(c, state):
os.write(1, " %s\n" % c.term.to_str())
if __name__=="__main__":
demonstration()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment