Skip to content

Instantly share code, notes, and snippets.

@evertheylen
Created November 15, 2016 17:40
Show Gist options
  • Save evertheylen/4df34243635754b412615cac02a41119 to your computer and use it in GitHub Desktop.
Save evertheylen/4df34243635754b412615cac02a41119 to your computer and use it in GitHub Desktop.
AI CSP
import inspect
from collections import defaultdict
from itertools import combinations
from util import *
latex = False
class Constraint:
def __init__(self, variables, func, doc=""):
self.variables = variables # ordered!
self.func = func
self.doc = doc
def valid(self, assignment):
args = [pick_first(d) for d in (assignment.domains[var] for var in self.variables) if len(d) == 1]
if len(args) == len(self.variables):
return self.func(*args)
else:
return True # Unsure, so continue
def is_binary(self):
return len(self.variables) == 2
def is_unary(self):
return len(self.variables) == 1
# functional!
class Solution:
@classmethod
def from_csp(cls, csp):
return cls(set(csp.variables.keys()), dict_value_copy(csp.variables), csp)
def __init__(self, unassigned, domains, csp):
self.unassigned = unassigned
self.domains = domains
self.csp = csp
def is_complete(self):
return len(self.unassigned) == 0
def can_continue(self):
return self.csp.valid(self)
def has_empty_domain(self):
return any(len(v) == 0 for v in self.domains.values())
def copy(self):
return type(self)(self.unassigned.copy(), dict_value_copy(self.domains), self.csp)
def assign(self, x, x_val):
new = self.copy()
new.domains[x] = {x_val}
new.unassigned.remove(x)
return new
def recheck_domains(self):
return self
def _latex_str(self):
"RST/latex string"
l = []
l.append(".. math::")
l.append("\t\\begin{aligned}")
for k, v in sorted(self.domains.items(), key=str):
l.append("\t{} & \\rightarrow \\{{ {} \\}} \\\\".format(k, ", ".join(
"\\text{{{}}}".format(i) for i in v)))
l.append("\t\end{aligned}")
return "\n".join(l)
def _normal_str(self):
l = []
for k, v in sorted(self.domains.items(), key=str):
l.append("{}\t --> {}".format(k, v))
return "\n".join(l)
__str__ = _latex_str if latex else _normal_str
class ForwardCheckingSolution(Solution):
# Warning: only for binary constraints!
def can_continue(self):
if self.csp.all_binary:
return not self.has_empty_domain()
else:
return super().can_continue()
def assign(self, x, x_val):
new = super().assign(x, x_val)
# forward checking
for y in new.unassigned:
for c in self.csp.constraint_info.get(x, y):
if c.is_binary():
for y_val in new.domains[y].copy():
if not c.func(**{x:x_val, y:y_val}):
new.domains[y].remove(y_val)
return new
class AC3Solution(Solution):
def recheck_domains(self, *, Q=None):
new = self.copy()
contradiction = False
if Q is None:
Q = list(self.domains.keys())
while len(Q) != 0 and not contradiction:
#print("\nStarting AC3 iteration. Q = {}".format(Q))
x = Q.pop()
#print("x = {}".format(x))
for y, constraint in new.csp.constraint_info.related_to(x):
#print(" y = {}, constraint.vars = {}".format(y, constraint.variables))
if constraint.is_binary() and y in new.unassigned and new.remove_values(x, y, constraint):
#print(" Removed values!")
if len(new.domains[y]) == 0: contradiction = True
Q.insert(0, y)
return new
# mutates
def remove_values(self, x, y, constraint):
removed = False
for y_val in self.domains[y].copy():
#print(" checking x = {} in {}, y = {} = {}".format(x, self.domains[x], y, y_val))
if not any(constraint.func(**{x:x_val, y:y_val}) for x_val in self.domains[x]):
self.domains[y].remove(y_val)
removed = True
return removed
class AC3ForwardSolution(AC3Solution, ForwardCheckingSolution):
def recheck_domains(self):
# Q can just be the uninitialised values if we do forward checking
return self.recheck_domains(Q=self.uninitialised.copy())
class ConstraintInfo:
def __init__(self, constraints, max_len = 2):
self.vars_to_cons = multimap()
self.var_to_var = defaultdict(multimap)
for c in constraints:
for l in range(1, max_len+1):
for p in combinations(sorted(c.variables), r=l):
self.vars_to_cons[p].add(c)
for x in c.variables:
for y in c.variables:
if x != y:
self.var_to_var[x][y].add(c)
def get(self, *vars):
return self.vars_to_cons[tuple(sorted(vars))]
def related_to(self, x):
# returns every variable y related to x by a constraint (which is also returned)
return self.var_to_var[x].flat_items()
# Strategies
class SelectVar:
def first(csp, unassigned):
return pick_first(unassigned)
class OrderDomain:
def same(csp, domain):
return domain
class CSP(DotOutput):
def __init__(self, problem):
self.problem = problem
self.variables = problem.variables.copy()
self.constraints = []
for lam in self.problem.constraints:
if isinstance(lam, Constraint):
self.constraints.append(lam)
else:
spec = inspect.getargspec(lam)
self.constraints.append(Constraint(spec.args, lam))
self.all_binary = all(c.is_binary() for c in self.constraints)
self.constraint_info = ConstraintInfo(self.constraints)
def valid(self, assignment):
return all(c.valid(assignment) for c in self.constraints)
def solve(self, cls = ForwardCheckingSolution,
select_var = SelectVar.first,
order_domain = OrderDomain.same):
return self._solve(Solution.from_csp(self), select_var, order_domain)
def _solve(self, A, select_var, order_domain):
#print("\n_solve called, A =")
#print(A)
if A.is_complete():
assert self.valid(A)
return A
A = A.recheck_domains()
if A is None:
return None
x = select_var(self, A.unassigned)
#print("selected x = ", x)
for val in order_domain(self, A.domains[x]):
new_A = A.assign(x, val)
#print("{} = {}".format(x, val))
if new_A.can_continue():
#print("can continue!")
result = self._solve(new_A, select_var, order_domain)
if result is not None:
return result
return None
# Dot output methods
def get_dot_nodes(self):
for v in self.variables.keys():
yield Node(v)
def get_dot_transitions(self):
for c in self.constraints:
if c.is_binary():
yield Transition(c.variables[0], c.variables[1])
else:
print("WARNING: can't draw a non-binary constraint")
from csp import *
def question(i):
header = "Question {}".format(i)
line = "-"*len(header)
print("\n\n" + header + "\n" + line + "\n")
class Problem:
variables = {
"S": {"to", "ch", "pu"},
"M": {"me", "co", "pi"},
"V": {"qu", "ta", "sr"},
"G": {"sa", "st", "tu"},
}
constraints = [
lambda M, S: S == "pu" or M == "me" or M == "co",
lambda S, V: (S != "to") or (V != "ta" and V != "sr"),
lambda V, G: not (V == "qu" and (G == "st" or G == "tu"))
]
c = CSP(Problem())
print("Domains:")
print(Solution.from_csp(c))
question(1)
c.save_dot("constraint_graph.dot")
print("see constraint_graph.dot")
question(2)
A = ForwardCheckingSolution.from_csp(c)
print(A.assign("S", "to"))
question(3)
A = AC3Solution.from_csp(c)
print(A.assign("S", "to").recheck_domains())
question(4)
print("Simple:")
print(c.solve())
print("\nForward checking:")
print(c.solve(ForwardCheckingSolution))
print("\nAC3:")
print(c.solve(AC3Solution))
print("\nBoth:")
print(c.solve(AC3ForwardSolution))
from typing import List
from itertools import groupby, chain, combinations
from collections import defaultdict
# Etc
# ===
def dict_value_copy(d):
return {k: v.copy() for k, v in d.items()}
def pick_first(it):
for i in it:
return i
# Multimap
# ========
class multimap(defaultdict):
def __init__(self, *a, **kw):
super().__init__(set, *a, **kw)
@classmethod
def from_pairs(cls, l: list):
dct = cls()
for k, v in l:
dct[k].add(v)
return dct
def flat_items(self):
for k, values in self.items():
for v in values:
yield k, v
def flat_values(self):
for values in self.values():
for v in values:
yield v
def flat_len(self):
s = 0
for v in self.values():
s += len(v)
return s
def flatten(self):
d = {}
for k, v in self.items():
if len(v) != 1:
raise NotFlat(v)
d[k] = v.pop()
return d
# Dot helpers
# ===========
class Node:
def __init__(self, name):
self.name = name
def to_dot(self):
return 'node [shape=circle] "{name}";'.format(name = self.name)
class Transition:
def __init__(self, frm, to, label: str = None):
self.frm = frm
self.to = to
self.label = label
def to_dot(self):
s = '"{}" -- "{}"'.format(self.frm, self.to)
s += ' [label="{}"]'.format(self.label) if self.label else '' + ';'
return s
def key(self):
return (self.frm, self.to)
def dot_str(obj):
if isinstance(obj, (str, String)):
if len(obj) == 0:
return "ε"
elif isinstance(obj, String):
return str(obj).strip("`")
return str(obj)
dot_fmt = """
graph csp {{
rankdir=LR
{nodes}
{transitions}
}}
"""
class Dot:
def __init__(self, nodes = [], transitions = []):
self.nodes = nodes
self.transitions = transitions
def to_dot(self):
nodes = "\n".join(" " + n.to_dot() for n in self.nodes)
_transitions = sorted(self.transitions, key = lambda t: t.key())
_transitions = [Transition(k[0], k[1], "\\n".join(t.label for t in g if t.label is not None))
for k, g in groupby(_transitions, key = lambda t: t.key())]
transitions = "\n".join(" " + t.to_dot() for t in _transitions)
return dot_fmt.format(nodes = nodes, transitions = transitions)
class DotOutput:
def save_dot(self, fname):
d = Dot(list(self.get_dot_nodes()), list(self.get_dot_transitions()))
with open(fname, "w") as f:
f.write(d.to_dot())
def get_dot_nodes(self) -> List[Node]:
return [Node("default")]
def get_dot_transitions(self) -> List[Transition]:
return []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment