Skip to content

Instantly share code, notes, and snippets.

@teh
Last active January 31, 2020 17:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save teh/6847386 to your computer and use it in GitHub Desktop.
Save teh/6847386 to your computer and use it in GitHub Desktop.
Solve sudoku with a SAT solver.
import pycosat
import numpy
import itertools
WORLDS_HARDEST_RIDDLE_ACCORDING_TO_TELEGRAPH = """\
8........
..36.....
.7..9.2..
.5...7...
....457..
...1...3.
..1....68
..85...1.
.9....4.."""
def get_cnf(riddle):
# * add one because 0 is reserved in picosat
# * object type because pycosat expects 'int's
# * 9^3 vars x_ij^d where (i, j) == row and col, d == digit
vars = (numpy.arange(9 * 9 * 9).reshape(9, 9, 9) + 1).astype('object')
cnf = []
# At least one digit per square
for i in range(9):
for j in range(9):
cnf.append(vars[i, j, :].tolist())
# Only one digit per square
for i in range(9):
for j in range(9):
cnf += list(itertools.combinations(-vars[i, j, :], 2))
# Each 3x3 block must contain 9 differrent digits
for i in range(3):
for j in range(3):
for d in range(9):
cnf += list(itertools.combinations(-vars[i*3:i*3+3, j*3:j*3+3, d].ravel(), 2))
# Each row and each column must contain 9 different digits
for i in range(9):
for d in range(9):
cnf += list(itertools.combinations(-vars[i,:,d].ravel(), 2))
cnf += list(itertools.combinations(-vars[:,i,d].ravel(), 2))
# Tranform riddle board to CNF
for i, x in enumerate(riddle.split()):
for j, y in enumerate(x):
if y == '.':
continue
d = int(y) - 1
cnf.append([vars[i, j, d]])
return [list(x) for x in cnf]
def print_solution(solution):
solution_a = numpy.array(solution).reshape(9, 9, 9)
for i in range(9):
for j in range(9):
for d in range(9):
if solution_a[i, j, d] > 0:
print(d + 1, end=" ")
print("")
print_solution(pycosat.solve(get_cnf(WORLDS_HARDEST_RIDDLE_ACCORDING_TO_TELEGRAPH)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment