Skip to content

Instantly share code, notes, and snippets.

@aflaxman
Created July 2, 2010 03:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aflaxman/460904 to your computer and use it in GitHub Desktop.
Save aflaxman/460904 to your computer and use it in GitHub Desktop.
from pylab import *
import random
index_set = [[i,j] for i in range(9) for j in range(9)]
def solve(T):
""" Find a solution to T, if possible
T is a 9x9 array, with blank cells set to -1
T is changed to the solution, returns 'success' or 'failure'
Example
-------
>>> T = rand(70)
>>> solve(T)
'success'
>>> any(T == -1)
False
"""
# if all cells are filled in, we win
if all(T > 0):
return 'success'
# solve T recursively, by trying all values for the most constrained var
pos = possibilities(T)
i,j = most_constrained(T, pos)
for val in pos[(i,j)]:
T[i,j] = val
if solve(T) == 'success':
return 'success'
# if this point is reached, this branch is unsatisfiable
T[i, j] = -1
return 'failure'
def count_solns(T, one_vs_many=False):
""" How many unique solutions are there starting with T?
if one_vs_many is True, just count if there are 0, 1, or many solutions
Example
-------
>>> T = rand(81)
>>> count_solns(T)
1
"""
# solve T recursively, by trying all values for the most constrained var
pos = possibilities(T)
# if there are no keys in the possibility dictionary, this is a solution
if pos.keys() == []:
return 1
i,j = most_constrained(T, pos)
count = 0
for val in pos[(i,j)]:
T[i, j] = val
count += count_solns(T, one_vs_many)
if one_vs_many and count > 1:
T[i, j] = -1
return count
# when this point is reached, reset most_constrained cell
T[i, j] = -1
return count
def rand(n, T=None):
""" Create a random game, with n cells filled in
optionally start with an initialized board T
Example
-------
>>> sum(rand(70) != -1)
70
"""
# start with an empty board
if T == None:
T = -1*ones([9,9])
# solve it to generate an initial solution
res = solve(T)
assert res == 'success'
# do random shuffles to approximate uniformly random solution
for k in range(5):
select_random_cells(T, 20)
randomly_permute_labels(T)
solve(T)
# remove appropriate amount of labels
select_random_cells(T, n)
return T
def most_constrained(T, pos):
""" Find blank cell which is most constrained by non-blank cells
Returns tuple indexing the cell
"""
most_value = inf
for i, j in index_set:
if T[i, j] < 0:
cur_value = len(pos[(i,j)])
if cur_value < most_value:
most_index = (i,j)
most_value = cur_value
return most_index
def possibilities(T):
""" Find all possibilities for each empty cell of T
Returns a set of dictionaries
Example
-------
>>> pos = possibilities(-1*ones([9,9]))
>>> pos[(0,0)]
set([1, 2, 3, 4, 5, 6, 7, 8, 9])
"""
pos = {}
for i, j in index_set:
# integer division to find the super-cell for this i and j
ci = int(i)/int(3)
cj = int(j)/int(3)
if T[i, j] < 0:
pos[(i,j)] = set(range(1,10)) - (set(T[i, :]) | set(T[:, j]) \
| set(T[(3*ci):(3*ci+3), (3*cj):(3*cj+3)].flatten()))
return pos
def select_random_cells(T, n):
""" Replace all but n cells with -1 to indicate they are blank"""
for i,j in random.sample(index_set, 81-n):
T[i, j] = -1
def randomly_permute_labels(T):
""" Permute the positive values of T uniformly at random"""
new_labels = range(1,10)
random.shuffle(new_labels) # random.shuffle acts in-place
new_label_dict = dict(zip(range(1,10), new_labels))
for i,j in index_set:
if T[i,j] > 0:
T[i,j] = new_label_dict[T[i,j]]
def draw(T, R=None):
""" Use matplotlib to display 9x9 table T in a Sudoku style
Example
-------
>>> T = rand(70)
>>> R = copy(T)
>>> solve(R)
'success'
>>> draw(T, R)
"""
clf()
params = dict(linewidth=2)
grid = [1./3., 2./3.]
hlines(grid, 0, 1, **params)
vlines(grid, 0, 1, **params)
params = dict(linewidth=1)
grid = arange(0., 1.1, 1./9.)
hlines(grid, 0, 1, **params)
vlines(grid, 0, 1, **params)
params = dict(facecolor='gray')
mid_verts = [1/3., 2/3., 2/3., 1/3.]
bot_verts = [0, 0, 1/3., 1/3.]
top_verts = [1, 1, 2/3., 2/3.]
fill(mid_verts, bot_verts, **params)
fill(mid_verts, top_verts, **params)
fill(bot_verts, mid_verts, **params)
fill(top_verts, mid_verts, **params)
params = dict(fontsize=20, ha='center', va='center')
for i in range(9):
for j in range(9):
row_pos = 1. - (i/9. + 1/18.)
col_pos = j/9. + 1/18.
if T[i,j] > 0:
text(col_pos, row_pos, '%d'%T[i,j], weight='bold', **params)
elif R != None:
text(col_pos, row_pos, '%d'%R[i,j], **params)
axis([0,1,0,1])
xticks([])
yticks([])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment