Create a gist now

Instantly share code, notes, and snippets.

SAT-based KenKen solver.
"""
A SAT-based KenKen (http://www.kenken.com/) solver.
The implementation of this solver is based on the ideas contained in
the paper:
"A SAT-based Sudoku solver" by Tjark Weber
https://www.lri.fr/~conchon/mpri/weber.pdf
and a Python implementation of a SAT-based Sudoku solver which is found
in the pycosat examples:
https://pypi.python.org/pypi/pycosat
This module requires Python 3.3 or later.
"""
import itertools
import pycosat
########################################################################
# Exceptions
#
class InvalidPuzzle(Exception):
"""The puzzle specification is invalid."""
pass
class UnsatisfyablePuzzle(Exception):
"""The puzzle is unsatisfyable."""
pass
########################################################################
# Partitioners
#
# A partitioner yields all possible partitions of `result` into `nparts`
# according to a specified operation, where each value is in the range
# `1..dim`.
#
# See the docstring of `partition` for some examples.
#
def _partition_add(nparts, result, dim):
if nparts == 0:
raise ValueError("Expected nparts >= 1")
elif nparts == 1:
if 1 <= result <= dim:
yield (result, )
else:
raise StopIteration
else:
for i1 in range(1, dim+1):
for i2 in _partition_add(nparts-1, result-i1, dim):
yield (i1,) + i2
def _partition_mul(nparts, result, dim):
if nparts == 0:
raise ValueError("Expected nparts >= 1")
elif nparts == 1:
if 1 <= result <= dim:
yield (result, )
else:
raise StopIteration
else:
for i1 in range(1, dim+1):
if result % i1 != 0:
continue
for i2 in _partition_mul(nparts-1, int(result//i1), dim):
yield (i1,) + i2
def _partition_sub(nparts, result, dim):
if nparts != 2:
raise ValueError("Expected nparts = 2.")
for i1 in range(1, dim+1):
for i2 in (i1-result, i1+result):
if 1 <= i2 <= dim:
yield(i1, i2)
def _partition_div(nparts, result, dim):
if nparts != 2:
raise ValueError("Expected nparts = 2.")
for i1 in range(1, dim+1):
if i1 % result == 0:
i2 = int(i1 // result)
if 1 <= i2 <= dim:
yield (i1, i2)
i2 = result * i1
if 1 <= i2 <= dim:
yield (i1, i2)
def _partition_eq(nparts, result, dim):
if nparts != 1:
raise ValueError("Expected nparts = 1.")
if 1 <= result <= dim:
yield (result,)
_partitioners = {
'+': _partition_add,
'*': _partition_mul,
'-': _partition_sub,
'/': _partition_div,
'!': _partition_eq,
'=': _partition_eq,
}
def partition(op, nparts, result, dim):
"""Partition `result` into `nparts` each in the range `1..dim`
which can be obtained using the operation `op`.
For example, all pairs of numbers in the range 1..6 whose ratio is 3:
>>> sorted(partition('/', 2, 3, 6))
[(1, 3), (2, 6), (3, 1), (6, 2)]
All triples of numbers in the range 1..6 whose product is 4:
>>> sorted(partition('*', 3, 4, 6))
[(1, 1, 4), (1, 2, 2), (1, 4, 1), (2, 1, 2), (2, 2, 1), (4, 1, 1)]
All single numbers which are exactly equal to 4:
>>> sorted(partition('=', 1, 4, 6))
[(4,)]
"""
yield from _partitioners[op](nparts, result, dim)
########################################################################
# A cage is a sub-region of the puzzle with a specified arithmetic
# operation and resulting value.
#
class Cage:
"""A set of cells which achives a value by an arithmetic operation."""
def __init__(self, op, value, cells):
self.op = op
self.value = value
self.cells = cells
def __str__(self):
return "%s(%r, %s, %s)" % (
self.__class__.__name__, self.op, self.value, self.cells)
def dnf_clauses(self, dim, variable):
"""Yield clauses in disjunctive normal form which satisfy the
cage constraints.
Essentially this returns clauses which correspond to
(digit(cell[0]) == sol1[0] AND digit(cell[1]) == sol1[1] AND ...)
OR (digit(cell[0]) == sol2[0] AND digit(cell[1]) == sol2[1] AND ...)
OR (digit(cell[0]) == sol3[0] AND digit(cell[1]) == sol3[1] AND ...)
OR (digit(cell[0]) == sol4[0] AND digit(cell[1]) == sol4[1] AND ...)
OR ...
where sol1, sol2, sol3, sol4, ... are all possible local solutions
to the cage constraints, ignoring any global constraints (like
no duplicates in rows or columns).
Parameters
----------
dim : int
Puzzle dimension
variable : callable, as variable(i, j, d)
Returns the variable number corresponding to `d` in cell `(i,j)`.
"""
for vals in partition(self.op, len(self.cells), self.value, dim):
yield tuple(variable(cell[0], cell[1], val)
for cell, val in zip(self.cells, vals))
########################################################################
# Main Entry Point
#
class KenKenPuzzle:
"""A KenKen puzzle."""
def __init__(self, size, cages=None):
self.size = size
self._cages = cages or []
@classmethod
def from_text(cls, text):
"""Instantiate a puzzle from a text description."""
lines = iter(text.splitlines())
for line in lines:
line = line.strip()
if not line:
continue
op, value = line.split()
if op != '#':
raise InvalidPuzzleSpecification
size = int(value)
break
else:
raise InvalidPuzzleSpecification
puzzle = cls(size)
for line in lines:
line = line.strip()
if not line:
continue
puzzle.add_text_cage(line)
puzzle.assert_valid()
return puzzle
def assert_valid(self):
"""Checks that each cell belongs to exactly one cage."""
visited = set()
for cage in self._cages:
frontier = set(cage.cells)
if not visited.isdisjoint(frontier):
raise InvalidPuzzle("Duplicate cells: %s"
% (visited.intersection(frontier),))
visited.update(frontier)
expected = set(itertools.product(range(1, self.size+1),
range(1, self.size+1)))
missing = expected.difference(visited)
unknown = visited.difference(expected)
if missing or unknown:
messages = []
if missing:
messages.append("Missing cells: %s" % (missing,))
if unknown:
messages.append("Unknown cells: %s" % (unknown,))
raise InvalidPuzzle(" ".join(messages))
def add_cage(self, op, result, cells):
"""Add a cage."""
self._cages.append(Cage(op, result, cells))
def _cell_as_tuple(self, cell):
"""Convert cell notation from text form to tuple form.
For example, 'B7' -> (2,7).
"""
row, col = cell[:1], cell[1:]
row = ord(row)-ord('A')+1
col = int(col)
return row, col
def add_text_cage(self, text):
"""Add a textual representation of a cage."""
op, result, *cells = text.split()
result = int(result)
cells = tuple(self._cell_as_tuple(cell) for cell in cells)
self.add_cage(op, result, cells)
def variable(self, i, j, d):
"""Return the number of the variable which corresponds to cell (i, j)
containing digit d.
"""
n = self.size
assert 1 <= i <= n
assert 1 <= j <= n
assert 1 <= d <= n
return n*n * (i - 1) + n * (j - 1) + d
def _unvariable(self, v):
"""Return the cell and digit corresponding to v."""
n = self.size
i = v // (n*n) + 1
v = v - n*n*(i-1)
j = v // n + 1
v = v - n*(j-1)
d = v
return ((i,j), d)
def clauses(self):
"""Yield all clauses for the puzzle."""
n = self.size
v = self.variable
# For all cells, ensure that the each cell:
for i in range(1, n+1):
for j in range(1, n+1):
# Denotes (at least) one of the n digits.
yield [v(i,j,d) for d in range(1, n+1)]
# Does not denote two different digits at once.
for d in range(1, n+1):
for dp in range(d+1, n+1):
yield [-v(i,j,d), -v(i,j,dp)]
def valid(cells):
# Ensure that the cells contain distinct values.
for i, xi in enumerate(cells):
for j, xj in enumerate(cells):
if i < j:
for d in range(1, n+1):
yield [-v(xi[0],xi[1],d), -v(xj[0],xj[1],d)]
# Ensure rows and columns have distinct values.
for i in range(1, n+1):
yield from valid([(i, j) for j in range(1, n+1)])
yield from valid([(j, i) for j in range(1, n+1)])
# The cages return their clauses in disjunctive normal form,
# but our SAT solver needs the clauses in conjunctive normal
# form. To convert from DNF to CNF without exponential growth
# in the number of clauses we introduce additional variables.
auxiliary_vars = itertools.count(n**3+1)
# For each cage:
for cage in self._cages:
dnf = list(cage.dnf_clauses(n, v))
if not dnf:
raise ValueError("Invalid cage: %s" % (cage,))
# yield the clauses in conjunctive normal form,
# adding auxiliary variables as necessary
yield from self._dnf_to_cnf(dnf, auxiliary_vars)
@staticmethod
def _dnf_to_cnf(dnf, auxiliary_vars):
"""Convert dnf to cnf.
Parameters
----------
dnf : list of lists of int
Clauses in disjunctive normal form
auxiliary_variables: iterator, returning int
Auxiliary variables, of which len(dnf) will be taken.
"""
# Take the first `len(dnf)` entries from auxiliary_vars.
auxs = list(itertools.islice(auxiliary_vars, len(dnf)))
yield auxs
for v, clause in zip(auxs, dnf):
for c in clause:
yield [-v, c]
yield [v] + [-c for c in clause]
def solve(self):
"""Return a solution to the KenKen puzzle."""
sol = pycosat.solve(self.clauses())
if sol == 'UNSAT':
raise InvalidPuzzle
return self._sol_to_grid(set(sol))
def itersolve(self):
"""Return an iterator to all solutions of the KenKen puzzle."""
for sol in pycosat.itersolve(self.clauses()):
yield self._sol_to_grid(set(sol))
def _sol_to_grid(self, sol):
"""Convert a solution to a grid format."""
def read_cell(i, j):
# return the digit of cell i, j according to the solution
for d in range(1, self.size+1):
if self.variable(i, j, d) in sol:
return d
grid = [[None]*self.size for _ in range(self.size)]
for i in range(1, self.size+1):
for j in range(1, self.size+1):
grid[i-1][j-1] = read_cell(i, j)
return grid
if __name__ == "__main__":
# Apologies if there is a standard format for representing
# KenKen puzzles which I am not aware of.
# The puzzle format here is the following:
#
# # <dim>
# <cage1>
# ...
# <cageN>
#
# where dim is the dimension of the puzzle (i.e., 6 if
# the grid is 6-by-6). Each cage is of the form
#
# <op> <value> <cell1> ... <cellN>
#
# where op is one of +, -, *, /, and = (or !), value
# is the resulting number, and the cells are specified
# as <row><column> where the rows are labeled A, B, C,
# etc and the columns are labeled 1, 2, 3, etc.
#
print("Solving the 6x6 KenKen Puzzle from Wikipedia.")
print("See http://en.wikipedia.org/wiki/File:KenKenProblem.svg")
puzzle = KenKenPuzzle.from_text(
"""
# 6
+ 11 A1 B1
/ 2 A2 A3
* 20 A4 B4
* 6 A5 A6 B6 C6
- 3 B2 B3
/ 3 B5 C5
* 240 C1 C2 D1 D2
* 6 C3 C4
* 6 D3 E3
+ 7 D4 E4 E5
* 30 D5 D6
* 6 E1 E2
+ 9 E6 F6
+ 8 F1 F2 F3
/ 2 F4 F5
""")
grid = puzzle.solve()
# Compare to the known solution at
# http://en.wikipedia.org/wiki/File:KenKenSolution.svg
assert grid == [[5,6,3,4,1,2],
[6,1,4,5,2,3],
[4,5,2,3,6,1],
[3,4,1,2,5,6],
[2,3,6,1,4,5],
[1,2,5,6,3,4]]
for row in grid:
print(row)
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment