SAT-based KenKen solver.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A SAT-based KenKen (http://www.kenken.com/) solver. | |
The implementation of this solver is based on the ideas contained in | |
the paper: | |
"A SAT-based Sudoku solver" by Tjark Weber | |
https://www.lri.fr/~conchon/mpri/weber.pdf | |
and a Python implementation of a SAT-based Sudoku solver which is found | |
in the pycosat examples: | |
https://pypi.python.org/pypi/pycosat | |
This module requires Python 3.3 or later. | |
""" | |
import itertools | |
import pycosat | |
######################################################################## | |
# Exceptions | |
# | |
class InvalidPuzzle(Exception): | |
"""The puzzle specification is invalid.""" | |
pass | |
class UnsatisfyablePuzzle(Exception): | |
"""The puzzle is unsatisfyable.""" | |
pass | |
######################################################################## | |
# Partitioners | |
# | |
# A partitioner yields all possible partitions of `result` into `nparts` | |
# according to a specified operation, where each value is in the range | |
# `1..dim`. | |
# | |
# See the docstring of `partition` for some examples. | |
# | |
def _partition_add(nparts, result, dim): | |
if nparts == 0: | |
raise ValueError("Expected nparts >= 1") | |
elif nparts == 1: | |
if 1 <= result <= dim: | |
yield (result, ) | |
else: | |
raise StopIteration | |
else: | |
for i1 in range(1, dim+1): | |
for i2 in _partition_add(nparts-1, result-i1, dim): | |
yield (i1,) + i2 | |
def _partition_mul(nparts, result, dim): | |
if nparts == 0: | |
raise ValueError("Expected nparts >= 1") | |
elif nparts == 1: | |
if 1 <= result <= dim: | |
yield (result, ) | |
else: | |
raise StopIteration | |
else: | |
for i1 in range(1, dim+1): | |
if result % i1 != 0: | |
continue | |
for i2 in _partition_mul(nparts-1, int(result//i1), dim): | |
yield (i1,) + i2 | |
def _partition_sub(nparts, result, dim): | |
if nparts != 2: | |
raise ValueError("Expected nparts = 2.") | |
for i1 in range(1, dim+1): | |
for i2 in (i1-result, i1+result): | |
if 1 <= i2 <= dim: | |
yield(i1, i2) | |
def _partition_div(nparts, result, dim): | |
if nparts != 2: | |
raise ValueError("Expected nparts = 2.") | |
for i1 in range(1, dim+1): | |
if i1 % result == 0: | |
i2 = int(i1 // result) | |
if 1 <= i2 <= dim: | |
yield (i1, i2) | |
i2 = result * i1 | |
if 1 <= i2 <= dim: | |
yield (i1, i2) | |
def _partition_eq(nparts, result, dim): | |
if nparts != 1: | |
raise ValueError("Expected nparts = 1.") | |
if 1 <= result <= dim: | |
yield (result,) | |
_partitioners = { | |
'+': _partition_add, | |
'*': _partition_mul, | |
'-': _partition_sub, | |
'/': _partition_div, | |
'!': _partition_eq, | |
'=': _partition_eq, | |
} | |
def partition(op, nparts, result, dim): | |
"""Partition `result` into `nparts` each in the range `1..dim` | |
which can be obtained using the operation `op`. | |
For example, all pairs of numbers in the range 1..6 whose ratio is 3: | |
>>> sorted(partition('/', 2, 3, 6)) | |
[(1, 3), (2, 6), (3, 1), (6, 2)] | |
All triples of numbers in the range 1..6 whose product is 4: | |
>>> sorted(partition('*', 3, 4, 6)) | |
[(1, 1, 4), (1, 2, 2), (1, 4, 1), (2, 1, 2), (2, 2, 1), (4, 1, 1)] | |
All single numbers which are exactly equal to 4: | |
>>> sorted(partition('=', 1, 4, 6)) | |
[(4,)] | |
""" | |
yield from _partitioners[op](nparts, result, dim) | |
######################################################################## | |
# A cage is a sub-region of the puzzle with a specified arithmetic | |
# operation and resulting value. | |
# | |
class Cage: | |
"""A set of cells which achives a value by an arithmetic operation.""" | |
def __init__(self, op, value, cells): | |
self.op = op | |
self.value = value | |
self.cells = cells | |
def __str__(self): | |
return "%s(%r, %s, %s)" % ( | |
self.__class__.__name__, self.op, self.value, self.cells) | |
def dnf_clauses(self, dim, variable): | |
"""Yield clauses in disjunctive normal form which satisfy the | |
cage constraints. | |
Essentially this returns clauses which correspond to | |
(digit(cell[0]) == sol1[0] AND digit(cell[1]) == sol1[1] AND ...) | |
OR (digit(cell[0]) == sol2[0] AND digit(cell[1]) == sol2[1] AND ...) | |
OR (digit(cell[0]) == sol3[0] AND digit(cell[1]) == sol3[1] AND ...) | |
OR (digit(cell[0]) == sol4[0] AND digit(cell[1]) == sol4[1] AND ...) | |
OR ... | |
where sol1, sol2, sol3, sol4, ... are all possible local solutions | |
to the cage constraints, ignoring any global constraints (like | |
no duplicates in rows or columns). | |
Parameters | |
---------- | |
dim : int | |
Puzzle dimension | |
variable : callable, as variable(i, j, d) | |
Returns the variable number corresponding to `d` in cell `(i,j)`. | |
""" | |
for vals in partition(self.op, len(self.cells), self.value, dim): | |
yield tuple(variable(cell[0], cell[1], val) | |
for cell, val in zip(self.cells, vals)) | |
######################################################################## | |
# Main Entry Point | |
# | |
class KenKenPuzzle: | |
"""A KenKen puzzle.""" | |
def __init__(self, size, cages=None): | |
self.size = size | |
self._cages = cages or [] | |
@classmethod | |
def from_text(cls, text): | |
"""Instantiate a puzzle from a text description.""" | |
lines = iter(text.splitlines()) | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
op, value = line.split() | |
if op != '#': | |
raise InvalidPuzzleSpecification | |
size = int(value) | |
break | |
else: | |
raise InvalidPuzzleSpecification | |
puzzle = cls(size) | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
puzzle.add_text_cage(line) | |
puzzle.assert_valid() | |
return puzzle | |
def assert_valid(self): | |
"""Checks that each cell belongs to exactly one cage.""" | |
visited = set() | |
for cage in self._cages: | |
frontier = set(cage.cells) | |
if not visited.isdisjoint(frontier): | |
raise InvalidPuzzle("Duplicate cells: %s" | |
% (visited.intersection(frontier),)) | |
visited.update(frontier) | |
expected = set(itertools.product(range(1, self.size+1), | |
range(1, self.size+1))) | |
missing = expected.difference(visited) | |
unknown = visited.difference(expected) | |
if missing or unknown: | |
messages = [] | |
if missing: | |
messages.append("Missing cells: %s" % (missing,)) | |
if unknown: | |
messages.append("Unknown cells: %s" % (unknown,)) | |
raise InvalidPuzzle(" ".join(messages)) | |
def add_cage(self, op, result, cells): | |
"""Add a cage.""" | |
self._cages.append(Cage(op, result, cells)) | |
def _cell_as_tuple(self, cell): | |
"""Convert cell notation from text form to tuple form. | |
For example, 'B7' -> (2,7). | |
""" | |
row, col = cell[:1], cell[1:] | |
row = ord(row)-ord('A')+1 | |
col = int(col) | |
return row, col | |
def add_text_cage(self, text): | |
"""Add a textual representation of a cage.""" | |
op, result, *cells = text.split() | |
result = int(result) | |
cells = tuple(self._cell_as_tuple(cell) for cell in cells) | |
self.add_cage(op, result, cells) | |
def variable(self, i, j, d): | |
"""Return the number of the variable which corresponds to cell (i, j) | |
containing digit d. | |
""" | |
n = self.size | |
assert 1 <= i <= n | |
assert 1 <= j <= n | |
assert 1 <= d <= n | |
return n*n * (i - 1) + n * (j - 1) + d | |
def _unvariable(self, v): | |
"""Return the cell and digit corresponding to v.""" | |
n = self.size | |
i = v // (n*n) + 1 | |
v = v - n*n*(i-1) | |
j = v // n + 1 | |
v = v - n*(j-1) | |
d = v | |
return ((i,j), d) | |
def clauses(self): | |
"""Yield all clauses for the puzzle.""" | |
n = self.size | |
v = self.variable | |
# For all cells, ensure that the each cell: | |
for i in range(1, n+1): | |
for j in range(1, n+1): | |
# Denotes (at least) one of the n digits. | |
yield [v(i,j,d) for d in range(1, n+1)] | |
# Does not denote two different digits at once. | |
for d in range(1, n+1): | |
for dp in range(d+1, n+1): | |
yield [-v(i,j,d), -v(i,j,dp)] | |
def valid(cells): | |
# Ensure that the cells contain distinct values. | |
for i, xi in enumerate(cells): | |
for j, xj in enumerate(cells): | |
if i < j: | |
for d in range(1, n+1): | |
yield [-v(xi[0],xi[1],d), -v(xj[0],xj[1],d)] | |
# Ensure rows and columns have distinct values. | |
for i in range(1, n+1): | |
yield from valid([(i, j) for j in range(1, n+1)]) | |
yield from valid([(j, i) for j in range(1, n+1)]) | |
# The cages return their clauses in disjunctive normal form, | |
# but our SAT solver needs the clauses in conjunctive normal | |
# form. To convert from DNF to CNF without exponential growth | |
# in the number of clauses we introduce additional variables. | |
auxiliary_vars = itertools.count(n**3+1) | |
# For each cage: | |
for cage in self._cages: | |
dnf = list(cage.dnf_clauses(n, v)) | |
if not dnf: | |
raise ValueError("Invalid cage: %s" % (cage,)) | |
# yield the clauses in conjunctive normal form, | |
# adding auxiliary variables as necessary | |
yield from self._dnf_to_cnf(dnf, auxiliary_vars) | |
@staticmethod | |
def _dnf_to_cnf(dnf, auxiliary_vars): | |
"""Convert dnf to cnf. | |
Parameters | |
---------- | |
dnf : list of lists of int | |
Clauses in disjunctive normal form | |
auxiliary_variables: iterator, returning int | |
Auxiliary variables, of which len(dnf) will be taken. | |
""" | |
# Take the first `len(dnf)` entries from auxiliary_vars. | |
auxs = list(itertools.islice(auxiliary_vars, len(dnf))) | |
yield auxs | |
for v, clause in zip(auxs, dnf): | |
for c in clause: | |
yield [-v, c] | |
yield [v] + [-c for c in clause] | |
def solve(self): | |
"""Return a solution to the KenKen puzzle.""" | |
sol = pycosat.solve(self.clauses()) | |
if sol == 'UNSAT': | |
raise InvalidPuzzle | |
return self._sol_to_grid(set(sol)) | |
def itersolve(self): | |
"""Return an iterator to all solutions of the KenKen puzzle.""" | |
for sol in pycosat.itersolve(self.clauses()): | |
yield self._sol_to_grid(set(sol)) | |
def _sol_to_grid(self, sol): | |
"""Convert a solution to a grid format.""" | |
def read_cell(i, j): | |
# return the digit of cell i, j according to the solution | |
for d in range(1, self.size+1): | |
if self.variable(i, j, d) in sol: | |
return d | |
grid = [[None]*self.size for _ in range(self.size)] | |
for i in range(1, self.size+1): | |
for j in range(1, self.size+1): | |
grid[i-1][j-1] = read_cell(i, j) | |
return grid | |
if __name__ == "__main__": | |
# Apologies if there is a standard format for representing | |
# KenKen puzzles which I am not aware of. | |
# The puzzle format here is the following: | |
# | |
# # <dim> | |
# <cage1> | |
# ... | |
# <cageN> | |
# | |
# where dim is the dimension of the puzzle (i.e., 6 if | |
# the grid is 6-by-6). Each cage is of the form | |
# | |
# <op> <value> <cell1> ... <cellN> | |
# | |
# where op is one of +, -, *, /, and = (or !), value | |
# is the resulting number, and the cells are specified | |
# as <row><column> where the rows are labeled A, B, C, | |
# etc and the columns are labeled 1, 2, 3, etc. | |
# | |
print("Solving the 6x6 KenKen Puzzle from Wikipedia.") | |
print("See http://en.wikipedia.org/wiki/File:KenKenProblem.svg") | |
puzzle = KenKenPuzzle.from_text( | |
""" | |
# 6 | |
+ 11 A1 B1 | |
/ 2 A2 A3 | |
* 20 A4 B4 | |
* 6 A5 A6 B6 C6 | |
- 3 B2 B3 | |
/ 3 B5 C5 | |
* 240 C1 C2 D1 D2 | |
* 6 C3 C4 | |
* 6 D3 E3 | |
+ 7 D4 E4 E5 | |
* 30 D5 D6 | |
* 6 E1 E2 | |
+ 9 E6 F6 | |
+ 8 F1 F2 F3 | |
/ 2 F4 F5 | |
""") | |
grid = puzzle.solve() | |
# Compare to the known solution at | |
# http://en.wikipedia.org/wiki/File:KenKenSolution.svg | |
assert grid == [[5,6,3,4,1,2], | |
[6,1,4,5,2,3], | |
[4,5,2,3,6,1], | |
[3,4,1,2,5,6], | |
[2,3,6,1,4,5], | |
[1,2,5,6,3,4]] | |
for row in grid: | |
print(row) | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For Python 3.7 and newer, to avoid running into
you'll have to replace
raise StopIteration
toreturn None
in all places.This comes from PEP 479:
PEP 479 is enabled for all code in Python 3.7, meaning that StopIteration exceptions raised directly or indirectly in coroutines and generators are transformed into RuntimeError exceptions. (Contributed by Yury Selivanov in bpo-32670.)