Skip to content

Instantly share code, notes, and snippets.

@kedarbellare
Created July 13, 2012 05:44
Show Gist options
  • Save kedarbellare/3102962 to your computer and use it in GitHub Desktop.
Save kedarbellare/3102962 to your computer and use it in GitHub Desktop.
Complete sudoku solver
# CHALLENGE PROBLEM:
#
# Use your check_sudoku function as the basis for solve_sudoku(): a
# function that takes a partially-completed Sudoku grid and replaces
# each 0 cell with a number in the range 1..9 in such a way that the
# final grid is valid.
#
# There are many ways to cleverly solve a partially-completed Sudoku
# puzzle, but a brute-force recursive solution with backtracking is a
# perfectly good option. The solver should return None for broken
# input, False for inputs that have no valid solutions, and a valid
# 9x9 Sudoku grid containing no 0 elements otherwise. In general, a
# partially-completed Sudoku grid does not have a unique solution. You
# should just return some member of the set of solutions.
#
# A solve_sudoku() in this style can be implemented in about 16 lines
# without making any particular effort to write concise code.
# solve_sudoku should return None
ill_formed = [[5,3,4,6,7,8,9,1,2],
[6,7,2,1,9,5,3,4,8],
[1,9,8,3,4,2,5,6,7],
[8,5,9,7,6,1,4,2,3],
[4,2,6,8,5,3,7,9], # <---
[7,1,3,9,2,4,8,5,6],
[9,6,1,5,3,7,2,8,4],
[2,8,7,4,1,9,6,3,5],
[3,4,5,2,8,6,1,7,9]]
# solve_sudoku should return valid unchanged
valid = [[5,3,4,6,7,8,9,1,2],
[6,7,2,1,9,5,3,4,8],
[1,9,8,3,4,2,5,6,7],
[8,5,9,7,6,1,4,2,3],
[4,2,6,8,5,3,7,9,1],
[7,1,3,9,2,4,8,5,6],
[9,6,1,5,3,7,2,8,4],
[2,8,7,4,1,9,6,3,5],
[3,4,5,2,8,6,1,7,9]]
# solve_sudoku should return False
invalid = [[5,3,4,6,7,8,9,1,2],
[6,7,2,1,9,5,3,4,8],
[1,9,8,3,8,2,5,6,7],
[8,5,9,7,6,1,4,2,3],
[4,2,6,8,5,3,7,9,1],
[7,1,3,9,2,4,8,5,6],
[9,6,1,5,3,7,2,8,4],
[2,8,7,4,1,9,6,3,5],
[3,4,5,2,8,6,1,7,9]]
# solve_sudoku should return a
# sudoku grid which passes a
# sudoku checker. There may be
# multiple correct grids which
# can be made from this starting
# grid.
easy = [[2,9,0,0,0,0,0,7,0],
[3,0,6,0,0,8,4,0,0],
[8,0,0,0,4,0,0,0,2],
[0,2,0,0,3,1,0,0,7],
[0,0,0,0,8,0,0,0,0],
[1,0,0,9,5,0,0,6,0],
[7,0,0,0,9,0,0,0,1],
[0,0,1,2,0,0,3,0,6],
[0,3,0,0,0,0,0,5,9]]
# Note: this may timeout
# in the Udacity IDE! Try running
# it locally if you'd like to test
# your solution with it.
#
hard = [[1,0,0,0,0,7,0,9,0],
[0,3,0,0,2,0,0,0,8],
[0,0,9,6,0,0,5,0,0],
[0,0,5,3,0,0,9,0,0],
[0,1,0,0,8,0,0,0,2],
[6,0,0,0,0,4,0,0,0],
[3,0,0,0,0,0,0,1,0],
[0,4,0,0,0,0,0,0,7],
[0,0,7,0,0,0,3,0,0]]
subg = [[9,8,7,6,5,4,3,2,1],
[8,7,6,5,4,3,2,1,9],
[7,6,5,4,3,2,1,9,8],
[6,5,4,3,2,1,9,8,7],
[5,4,3,2,1,9,8,7,6],
[4,3,2,1,9,8,7,6,5],
[3,2,1,9,8,7,6,5,4],
[2,1,9,8,7,6,5,4,3],
[1,9,8,7,6,5,4,3,2]]
hard2 = [[0,1,0,0,0,0,0,0,4],
[0,4,0,0,0,7,0,0,1],
[0,0,0,0,0,0,7,0,0],
[0,0,5,0,0,0,0,0,0],
[0,0,0,0,0,0,0,5,0],
[0,0,0,9,0,0,1,0,0],
[0,7,0,0,2,0,0,0,9],
[0,0,0,0,0,0,0,0,0],
[0,0,0,0,9,5,0,0,0]]
def cross(A, B):
"Cross product of elements in A and elements in B."
return [a+b for a in A for b in B]
digits = '123456789'
rows = 'ABCDEFGHI'
cols = digits
squares = cross(rows, cols)
unitlist = ([cross(rows, c) for c in cols] +
[cross(r, cols) for r in rows] +
[cross(rs, cs) for rs in ('ABC','DEF','GHI') for cs in ('123','456','789')])
units = dict((s, [u for u in unitlist if s in u])
for s in squares)
peers = dict((s, set(sum(units[s],[]))-set([s]))
for s in squares)
def check_row(nums):
## each number should be between 0 and 9
counts = [0 for i in range(10)]
for num in nums:
if num < 0 or num > 9:
return False
counts[num] += 1
if num > 0 and counts[num] > 1:
return False
return True
def get_constraints(grid, i, j):
row = grid[i]
col = [grid[k][j] for k in range(9)]
cell = [grid[k][l] for k in range(i/3*3, (i/3+1)*3) for l in range(j/3*3, (j/3+1)*3)]
if check_row(row) and check_row(col) and check_row(cell):
if grid[i][j] > 0 and grid[i][j] < 10:
return str(grid[i][j])
else:
constraints = set(range(1, 10)) - set(row) - set(col) - set(cell)
return ''.join(map(lambda x: str(x), constraints))
def check_sudoku(grid):
if type(grid) is not list or len(grid) != 9: return
for row in grid:
if type(row) is not list or len(row) != 9: return
for num in row:
if type(num) is not int: return
for i in range(9):
for j in range(9):
if get_constraints(grid, i, j) is None:
return False
return True
def erase(values, to_cell, num):
if num not in values[to_cell]:
return values
values[to_cell] = values[to_cell].replace(num, '')
if len(values[to_cell]) == 0:
return False # contradiction
elif len(values[to_cell]) == 1:
num2 = values[to_cell]
if not all(erase(values, cell, num2) for cell in peers[to_cell]):
return False
for unit in units[to_cell]:
numplaces = [cell for cell in unit if num in values[cell]]
if len(numplaces) == 0:
return False
elif len(numplaces) == 1:
if not assign(values, numplaces[0], num):
return False
return values
def assign(values, to_cell, num):
othernums = values[to_cell].replace(num, '')
if all(erase(values, to_cell, onum) for onum in othernums):
return values
else:
return False
def assign_from(grid, values):
for cell in values:
if len(values[cell]) == 1:
r,c = cell
i,j = rows.index(r),cols.index(c)
grid[i][j] = int(values[cell])
return grid
def solve_partial(values):
if values is False: return False
# most constrained
cell_lengths = [(len(values[cell]), cell) for cell in values]
if any(lc[0] == 0 for lc in cell_lengths):
return False
if all(lc[0] == 1 for lc in cell_lengths):
return values
least, cell = min([lc for lc in cell_lengths if lc[0] > 1], key=lambda x: x[0])
for d in values[cell]:
newvalues = solve_partial(assign(values.copy(), cell, d))
if newvalues:
return newvalues # found partial solution
return False
def solve_sudoku (grid):
###Your code here.
grid_check = check_sudoku(grid)
if grid_check == None or grid_check == False: return grid_check
values = dict((s, get_constraints(grid, rows.index(s[0]), cols.index(s[1]))) for s in squares)
soln = solve_partial(values)
if soln:
assign_from(grid, soln)
assert check_sudoku(grid) == True
assert sum(map(sum, grid)) == 405
return grid
else:
return False
## Tests
import time, random
def from_file(filename, sep='\n'):
"Parse a file into a list of strings, separated by sep."
return file(filename).read().strip().split(sep)
def random_puzzle(N=30):
values = dict((cell, '123456789') for cell in squares)
for cell in shuffled(values.keys()):
if not assign(values, cell, random.choice(values[cell])):
break
ds = [values[cell] for cell in values if len(values[cell]) == 1]
if len(ds) >= N and len(set(ds)) >= 9:
return ''.join(values[c] if len(values[c]) == 1 else '0' for c in squares)
return random_puzzle(N)
def shuffled(seq):
seq = list(seq)
random.shuffle(seq)
return seq
def parse_grid(gridstr):
gridstr = gridstr.replace('\n', '')
return [[int(gridstr[i+9*j]) for i in range(9)] for j in range(9)]
def solve_all(grids, name='', showif=0.0):
"""Attempt to solve a sequence of grids. Report results.
When showif is a number of seconds, display puzzles that take longer.
When showif is None, don't display any puzzles."""
def time_solve(grid):
global solved
start = time.clock()
soln = solve_sudoku(parse_grid(grid))
t = time.clock()-start
## Display puzzles that take long enough
if soln == False:
print 'failed:', grid
return (t, soln)
if showif is not None and t > showif:
print grid
print ''.join(''.join(map(lambda x: str(x), n)) for n in soln)
print '(%.2f seconds)\n' % t
return (t, soln)
times, results = zip(*[time_solve(grid) for grid in grids])
N = len(grids)
if N > 1:
print "Solved %d of %d %s puzzles (total %.2f secs, avg %.2f secs (%d Hz), max %.2f secs)." % (len(results), N, name, sum(times), sum(times)/N, N/sum(times), max(times))
if __name__ == '__main__':
print check_sudoku(subg)
print solve_sudoku(easy)
print solve_sudoku(hard)
print solve_sudoku(hard2)
print solve_sudoku(parse_grid('000000605000300090080004001040020970000000000031080060900600020010007000504000000'))
print solve_sudoku(parse_grid('800000000003600000070090200050007000000045700000100030001000068008500010090000400'))
print solve_sudoku(parse_grid('100007090030020008009600500005300900010080002600004000300000010041000007007000300'))
solve_all(from_file("easy_sudoku.txt"), "easy", None)
solve_all(from_file("hard_sudoku.txt"), "medium", None)
solve_all(from_file("top95.txt"), "hard", None)
solve_all([random_puzzle() for _ in range(100)], "random", None)
solve_all(from_file("extreme_sudoku.txt"), "extreme", 0.1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment