Skip to content

Instantly share code, notes, and snippets.

@torchlight
Last active December 25, 2015 08:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save torchlight/ace4f5fee81104ff693d to your computer and use it in GitHub Desktop.
Save torchlight/ace4f5fee81104ff693d to your computer and use it in GitHub Desktop.
Sudoku
"""Sudoku."""
import random
from itertools import combinations
def complement(A,B):
"""Return A\\B."""
return list(filter(lambda a:a not in B,A))
def intersect(A,B):
"""Return A \\cap B."""
return list(filter(A.__contains__,B))
def dot(A,B):
"""Dot-multiply two sequences, assuming len(A) == len(B)."""
s = 0
for i in range(len(A)):
s += A[i]*B[i]
return s
def shuffle(A):
"""Return a shuffled copy of A."""
A = list(A)
random.shuffle(A)
return A
def box(xpitch,ypitch,start):
"""
Return a list of indices, start + x*xpitch + y*ypitch, with x, y going from
0 to 2. For example, the indices of the top-left 3x3 box may be obtained
with box(1,9,0).
"""
return [start+x*xpitch+y*ypitch for y in range(3) for x in range(3)]
def gen_groups(standard=True,diagonals=False,hyper=False,disjoint=False):
"""Generate a list of groups for which the one rule should be satisfied."""
groups = []
if standard:
groups += [box(1,3,x) for x in box(9,27,0)] # rows,
groups += [box(9,27,x) for x in box(1,3,0)] # columns,
groups += [box(1,9,x) for x in box(3,27,0)] # and boxes
if diagonals:
groups += [box(10,30,0),box(8,24,8)]
if hyper:
groups += [box(1,9,x) for x in (10,14,46,50)]
# the next five groups are implied by the standard groups along with the
# above four, but we include them anyway to make the autosolvers' jobs
# easier.
groups += [box(1,36,5),box(1,36,5),box(4,9,9),box(4,9,45),box(4,36,0)]
if disjoint:
# sudokuwiki calls this colour sudoku; some other sites call it
# "disjoint groups"
groups += [box(3,27,x) for x in box(1,9,0)]
return groups
def convert(g):
"""Convert a grid from a flat list to the internal format."""
grid = [0]*81
for i in range(81):
grid[i] = [i]
if g[i]:
grid[i] += [False]*9
grid[i][g[i]] = True
else:
grid[i] += [True]*9
return grid
def convert_str(s):
"""Convert a grid from a string representation to the internal format."""
g = [0]*81
i = 0
for c in s:
if c in '._':
c = '0'
if '0' <= c <= '9':
g[i] = int(c)
i += 1
return convert(g)
def unconvert(grid,hints=False):
"""Convert a grid from the internal format to a flat list."""
g = [0]*81
for i in range(81):
s = sum(grid[i][1:])
if s == 1 or hints:
ss = 0
for v in range(1,10):
if grid[i][v]:
ss = ss*10+v
g[i] = ss
else:
g[i] = 0
return g
def print_grid(grid):
"""Print the grid in a fixed-width format."""
ss = ''
for i in range(9):
for j in range(9):
s = ''
for v in range(1,10):
if grid[i*9+j][v]:
s += str(v)
ss += s.ljust(10)
ss += '\n'
print(ss)
def copy_grid(grid):
return [c[:] for c in grid]
def count_candidates(grid):
return sum(map(lambda c:sum(c[1:]),grid))
def elim(grid,groups):
"""Eliminate conflicting candidates."""
grid = copy_grid(grid)
changed = True
while changed:
changed = False
for group in groups:
for cell in group:
if sum(grid[cell][1:]) != 1:
continue
v = dot(grid[cell],range(10))
for cell_ in group:
if cell == cell_ or not grid[cell_][v]:
continue
grid[cell_][v] = False
changed = True
return grid
def elim_multi(grid,groups,n=4):
"""Eliminate candidates conflicting naked n-tuples."""
if n == 1:
return elim(grid,groups)
grid = elim_multi(grid,groups,n-1)
changed = True
while changed:
changed = False
check = [2 <= sum(grid[cell][1:]) <= n for cell in range(81)]
for group in groups:
indices = filter(check.__getitem__,group)
for cells in combinations(indices,n):
flags = [False]*10
for cell in cells:
for i in range(1,10):
flags[i] = flags[i] or grid[cell][i]
if sum(flags) < n:
return grid # not solvable
elif sum(flags) > n:
continue # not a naked n-tuple
for cell in group:
if cell in cells:
continue
for v in range(1,10):
if flags[v] and grid[cell][v]:
grid[cell][v] = False
changed = True
return grid
def scan(grid,groups):
"""Eliminate candidates by hidden singles."""
grid = copy_grid(grid)
changed = True
while changed:
changed = False
for group in groups:
for v in range(1,10):
if sum((grid[cell][v] for cell in group)) != 1:
continue
for cell in group:
if grid[cell][v]:
if sum(grid[cell][1:]) == 1:
break
changed = True
grid[cell][1:] = [False]*9
grid[cell][v] = True
break
return grid
def scan_multi(grid,groups,n=4):
"""Eliminate candidates by hidden n-tuples."""
if n == 1:
return scan(grid,groups)
grid = scan_multi(grid,groups,n-1)
changed = True
while changed:
changed = False
for group in groups:
for V in combinations(range(1,10),n):
cells = []
for cell in group:
if sum((grid[cell][v] for v in V)) > 0:
cells.append(cell)
if len(cells) < n:
return grid # not solvable
elif len(cells) > n:
continue # not a hidden n-tuple
for cell in cells:
if sum(grid[cell][1:]) == sum((grid[cell][v] for v in V)):
break
changed = True
for v in range(1,10):
if v not in V:
grid[cell][v] = False
return grid
def ir(grid,groups):
"""
Eliminate candidates by intersection removal. This includes pointing
n-tuples and box/line reduction under the standard rules, but is generalised
to arbitrary pairs of regions.
"""
grid = copy_grid(grid)
changed = True
while changed:
changed = False
for (group0,group1) in combinations(groups,2):
intersection = intersect(group0,group1)
if len(intersection) < 2:
continue # we need at least two common cells
flags = [None] + [False]*9
for cell in intersection:
for i in range(1,10):
flags[i] = flags[i] or grid[cell][i]
candidates = list(filter(flags.__getitem__,range(1,10)))
for (groupa,groupb) in ((group0,group1),(group1,group0)):
V = []
for v in candidates:
skip = False
for cell in groupa:
if cell in intersection:
continue
if grid[cell][v]:
skip = True
break
if skip:
continue
V.append(v)
for v in V:
for cell in groupb:
if cell in intersection or not grid[cell][v]:
continue
grid[cell][v] = False
changed = True
return grid
def solve_logic(grid,groups,methods=[elim,scan],verbose=False):
"""
Attempt to solve a puzzle using a specified list of methods, defaulting to
naked/hidden singles.
"""
oldcount = 0
count = count_candidates(grid)
if verbose:
print(count)
while count != oldcount:
for method in methods:
grid = method(grid,groups)
oldcount = count
count = count_candidates(grid)
if verbose:
print(count)
return grid
def solve_weak(grid,groups):
"""Attempt to solve a puzzle using only naked/hidden singles."""
oldcount = 0
count = count_candidates(grid)
while count != oldcount:
grid = scan(elim(grid,groups),groups)
oldcount = count
count = count_candidates(grid)
return grid
def solve_basic(grid,groups):
"""
Attempt to solve a puzzle using the basic techniques; this includes naked/
hidden n-tuples and intersection removal.
"""
return solve_logic(grid,groups,[elim_multi,scan_multi,ir])
def check_validity(grid,groups):
"""Check that the grid is sensible without solving it."""
for c in grid:
s = sum(c[1:])
if s == 0: # cell has no candidates
return False
elif s == 1: # cell has one candidate; check that it doesn't conflict
v = dot(c,range(10))
for group in groups:
if c[0] not in group:
continue
for cell in group:
if c[0] >= cell:
continue
if sum(grid[cell][1:]) == 1 and grid[cell][v]:
return False
return True
def solve_full(grid,groups,deterministic=False,order=None,maxdepth=0):
"""
Solve a puzzle, falling back on brute force if necessary. For brute-forcing,
order specifies the order in which cells should be attempted, and maxdepth
specifies the maximum search depth plus one.
"""
grid = solve_weak(grid,groups)
if not check_validity(grid,groups):
return None # invalid
if count_candidates(grid) == 81:
return grid # solved
if maxdepth == 1:
return grid # max depth reached
if order is None:
order = list(range(81))
if not deterministic:
random.shuffle(order)
if len(order) == 0:
return grid # no cells left to check
cell,order = order[0],order[1:]
while sum(grid[cell][1:]) == 1:
if len(order) == 0:
return grid # no unsolved cells left to check
cell,order = order[0],order[1:]
V = list(range(1,10))
if not deterministic:
random.shuffle(V)
for v in V:
if not grid[cell][v]:
continue
g = copy_grid(grid)
g[cell][1:] = [False]*9
g[cell][v] = True
g = solve_full(g,groups,deterministic,order,maxdepth and maxdepth-1)
if g:
return g
return None # no solution
def check_unique_solution(grid,groups,refgrid=None,deterministic=False):
"""
Check that a puzzle has a unique solution. If refgrid is specified, try to
find a solution that is not refgrid, assuming that refgrid is a solution.
"""
grid = solve_weak(grid,groups)
if not check_validity(grid,groups):
return False
if count_candidates(grid) == 81:
return True
if refgrid is None:
refgrid = solve_full(grid,groups,deterministic)
if refgrid is None:
return False
for cell in range(81):
V = list(range(1,10))
if not deterministic:
random.shuffle(V)
for v in V:
if refgrid[cell][v] or not grid[cell][v]:
continue
g = copy_grid(grid)
g[cell][1:] = [False]*9
g[cell][v] = True
g = solve_full(g,groups,deterministic)
if g is not None:
return False
return True
def gen_standard():
"""Generate a random filled sudoku grid."""
# populate the three diagonal boxes, then try to find a random solution;
# this is moderately efficient (slightly faster than a row-by-row method)
# and has pretty much negligible bias, but can occasionally take thirty
# times as long as usual to generate a grid if the RNG hates you.
while True:
g = [0]*81
for j in (0,30,60):
V = shuffle(range(1,10))
for i in box(1,9,j):
g[i] = V[-1]
V = V[:-1]
grid = solve_full(convert(g),gen_groups())
if grid is not None:
return grid
def add_holes(grid,groups,hard=True,order=None,verbose=False):
"""Add blanks to a grid to make it a puzzle."""
hgrid = copy_grid(grid)
h = order or shuffle(range(81))
if verbose:
print(h)
n = 0
for i in h:
ngrid = copy_grid(hgrid)
ngrid[i][1:] = [True]*9
sgrid = solve_basic(ngrid,groups)
if sum(sgrid[i][1:]) == 1 or (hard and check_unique_solution(sgrid,groups,grid)):
n += 1
hgrid = ngrid
if verbose:
print('added at %d (total %d)'%(i,n))
continue
elif verbose:
print('rejected add at %d'%i)
return hgrid
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment